Fixed shared_lstm argument in CNN and MultiInput Policies for RecurrentPPO (#90)
* fixed shared_lstm parameter in CNN and MultiInput Policies * updated tests * changelog * Fix FPS for recurrent PPO * Fix import * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
7e687ac47c
commit
fc68af8841
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.6.1a0 (WIP)
|
Release 1.6.1a1 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -15,8 +15,10 @@ New Features:
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
- Fixed the issue of wrongly passing policy arguments when using CnnLstmPolicy or MultiInputLstmPolicy with ``RecurrentPPO`` (@mlodel)
|
||||||
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
||||||
|
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
@ -294,3 +296,4 @@ Contributors:
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
||||||
|
@mlodel
|
||||||
|
|
|
||||||
|
|
@ -483,6 +483,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
lstm_hidden_size: int = 256,
|
lstm_hidden_size: int = 256,
|
||||||
n_lstm_layers: int = 1,
|
n_lstm_layers: int = 1,
|
||||||
|
shared_lstm: bool = False,
|
||||||
enable_critic_lstm: bool = True,
|
enable_critic_lstm: bool = True,
|
||||||
lstm_kwargs: Optional[Dict[str, Any]] = None,
|
lstm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
|
|
@ -506,6 +507,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
lstm_hidden_size,
|
lstm_hidden_size,
|
||||||
n_lstm_layers,
|
n_lstm_layers,
|
||||||
|
shared_lstm,
|
||||||
enable_critic_lstm,
|
enable_critic_lstm,
|
||||||
lstm_kwargs,
|
lstm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
@ -573,6 +575,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
lstm_hidden_size: int = 256,
|
lstm_hidden_size: int = 256,
|
||||||
n_lstm_layers: int = 1,
|
n_lstm_layers: int = 1,
|
||||||
|
shared_lstm: bool = False,
|
||||||
enable_critic_lstm: bool = True,
|
enable_critic_lstm: bool = True,
|
||||||
lstm_kwargs: Optional[Dict[str, Any]] = None,
|
lstm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
|
|
@ -596,6 +599,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
lstm_hidden_size,
|
lstm_hidden_size,
|
||||||
n_lstm_layers,
|
n_lstm_layers,
|
||||||
|
shared_lstm,
|
||||||
enable_critic_lstm,
|
enable_critic_lstm,
|
||||||
lstm_kwargs,
|
lstm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
@ -513,13 +514,14 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
# Display training infos
|
# Display training infos
|
||||||
if log_interval is not None and iteration % log_interval == 0:
|
if log_interval is not None and iteration % log_interval == 0:
|
||||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
|
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||||
|
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||||
self.logger.record("time/fps", fps)
|
self.logger.record("time/fps", fps)
|
||||||
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
|
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
||||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||||
self.logger.dump(step=self.num_timesteps)
|
self.logger.dump(step=self.num_timesteps)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.6.1a0
|
1.6.1a1
|
||||||
|
|
|
||||||
|
|
@ -57,13 +57,30 @@ class CartPoleNoVelEnv(CartPoleEnv):
|
||||||
return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info
|
return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info
|
||||||
|
|
||||||
|
|
||||||
def test_cnn():
|
@pytest.mark.parametrize(
|
||||||
|
"policy_kwargs",
|
||||||
|
[
|
||||||
|
{},
|
||||||
|
dict(shared_lstm=True, enable_critic_lstm=False),
|
||||||
|
dict(
|
||||||
|
enable_critic_lstm=True,
|
||||||
|
lstm_hidden_size=4,
|
||||||
|
lstm_kwargs=dict(dropout=0.5),
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
enable_critic_lstm=False,
|
||||||
|
lstm_hidden_size=4,
|
||||||
|
lstm_kwargs=dict(dropout=0.5),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_cnn(policy_kwargs):
|
||||||
model = RecurrentPPO(
|
model = RecurrentPPO(
|
||||||
"CnnLstmPolicy",
|
"CnnLstmPolicy",
|
||||||
FakeImageEnv(screen_height=40, screen_width=40, n_channels=3),
|
FakeImageEnv(screen_height=40, screen_width=40, n_channels=3),
|
||||||
n_steps=16,
|
n_steps=16,
|
||||||
seed=0,
|
seed=0,
|
||||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)),
|
||||||
)
|
)
|
||||||
|
|
||||||
model.learn(total_timesteps=32)
|
model.learn(total_timesteps=32)
|
||||||
|
|
@ -138,9 +155,26 @@ def test_run_sde():
|
||||||
model.learn(total_timesteps=200, eval_freq=150)
|
model.learn(total_timesteps=200, eval_freq=150)
|
||||||
|
|
||||||
|
|
||||||
def test_dict_obs():
|
@pytest.mark.parametrize(
|
||||||
|
"policy_kwargs",
|
||||||
|
[
|
||||||
|
{},
|
||||||
|
dict(shared_lstm=True, enable_critic_lstm=False),
|
||||||
|
dict(
|
||||||
|
enable_critic_lstm=True,
|
||||||
|
lstm_hidden_size=4,
|
||||||
|
lstm_kwargs=dict(dropout=0.5),
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
enable_critic_lstm=False,
|
||||||
|
lstm_hidden_size=4,
|
||||||
|
lstm_kwargs=dict(dropout=0.5),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_dict_obs(policy_kwargs):
|
||||||
env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper)
|
env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper)
|
||||||
model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32).learn(64)
|
model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32, policy_kwargs=policy_kwargs).learn(64)
|
||||||
evaluate_policy(model, env, warn=False)
|
evaluate_policy(model, env, warn=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue