diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index ed9c1dd..29f832e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -18,6 +18,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Fixed MaskablePPO ignoring stats_window_size argument Deprecations: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index ca23a22..ffd831c 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -222,8 +222,8 @@ class MaskablePPO(OnPolicyAlgorithm): self.start_time = time.time_ns() if self.ep_info_buffer is None or reset_num_timesteps: # Initialize buffers if they don't exist, or reinitialize if resetting counters - self.ep_info_buffer = deque(maxlen=100) - self.ep_success_buffer = deque(maxlen=100) + self.ep_info_buffer = deque(maxlen=self._stats_window_size) + self.ep_success_buffer = deque(maxlen=self._stats_window_size) if reset_num_timesteps: self.num_timesteps = 0 diff --git a/tests/test_run.py b/tests/test_run.py index 9ce4c71..9da6525 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -161,12 +161,13 @@ def test_advantage_normalization(normalize_advantage): model.learn(64) -@pytest.mark.parametrize("algo", [TRPO, QRDQN]) +@pytest.mark.parametrize("algo", [TRPO, QRDQN, MaskablePPO]) @pytest.mark.parametrize("stats_window_size", [1, 42]) def test_ep_buffers_stats_window_size(algo, stats_window_size): """Set stats_window_size for logging to non-default value and check if ep_info_buffer and ep_success_buffer are initialized to the correct length""" - model = algo("MlpPolicy", "CartPole-v1", stats_window_size=stats_window_size) + env = InvalidActionEnvDiscrete() if algo == MaskablePPO else "CartPole-v1" + model = algo("MlpPolicy", env, stats_window_size=stats_window_size) model.learn(total_timesteps=10) assert model.ep_info_buffer.maxlen == stats_window_size assert model.ep_success_buffer.maxlen == stats_window_size