[Feature] Non-shared features extractor in on-policy algorithms (#130)
* Modified sb3_contrib/common/maskable/policies.py - Added support for non-shared features extractor in file sb3_contrib/common/maskable/policies.py - updated changelog * Modified sb3_contrib/common/recurrent/policies.py * Modified sb3_contrib/qrdqn/policies.py and sb3_contrib/tqc/policies.py * Updated test_cnn.py * Upgrade SB3 version * Revert changes in formatting * Remove duplicate normalize_images * Add test for image-like inputs * Fixes and add more tests * Update SB3 version * Fix ARS warnings Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
6b23c6cfe3
commit
ab8684f469
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.7.0a4 (WIP)
|
Release 1.7.0a10 (WIP)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -11,12 +11,15 @@ Breaking Changes:
|
||||||
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
||||||
please use an ``EvalCallback`` instead
|
please use an ``EvalCallback`` instead
|
||||||
- Removed deprecated ``sde_net_arch`` parameter
|
- Removed deprecated ``sde_net_arch`` parameter
|
||||||
- Upgraded to Stable-Baselines3 >= 1.7.0a4
|
- Upgraded to Stable-Baselines3 >= 1.7.0a10
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
- Introduced mypy type checking
|
- Introduced mypy type checking
|
||||||
- Added ``with_bias`` parameter to ``ARSPolicy``
|
- Added ``with_bias`` parameter to ``ARSPolicy``
|
||||||
|
- Added option to have non-shared features extractor between actor and critic in on-policy algorithms (@AlexPasqua)
|
||||||
|
- Features extractors now properly support unnormalized image-like observations (3D tensor)
|
||||||
|
when passing ``normalize_images=False``
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
@ -25,6 +28,7 @@ Bug Fixes:
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
|
||||||
|
|
||||||
Others:
|
Others:
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
@ -364,4 +368,4 @@ Contributors:
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
||||||
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong
|
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ class ARSPolicy(BasePolicy):
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||||
|
|
||||||
features = self.extract_features(obs)
|
features = self.extract_features(obs, self.features_extractor)
|
||||||
if isinstance(self.action_space, gym.spaces.Box):
|
if isinstance(self.action_space, gym.spaces.Box):
|
||||||
return self.action_net(features)
|
return self.action_net(features)
|
||||||
elif isinstance(self.action_space, gym.spaces.Discrete):
|
elif isinstance(self.action_space, gym.spaces.Discrete):
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
:param features_extractor_class: Features extractor to use.
|
:param features_extractor_class: Features extractor to use.
|
||||||
:param features_extractor_kwargs: Keyword arguments
|
:param features_extractor_kwargs: Keyword arguments
|
||||||
to pass to the features extractor.
|
to pass to the features extractor.
|
||||||
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||||
:param normalize_images: Whether to normalize images or not,
|
:param normalize_images: Whether to normalize images or not,
|
||||||
dividing by 255.0 (True by default)
|
dividing by 255.0 (True by default)
|
||||||
:param optimizer_class: The optimizer to use,
|
:param optimizer_class: The optimizer to use,
|
||||||
|
|
@ -50,6 +51,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
|
@ -68,6 +70,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
features_extractor_kwargs,
|
features_extractor_kwargs,
|
||||||
optimizer_class=optimizer_class,
|
optimizer_class=optimizer_class,
|
||||||
optimizer_kwargs=optimizer_kwargs,
|
optimizer_kwargs=optimizer_kwargs,
|
||||||
|
normalize_images=normalize_images,
|
||||||
squash_output=False,
|
squash_output=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -82,10 +85,21 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
self.activation_fn = activation_fn
|
self.activation_fn = activation_fn
|
||||||
self.ortho_init = ortho_init
|
self.ortho_init = ortho_init
|
||||||
|
|
||||||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
self.share_features_extractor = share_features_extractor
|
||||||
|
self.features_extractor = self.make_features_extractor()
|
||||||
self.features_dim = self.features_extractor.features_dim
|
self.features_dim = self.features_extractor.features_dim
|
||||||
|
if self.share_features_extractor:
|
||||||
|
self.pi_features_extractor = self.features_extractor
|
||||||
|
self.vf_features_extractor = self.features_extractor
|
||||||
|
else:
|
||||||
|
self.pi_features_extractor = self.features_extractor
|
||||||
|
self.vf_features_extractor = self.make_features_extractor()
|
||||||
|
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
|
||||||
|
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
|
||||||
|
raise ValueError(
|
||||||
|
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
|
||||||
|
)
|
||||||
|
|
||||||
self.normalize_images = normalize_images
|
|
||||||
# Action distribution
|
# Action distribution
|
||||||
self.action_dist = make_masked_proba_distribution(action_space)
|
self.action_dist = make_masked_proba_distribution(action_space)
|
||||||
|
|
||||||
|
|
@ -107,7 +121,12 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
"""
|
"""
|
||||||
# Preprocess the observation if needed
|
# Preprocess the observation if needed
|
||||||
features = self.extract_features(obs)
|
features = self.extract_features(obs)
|
||||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
if self.share_features_extractor:
|
||||||
|
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||||
|
else:
|
||||||
|
pi_features, vf_features = features
|
||||||
|
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
||||||
|
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
||||||
# Evaluate the values for the given observations
|
# Evaluate the values for the given observations
|
||||||
values = self.value_net(latent_vf)
|
values = self.value_net(latent_vf)
|
||||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
|
|
@ -117,6 +136,19 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
log_prob = distribution.log_prob(actions)
|
log_prob = distribution.log_prob(actions)
|
||||||
return actions, values, log_prob
|
return actions, values, log_prob
|
||||||
|
|
||||||
|
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||||
|
"""
|
||||||
|
Preprocess the observation if needed and extract features.
|
||||||
|
:param obs: Observation
|
||||||
|
:return: the output of the features extractor(s)
|
||||||
|
"""
|
||||||
|
if self.share_features_extractor:
|
||||||
|
return super().extract_features(obs, self.features_extractor)
|
||||||
|
else:
|
||||||
|
pi_features = super().extract_features(obs, self.pi_features_extractor)
|
||||||
|
vf_features = super().extract_features(obs, self.vf_features_extractor)
|
||||||
|
return pi_features, vf_features
|
||||||
|
|
||||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
data = super()._get_constructor_parameters()
|
data = super()._get_constructor_parameters()
|
||||||
|
|
||||||
|
|
@ -174,6 +206,13 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
self.action_net: 0.01,
|
self.action_net: 0.01,
|
||||||
self.value_net: 1,
|
self.value_net: 1,
|
||||||
}
|
}
|
||||||
|
if not self.share_features_extractor:
|
||||||
|
# Note(antonin): this is to keep SB3 results
|
||||||
|
# consistent, see GH#1148
|
||||||
|
del module_gains[self.features_extractor]
|
||||||
|
module_gains[self.pi_features_extractor] = np.sqrt(2)
|
||||||
|
module_gains[self.vf_features_extractor] = np.sqrt(2)
|
||||||
|
|
||||||
for module, gain in module_gains.items():
|
for module, gain in module_gains.items():
|
||||||
module.apply(partial(self.init_weights, gain=gain))
|
module.apply(partial(self.init_weights, gain=gain))
|
||||||
|
|
||||||
|
|
@ -268,13 +307,19 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
Evaluate actions according to the current policy,
|
Evaluate actions according to the current policy,
|
||||||
given the observations.
|
given the observations.
|
||||||
|
|
||||||
:param obs:
|
:param obs: Observation
|
||||||
:param actions:
|
:param actions: Actions
|
||||||
:return: estimated value, log likelihood of taking those actions
|
:return: estimated value, log likelihood of taking those actions
|
||||||
and entropy of the action distribution.
|
and entropy of the action distribution.
|
||||||
"""
|
"""
|
||||||
features = self.extract_features(obs)
|
features = self.extract_features(obs)
|
||||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
if self.share_features_extractor:
|
||||||
|
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||||
|
else:
|
||||||
|
pi_features, vf_features = features
|
||||||
|
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
||||||
|
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
||||||
|
|
||||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
if action_masks is not None:
|
if action_masks is not None:
|
||||||
distribution.apply_masking(action_masks)
|
distribution.apply_masking(action_masks)
|
||||||
|
|
@ -286,11 +331,11 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
"""
|
"""
|
||||||
Get the current policy distribution given the observations.
|
Get the current policy distribution given the observations.
|
||||||
|
|
||||||
:param obs:
|
:param obs: Observation
|
||||||
:param action_masks:
|
:param action_masks: Actions' mask
|
||||||
:return: the action distribution.
|
:return: the action distribution.
|
||||||
"""
|
"""
|
||||||
features = self.extract_features(obs)
|
features = super().extract_features(obs, self.pi_features_extractor)
|
||||||
latent_pi = self.mlp_extractor.forward_actor(features)
|
latent_pi = self.mlp_extractor.forward_actor(features)
|
||||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
if action_masks is not None:
|
if action_masks is not None:
|
||||||
|
|
@ -301,10 +346,10 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
"""
|
"""
|
||||||
Get the estimated values according to the current policy given the observations.
|
Get the estimated values according to the current policy given the observations.
|
||||||
|
|
||||||
:param obs:
|
:param obs: Observation
|
||||||
:return: the estimated values.
|
:return: the estimated values.
|
||||||
"""
|
"""
|
||||||
features = self.extract_features(obs)
|
features = super().extract_features(obs, self.vf_features_extractor)
|
||||||
latent_vf = self.mlp_extractor.forward_critic(features)
|
latent_vf = self.mlp_extractor.forward_critic(features)
|
||||||
return self.value_net(latent_vf)
|
return self.value_net(latent_vf)
|
||||||
|
|
||||||
|
|
@ -323,6 +368,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
||||||
:param features_extractor_class: Features extractor to use.
|
:param features_extractor_class: Features extractor to use.
|
||||||
:param features_extractor_kwargs: Keyword arguments
|
:param features_extractor_kwargs: Keyword arguments
|
||||||
to pass to the features extractor.
|
to pass to the features extractor.
|
||||||
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||||
:param normalize_images: Whether to normalize images or not,
|
:param normalize_images: Whether to normalize images or not,
|
||||||
dividing by 255.0 (True by default)
|
dividing by 255.0 (True by default)
|
||||||
:param optimizer_class: The optimizer to use,
|
:param optimizer_class: The optimizer to use,
|
||||||
|
|
@ -341,6 +387,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
|
@ -354,6 +401,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
||||||
ortho_init,
|
ortho_init,
|
||||||
features_extractor_class,
|
features_extractor_class,
|
||||||
features_extractor_kwargs,
|
features_extractor_kwargs,
|
||||||
|
share_features_extractor,
|
||||||
normalize_images,
|
normalize_images,
|
||||||
optimizer_class,
|
optimizer_class,
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
|
|
@ -374,6 +422,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
||||||
:param features_extractor_class: Uses the CombinedExtractor
|
:param features_extractor_class: Uses the CombinedExtractor
|
||||||
:param features_extractor_kwargs: Keyword arguments
|
:param features_extractor_kwargs: Keyword arguments
|
||||||
to pass to the feature extractor.
|
to pass to the feature extractor.
|
||||||
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||||
:param normalize_images: Whether to normalize images or not,
|
:param normalize_images: Whether to normalize images or not,
|
||||||
dividing by 255.0 (True by default)
|
dividing by 255.0 (True by default)
|
||||||
:param optimizer_class: The optimizer to use,
|
:param optimizer_class: The optimizer to use,
|
||||||
|
|
@ -392,6 +441,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
|
@ -405,6 +455,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
||||||
ortho_init,
|
ortho_init,
|
||||||
features_extractor_class,
|
features_extractor_class,
|
||||||
features_extractor_kwargs,
|
features_extractor_kwargs,
|
||||||
|
share_features_extractor,
|
||||||
normalize_images,
|
normalize_images,
|
||||||
optimizer_class,
|
optimizer_class,
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
:param features_extractor_class: Features extractor to use.
|
:param features_extractor_class: Features extractor to use.
|
||||||
:param features_extractor_kwargs: Keyword arguments
|
:param features_extractor_kwargs: Keyword arguments
|
||||||
to pass to the features extractor.
|
to pass to the features extractor.
|
||||||
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||||
:param normalize_images: Whether to normalize images or not,
|
:param normalize_images: Whether to normalize images or not,
|
||||||
dividing by 255.0 (True by default)
|
dividing by 255.0 (True by default)
|
||||||
:param optimizer_class: The optimizer to use,
|
:param optimizer_class: The optimizer to use,
|
||||||
|
|
@ -75,6 +76,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
squash_output: bool = False,
|
squash_output: bool = False,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
|
@ -99,6 +101,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
squash_output,
|
squash_output,
|
||||||
features_extractor_class,
|
features_extractor_class,
|
||||||
features_extractor_kwargs,
|
features_extractor_kwargs,
|
||||||
|
share_features_extractor,
|
||||||
normalize_images,
|
normalize_images,
|
||||||
optimizer_class,
|
optimizer_class,
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
|
|
@ -120,7 +123,11 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
self.lstm_critic = None
|
self.lstm_critic = None
|
||||||
assert not (
|
assert not (
|
||||||
self.shared_lstm and self.enable_critic_lstm
|
self.shared_lstm and self.enable_critic_lstm
|
||||||
), "You must choose between shared LSTM, seperate or no LSTM for the critic"
|
), "You must choose between shared LSTM, seperate or no LSTM for the critic."
|
||||||
|
|
||||||
|
assert not (
|
||||||
|
self.shared_lstm and not self.share_features_extractor
|
||||||
|
), "If the features extractor is not shared, the LSTM cannot be shared."
|
||||||
|
|
||||||
# No LSTM for the critic, we still need to convert
|
# No LSTM for the critic, we still need to convert
|
||||||
# output of features extractor to the correct size
|
# output of features extractor to the correct size
|
||||||
|
|
@ -222,17 +229,21 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
"""
|
"""
|
||||||
# Preprocess the observation if needed
|
# Preprocess the observation if needed
|
||||||
features = self.extract_features(obs)
|
features = self.extract_features(obs)
|
||||||
|
if self.share_features_extractor:
|
||||||
|
pi_features = vf_features = features # alis
|
||||||
|
else:
|
||||||
|
pi_features, vf_features = features
|
||||||
# latent_pi, latent_vf = self.mlp_extractor(features)
|
# latent_pi, latent_vf = self.mlp_extractor(features)
|
||||||
latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor)
|
latent_pi, lstm_states_pi = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor)
|
||||||
if self.lstm_critic is not None:
|
if self.lstm_critic is not None:
|
||||||
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic)
|
latent_vf, lstm_states_vf = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic)
|
||||||
elif self.shared_lstm:
|
elif self.shared_lstm:
|
||||||
# Re-use LSTM features but do not backpropagate
|
# Re-use LSTM features but do not backpropagate
|
||||||
latent_vf = latent_pi.detach()
|
latent_vf = latent_pi.detach()
|
||||||
lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach())
|
lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach())
|
||||||
else:
|
else:
|
||||||
# Critic only has a feedforward network
|
# Critic only has a feedforward network
|
||||||
latent_vf = self.critic(features)
|
latent_vf = self.critic(vf_features)
|
||||||
lstm_states_vf = lstm_states_pi
|
lstm_states_vf = lstm_states_pi
|
||||||
|
|
||||||
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
||||||
|
|
@ -260,7 +271,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
or not (we reset the lstm states in that case).
|
or not (we reset the lstm states in that case).
|
||||||
:return: the action distribution and new hidden states.
|
:return: the action distribution and new hidden states.
|
||||||
"""
|
"""
|
||||||
features = self.extract_features(obs)
|
# Call the method from the parent of the parent class
|
||||||
|
features = super(ActorCriticPolicy, self).extract_features(obs, self.pi_features_extractor)
|
||||||
latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor)
|
latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor)
|
||||||
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
||||||
return self._get_action_dist_from_latent(latent_pi), lstm_states
|
return self._get_action_dist_from_latent(latent_pi), lstm_states
|
||||||
|
|
@ -280,7 +292,9 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
or not (we reset the lstm states in that case).
|
or not (we reset the lstm states in that case).
|
||||||
:return: the estimated values.
|
:return: the estimated values.
|
||||||
"""
|
"""
|
||||||
features = self.extract_features(obs)
|
# Call the method from the parent of the parent class
|
||||||
|
features = super(ActorCriticPolicy, self).extract_features(obs, self.vf_features_extractor)
|
||||||
|
|
||||||
if self.lstm_critic is not None:
|
if self.lstm_critic is not None:
|
||||||
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic)
|
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic)
|
||||||
elif self.shared_lstm:
|
elif self.shared_lstm:
|
||||||
|
|
@ -294,11 +308,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
return self.value_net(latent_vf)
|
return self.value_net(latent_vf)
|
||||||
|
|
||||||
def evaluate_actions(
|
def evaluate_actions(
|
||||||
self,
|
self, obs: th.Tensor, actions: th.Tensor, lstm_states: RNNStates, episode_starts: th.Tensor
|
||||||
obs: th.Tensor,
|
|
||||||
actions: th.Tensor,
|
|
||||||
lstm_states: RNNStates,
|
|
||||||
episode_starts: th.Tensor,
|
|
||||||
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||||
"""
|
"""
|
||||||
Evaluate actions according to the current policy,
|
Evaluate actions according to the current policy,
|
||||||
|
|
@ -314,14 +324,17 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
"""
|
"""
|
||||||
# Preprocess the observation if needed
|
# Preprocess the observation if needed
|
||||||
features = self.extract_features(obs)
|
features = self.extract_features(obs)
|
||||||
latent_pi, _ = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor)
|
if self.share_features_extractor:
|
||||||
|
pi_features = vf_features = features # alias
|
||||||
|
else:
|
||||||
|
pi_features, vf_features = features
|
||||||
|
latent_pi, _ = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor)
|
||||||
if self.lstm_critic is not None:
|
if self.lstm_critic is not None:
|
||||||
latent_vf, _ = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic)
|
latent_vf, _ = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic)
|
||||||
elif self.shared_lstm:
|
elif self.shared_lstm:
|
||||||
latent_vf = latent_pi.detach()
|
latent_vf = latent_pi.detach()
|
||||||
else:
|
else:
|
||||||
latent_vf = self.critic(features)
|
latent_vf = self.critic(vf_features)
|
||||||
|
|
||||||
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
||||||
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
|
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
|
||||||
|
|
@ -439,6 +452,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||||
:param features_extractor_class: Features extractor to use.
|
:param features_extractor_class: Features extractor to use.
|
||||||
:param features_extractor_kwargs: Keyword arguments
|
:param features_extractor_kwargs: Keyword arguments
|
||||||
to pass to the features extractor.
|
to pass to the features extractor.
|
||||||
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||||
:param normalize_images: Whether to normalize images or not,
|
:param normalize_images: Whether to normalize images or not,
|
||||||
dividing by 255.0 (True by default)
|
dividing by 255.0 (True by default)
|
||||||
:param optimizer_class: The optimizer to use,
|
:param optimizer_class: The optimizer to use,
|
||||||
|
|
@ -469,6 +483,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||||
squash_output: bool = False,
|
squash_output: bool = False,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
|
@ -492,6 +507,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||||
squash_output,
|
squash_output,
|
||||||
features_extractor_class,
|
features_extractor_class,
|
||||||
features_extractor_kwargs,
|
features_extractor_kwargs,
|
||||||
|
share_features_extractor,
|
||||||
normalize_images,
|
normalize_images,
|
||||||
optimizer_class,
|
optimizer_class,
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
|
|
@ -526,6 +542,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||||
:param features_extractor_class: Features extractor to use.
|
:param features_extractor_class: Features extractor to use.
|
||||||
:param features_extractor_kwargs: Keyword arguments
|
:param features_extractor_kwargs: Keyword arguments
|
||||||
to pass to the features extractor.
|
to pass to the features extractor.
|
||||||
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||||
:param normalize_images: Whether to normalize images or not,
|
:param normalize_images: Whether to normalize images or not,
|
||||||
dividing by 255.0 (True by default)
|
dividing by 255.0 (True by default)
|
||||||
:param optimizer_class: The optimizer to use,
|
:param optimizer_class: The optimizer to use,
|
||||||
|
|
@ -556,6 +573,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||||
squash_output: bool = False,
|
squash_output: bool = False,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
|
@ -579,6 +597,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||||
squash_output,
|
squash_output,
|
||||||
features_extractor_class,
|
features_extractor_class,
|
||||||
features_extractor_kwargs,
|
features_extractor_kwargs,
|
||||||
|
share_features_extractor,
|
||||||
normalize_images,
|
normalize_images,
|
||||||
optimizer_class,
|
optimizer_class,
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,6 @@ class QuantileNetwork(BasePolicy):
|
||||||
self.features_extractor = features_extractor
|
self.features_extractor = features_extractor
|
||||||
self.features_dim = features_dim
|
self.features_dim = features_dim
|
||||||
self.n_quantiles = n_quantiles
|
self.n_quantiles = n_quantiles
|
||||||
self.normalize_images = normalize_images
|
|
||||||
action_dim = self.action_space.n # number of actions
|
action_dim = self.action_space.n # number of actions
|
||||||
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
|
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
|
||||||
self.quantile_net = nn.Sequential(*quantile_net)
|
self.quantile_net = nn.Sequential(*quantile_net)
|
||||||
|
|
@ -65,7 +64,7 @@ class QuantileNetwork(BasePolicy):
|
||||||
:param obs: Observation
|
:param obs: Observation
|
||||||
:return: The estimated quantiles for each action.
|
:return: The estimated quantiles for each action.
|
||||||
"""
|
"""
|
||||||
quantiles = self.quantile_net(self.extract_features(obs))
|
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
|
||||||
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
|
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
|
||||||
|
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||||
|
|
@ -132,6 +131,7 @@ class QRDQNPolicy(BasePolicy):
|
||||||
features_extractor_kwargs,
|
features_extractor_kwargs,
|
||||||
optimizer_class=optimizer_class,
|
optimizer_class=optimizer_class,
|
||||||
optimizer_kwargs=optimizer_kwargs,
|
optimizer_kwargs=optimizer_kwargs,
|
||||||
|
normalize_images=normalize_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
if net_arch is None:
|
if net_arch is None:
|
||||||
|
|
@ -143,7 +143,6 @@ class QRDQNPolicy(BasePolicy):
|
||||||
self.n_quantiles = n_quantiles
|
self.n_quantiles = n_quantiles
|
||||||
self.net_arch = net_arch
|
self.net_arch = net_arch
|
||||||
self.activation_fn = activation_fn
|
self.activation_fn = activation_fn
|
||||||
self.normalize_images = normalize_images
|
|
||||||
|
|
||||||
self.net_args = {
|
self.net_args = {
|
||||||
"observation_space": self.observation_space,
|
"observation_space": self.observation_space,
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,7 @@ class Actor(BasePolicy):
|
||||||
:return:
|
:return:
|
||||||
Mean, standard deviation and optional keyword arguments.
|
Mean, standard deviation and optional keyword arguments.
|
||||||
"""
|
"""
|
||||||
features = self.extract_features(obs)
|
features = self.extract_features(obs, self.features_extractor)
|
||||||
latent_pi = self.latent_pi(features)
|
latent_pi = self.latent_pi(features)
|
||||||
mean_actions = self.mu(latent_pi)
|
mean_actions = self.mu(latent_pi)
|
||||||
|
|
||||||
|
|
@ -230,7 +230,7 @@ class Critic(BaseModel):
|
||||||
# Learn the features extractor using the policy loss only
|
# Learn the features extractor using the policy loss only
|
||||||
# when the features_extractor is shared with the actor
|
# when the features_extractor is shared with the actor
|
||||||
with th.set_grad_enabled(not self.share_features_extractor):
|
with th.set_grad_enabled(not self.share_features_extractor):
|
||||||
features = self.extract_features(obs)
|
features = self.extract_features(obs, self.features_extractor)
|
||||||
qvalue_input = th.cat([features, action], dim=1)
|
qvalue_input = th.cat([features, action], dim=1)
|
||||||
quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1)
|
quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1)
|
||||||
return quantiles
|
return quantiles
|
||||||
|
|
@ -293,6 +293,7 @@ class TQCPolicy(BasePolicy):
|
||||||
features_extractor_kwargs,
|
features_extractor_kwargs,
|
||||||
optimizer_class=optimizer_class,
|
optimizer_class=optimizer_class,
|
||||||
optimizer_kwargs=optimizer_kwargs,
|
optimizer_kwargs=optimizer_kwargs,
|
||||||
|
normalize_images=normalize_images,
|
||||||
squash_output=True,
|
squash_output=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.7.0a4
|
1.7.0a10
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.7.0a4",
|
"stable_baselines3>=1.7.0a10",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -6,13 +7,15 @@ import pytest
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.envs import FakeImageEnv
|
from stable_baselines3.common.envs import FakeImageEnv
|
||||||
from stable_baselines3.common.utils import zip_strict
|
from stable_baselines3.common.utils import zip_strict
|
||||||
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize, VecTransposeImage, is_vecenv_wrapped
|
||||||
|
|
||||||
from sb3_contrib import QRDQN, TQC, TRPO
|
from sb3_contrib import QRDQN, TQC, TRPO, MaskablePPO, RecurrentPPO
|
||||||
|
from sb3_contrib.common.wrappers import ActionMasker
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO])
|
@pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO])
|
||||||
def test_cnn(tmp_path, model_class):
|
@pytest.mark.parametrize("share_features_extractor", [True, False])
|
||||||
|
def test_cnn(tmp_path, model_class, share_features_extractor):
|
||||||
SAVE_NAME = "cnn_model.zip"
|
SAVE_NAME = "cnn_model.zip"
|
||||||
# Fake grayscale with frameskip
|
# Fake grayscale with frameskip
|
||||||
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
||||||
|
|
@ -23,8 +26,11 @@ def test_cnn(tmp_path, model_class):
|
||||||
n_channels=1,
|
n_channels=1,
|
||||||
discrete=model_class not in {TQC},
|
discrete=model_class not in {TQC},
|
||||||
)
|
)
|
||||||
kwargs = {}
|
kwargs = dict(policy_kwargs=dict(share_features_extractor=share_features_extractor))
|
||||||
if model_class in {TQC, QRDQN}:
|
if model_class in {TQC, QRDQN}:
|
||||||
|
# share_features_extractor is checked later for offpolicy algorithms
|
||||||
|
if share_features_extractor:
|
||||||
|
return
|
||||||
# Avoid memory error when using replay buffer
|
# Avoid memory error when using replay buffer
|
||||||
# Reduce the size of the features and the number of quantiles
|
# Reduce the size of the features and the number of quantiles
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
|
|
@ -34,6 +40,7 @@ def test_cnn(tmp_path, model_class):
|
||||||
features_extractor_kwargs=dict(features_dim=32),
|
features_extractor_kwargs=dict(features_dim=32),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||||
|
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
@ -155,3 +162,53 @@ def test_feature_extractor_target_net(model_class, share_features_extractor):
|
||||||
|
|
||||||
# Critic should be the same
|
# Critic should be the same
|
||||||
params_should_match(original_param, model.critic.parameters())
|
params_should_match(original_param, model.critic.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_class", [TRPO, MaskablePPO, RecurrentPPO, QRDQN, TQC])
|
||||||
|
@pytest.mark.parametrize("normalize_images", [True, False])
|
||||||
|
def test_image_like_input(model_class, normalize_images):
|
||||||
|
"""
|
||||||
|
Check that we can handle image-like input (3D tensor)
|
||||||
|
when normalize_images=False
|
||||||
|
"""
|
||||||
|
# Fake grayscale with frameskip
|
||||||
|
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
||||||
|
# to check that the network handle it automatically
|
||||||
|
env = FakeImageEnv(
|
||||||
|
screen_height=36,
|
||||||
|
screen_width=36,
|
||||||
|
n_channels=1,
|
||||||
|
channel_first=True,
|
||||||
|
discrete=model_class not in {TQC},
|
||||||
|
)
|
||||||
|
if model_class == MaskablePPO:
|
||||||
|
|
||||||
|
def action_mask_fn(env):
|
||||||
|
random_invalid_action = random.randrange(env.action_space.n)
|
||||||
|
return [i != random_invalid_action for i in range(env.action_space.n)]
|
||||||
|
|
||||||
|
env = ActionMasker(env, action_mask_fn)
|
||||||
|
|
||||||
|
vec_env = VecNormalize(DummyVecEnv([lambda: env]))
|
||||||
|
# Reduce the size of the features
|
||||||
|
# deactivate normalization
|
||||||
|
kwargs = dict(
|
||||||
|
policy_kwargs=dict(
|
||||||
|
normalize_images=normalize_images,
|
||||||
|
features_extractor_kwargs=dict(features_dim=32),
|
||||||
|
),
|
||||||
|
seed=1,
|
||||||
|
)
|
||||||
|
policy = "CnnLstmPolicy" if model_class == RecurrentPPO else "CnnPolicy"
|
||||||
|
|
||||||
|
if model_class in {TRPO, MaskablePPO, RecurrentPPO}:
|
||||||
|
kwargs.update(dict(n_steps=64, batch_size=64))
|
||||||
|
else:
|
||||||
|
# Avoid memory error when using replay buffer
|
||||||
|
# Reduce the size of the features
|
||||||
|
kwargs.update(dict(buffer_size=250))
|
||||||
|
if normalize_images:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
model_class(policy, vec_env, **kwargs).learn(128)
|
||||||
|
else:
|
||||||
|
model_class(policy, vec_env, **kwargs).learn(128)
|
||||||
|
|
|
||||||
|
|
@ -226,7 +226,8 @@ def test_discrete_action_space_required():
|
||||||
MaskablePPO("MlpPolicy", env)
|
MaskablePPO("MlpPolicy", env)
|
||||||
|
|
||||||
|
|
||||||
def test_cnn():
|
@pytest.mark.parametrize("share_features_extractor", [True, False])
|
||||||
|
def test_cnn(share_features_extractor):
|
||||||
def action_mask_fn(env):
|
def action_mask_fn(env):
|
||||||
random_invalid_action = random.randrange(env.action_space.n)
|
random_invalid_action = random.randrange(env.action_space.n)
|
||||||
return [i != random_invalid_action for i in range(env.action_space.n)]
|
return [i != random_invalid_action for i in range(env.action_space.n)]
|
||||||
|
|
@ -242,6 +243,7 @@ def test_cnn():
|
||||||
verbose=1,
|
verbose=1,
|
||||||
policy_kwargs=dict(
|
policy_kwargs=dict(
|
||||||
features_extractor_kwargs=dict(features_dim=32),
|
features_extractor_kwargs=dict(features_dim=32),
|
||||||
|
share_features_extractor=share_features_extractor,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
model.learn(100)
|
model.learn(100)
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,7 @@ class CartPoleNoVelEnv(CartPoleEnv):
|
||||||
"policy_kwargs",
|
"policy_kwargs",
|
||||||
[
|
[
|
||||||
{},
|
{},
|
||||||
|
{"share_features_extractor": False},
|
||||||
dict(shared_lstm=True, enable_critic_lstm=False),
|
dict(shared_lstm=True, enable_critic_lstm=False),
|
||||||
dict(
|
dict(
|
||||||
enable_critic_lstm=True,
|
enable_critic_lstm=True,
|
||||||
|
|
@ -74,6 +75,11 @@ class CartPoleNoVelEnv(CartPoleEnv):
|
||||||
lstm_kwargs=dict(dropout=0.5),
|
lstm_kwargs=dict(dropout=0.5),
|
||||||
n_lstm_layers=2,
|
n_lstm_layers=2,
|
||||||
),
|
),
|
||||||
|
dict(
|
||||||
|
enable_critic_lstm=False,
|
||||||
|
lstm_hidden_size=4,
|
||||||
|
share_features_extractor=False,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_cnn(policy_kwargs):
|
def test_cnn(policy_kwargs):
|
||||||
|
|
@ -83,6 +89,7 @@ def test_cnn(policy_kwargs):
|
||||||
n_steps=16,
|
n_steps=16,
|
||||||
seed=0,
|
seed=0,
|
||||||
policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)),
|
policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)),
|
||||||
|
n_epochs=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.learn(total_timesteps=32)
|
model.learn(total_timesteps=32)
|
||||||
|
|
@ -130,6 +137,16 @@ def test_check():
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=False, share_features_extractor=False)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
RecurrentPPO(
|
||||||
|
"MlpLstmPolicy",
|
||||||
|
"CartPole-v1",
|
||||||
|
n_steps=16,
|
||||||
|
seed=0,
|
||||||
|
policy_kwargs=policy_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
||||||
def test_run(env):
|
def test_run(env):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue