Merge branch 'Stable-Baselines-Team:master' into master
This commit is contained in:
commit
4e99b74e90
|
|
@ -18,6 +18,7 @@ New Features:
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
- Fixed MaskablePPO ignoring stats_window_size argument
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -222,8 +222,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
self.start_time = time.time_ns()
|
self.start_time = time.time_ns()
|
||||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
if self.ep_info_buffer is None or reset_num_timesteps:
|
||||||
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
||||||
self.ep_info_buffer = deque(maxlen=100)
|
self.ep_info_buffer = deque(maxlen=self._stats_window_size)
|
||||||
self.ep_success_buffer = deque(maxlen=100)
|
self.ep_success_buffer = deque(maxlen=self._stats_window_size)
|
||||||
|
|
||||||
if reset_num_timesteps:
|
if reset_num_timesteps:
|
||||||
self.num_timesteps = 0
|
self.num_timesteps = 0
|
||||||
|
|
|
||||||
|
|
@ -161,12 +161,13 @@ def test_advantage_normalization(normalize_advantage):
|
||||||
model.learn(64)
|
model.learn(64)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("algo", [TRPO, QRDQN])
|
@pytest.mark.parametrize("algo", [TRPO, QRDQN, MaskablePPO])
|
||||||
@pytest.mark.parametrize("stats_window_size", [1, 42])
|
@pytest.mark.parametrize("stats_window_size", [1, 42])
|
||||||
def test_ep_buffers_stats_window_size(algo, stats_window_size):
|
def test_ep_buffers_stats_window_size(algo, stats_window_size):
|
||||||
"""Set stats_window_size for logging to non-default value and check if
|
"""Set stats_window_size for logging to non-default value and check if
|
||||||
ep_info_buffer and ep_success_buffer are initialized to the correct length"""
|
ep_info_buffer and ep_success_buffer are initialized to the correct length"""
|
||||||
model = algo("MlpPolicy", "CartPole-v1", stats_window_size=stats_window_size)
|
env = InvalidActionEnvDiscrete() if algo == MaskablePPO else "CartPole-v1"
|
||||||
|
model = algo("MlpPolicy", env, stats_window_size=stats_window_size)
|
||||||
model.learn(total_timesteps=10)
|
model.learn(total_timesteps=10)
|
||||||
assert model.ep_info_buffer.maxlen == stats_window_size
|
assert model.ep_info_buffer.maxlen == stats_window_size
|
||||||
assert model.ep_success_buffer.maxlen == stats_window_size
|
assert model.ep_success_buffer.maxlen == stats_window_size
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue