Fixed shared_lstm argument in CNN and MultiInput Policies for RecurrentPPO (#90)

* fixed shared_lstm parameter in CNN and MultiInput Policies

* updated tests

* changelog

* Fix FPS for recurrent PPO

* Fix import

* Update changelog

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Max Lodel 2022-07-26 00:27:17 +02:00 committed by GitHub
parent 7e687ac47c
commit fc68af8841
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 8 deletions

View File

@ -3,7 +3,7 @@
Changelog Changelog
========== ==========
Release 1.6.1a0 (WIP) Release 1.6.1a1 (WIP)
------------------------------- -------------------------------
Breaking Changes: Breaking Changes:
@ -15,8 +15,10 @@ New Features:
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^
- Fixed the issue of wrongly passing policy arguments when using CnnLstmPolicy or MultiInputLstmPolicy with ``RecurrentPPO`` (@mlodel)
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. - Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
Deprecations: Deprecations:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
@ -294,3 +296,4 @@ Contributors:
------------- -------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
@mlodel

View File

@ -483,6 +483,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
lstm_hidden_size: int = 256, lstm_hidden_size: int = 256,
n_lstm_layers: int = 1, n_lstm_layers: int = 1,
shared_lstm: bool = False,
enable_critic_lstm: bool = True, enable_critic_lstm: bool = True,
lstm_kwargs: Optional[Dict[str, Any]] = None, lstm_kwargs: Optional[Dict[str, Any]] = None,
): ):
@ -506,6 +507,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
optimizer_kwargs, optimizer_kwargs,
lstm_hidden_size, lstm_hidden_size,
n_lstm_layers, n_lstm_layers,
shared_lstm,
enable_critic_lstm, enable_critic_lstm,
lstm_kwargs, lstm_kwargs,
) )
@ -573,6 +575,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
lstm_hidden_size: int = 256, lstm_hidden_size: int = 256,
n_lstm_layers: int = 1, n_lstm_layers: int = 1,
shared_lstm: bool = False,
enable_critic_lstm: bool = True, enable_critic_lstm: bool = True,
lstm_kwargs: Optional[Dict[str, Any]] = None, lstm_kwargs: Optional[Dict[str, Any]] = None,
): ):
@ -596,6 +599,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
optimizer_kwargs, optimizer_kwargs,
lstm_hidden_size, lstm_hidden_size,
n_lstm_layers, n_lstm_layers,
shared_lstm,
enable_critic_lstm, enable_critic_lstm,
lstm_kwargs, lstm_kwargs,
) )

View File

@ -1,3 +1,4 @@
import sys
import time import time
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Optional, Tuple, Type, Union from typing import Any, Dict, Optional, Tuple, Type, Union
@ -513,13 +514,14 @@ class RecurrentPPO(OnPolicyAlgorithm):
# Display training infos # Display training infos
if log_interval is not None and iteration % log_interval == 0: if log_interval is not None and iteration % log_interval == 0:
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time)) time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard") self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps) self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps) self.logger.dump(step=self.num_timesteps)

View File

@ -1 +1 @@
1.6.1a0 1.6.1a1

View File

@ -57,13 +57,30 @@ class CartPoleNoVelEnv(CartPoleEnv):
return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info
def test_cnn(): @pytest.mark.parametrize(
"policy_kwargs",
[
{},
dict(shared_lstm=True, enable_critic_lstm=False),
dict(
enable_critic_lstm=True,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
),
dict(
enable_critic_lstm=False,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
),
],
)
def test_cnn(policy_kwargs):
model = RecurrentPPO( model = RecurrentPPO(
"CnnLstmPolicy", "CnnLstmPolicy",
FakeImageEnv(screen_height=40, screen_width=40, n_channels=3), FakeImageEnv(screen_height=40, screen_width=40, n_channels=3),
n_steps=16, n_steps=16,
seed=0, seed=0,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)),
) )
model.learn(total_timesteps=32) model.learn(total_timesteps=32)
@ -138,9 +155,26 @@ def test_run_sde():
model.learn(total_timesteps=200, eval_freq=150) model.learn(total_timesteps=200, eval_freq=150)
def test_dict_obs(): @pytest.mark.parametrize(
"policy_kwargs",
[
{},
dict(shared_lstm=True, enable_critic_lstm=False),
dict(
enable_critic_lstm=True,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
),
dict(
enable_critic_lstm=False,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
),
],
)
def test_dict_obs(policy_kwargs):
env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper) env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper)
model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32).learn(64) model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32, policy_kwargs=policy_kwargs).learn(64)
evaluate_policy(model, env, warn=False) evaluate_policy(model, env, warn=False)