Bugfix/ppo mask stats window size (#199)

* fixes issue #198 regarding stats_window_size

* updates changelog

* updates test of stats_window_size

* updates test using maskable env

* removes print statement
This commit is contained in:
PatrickHelm 2023-08-01 11:53:27 +02:00 committed by GitHub
parent 35f06254ba
commit dfa23bdf9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 4 deletions

View File

@ -18,6 +18,7 @@ New Features:
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^
- Fixed MaskablePPO ignoring stats_window_size argument
Deprecations: Deprecations:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^

View File

@ -222,8 +222,8 @@ class MaskablePPO(OnPolicyAlgorithm):
self.start_time = time.time_ns() self.start_time = time.time_ns()
if self.ep_info_buffer is None or reset_num_timesteps: if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters # Initialize buffers if they don't exist, or reinitialize if resetting counters
self.ep_info_buffer = deque(maxlen=100) self.ep_info_buffer = deque(maxlen=self._stats_window_size)
self.ep_success_buffer = deque(maxlen=100) self.ep_success_buffer = deque(maxlen=self._stats_window_size)
if reset_num_timesteps: if reset_num_timesteps:
self.num_timesteps = 0 self.num_timesteps = 0

View File

@ -150,12 +150,13 @@ def test_advantage_normalization(normalize_advantage):
model.learn(64) model.learn(64)
@pytest.mark.parametrize("algo", [TRPO, QRDQN]) @pytest.mark.parametrize("algo", [TRPO, QRDQN, MaskablePPO])
@pytest.mark.parametrize("stats_window_size", [1, 42]) @pytest.mark.parametrize("stats_window_size", [1, 42])
def test_ep_buffers_stats_window_size(algo, stats_window_size): def test_ep_buffers_stats_window_size(algo, stats_window_size):
"""Set stats_window_size for logging to non-default value and check if """Set stats_window_size for logging to non-default value and check if
ep_info_buffer and ep_success_buffer are initialized to the correct length""" ep_info_buffer and ep_success_buffer are initialized to the correct length"""
model = algo("MlpPolicy", "CartPole-v1", stats_window_size=stats_window_size) env = InvalidActionEnvDiscrete() if algo == MaskablePPO else "CartPole-v1"
model = algo("MlpPolicy", env, stats_window_size=stats_window_size)
model.learn(total_timesteps=10) model.learn(total_timesteps=10)
assert model.ep_info_buffer.maxlen == stats_window_size assert model.ep_info_buffer.maxlen == stats_window_size
assert model.ep_success_buffer.maxlen == stats_window_size assert model.ep_success_buffer.maxlen == stats_window_size