[Feature] Non-shared features extractor in on-policy algorithms (#130)

* Modified sb3_contrib/common/maskable/policies.py

- Added support for non-shared features extractor in file sb3_contrib/common/maskable/policies.py
- updated changelog

* Modified sb3_contrib/common/recurrent/policies.py

* Modified sb3_contrib/qrdqn/policies.py and sb3_contrib/tqc/policies.py

* Updated test_cnn.py

* Upgrade SB3 version

* Revert changes in formatting

* Remove duplicate normalize_images

* Add test for image-like inputs

* Fixes and add more tests

* Update SB3 version

* Fix ARS warnings

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Alex Pasquali 2022-12-23 00:23:45 +01:00 committed by GitHub
parent 6b23c6cfe3
commit ab8684f469
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 192 additions and 42 deletions

View File

@ -3,7 +3,7 @@
Changelog Changelog
========== ==========
Release 1.7.0a4 (WIP) Release 1.7.0a10 (WIP)
-------------------------- --------------------------
Breaking Changes: Breaking Changes:
@ -11,12 +11,15 @@ Breaking Changes:
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters, - Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
please use an ``EvalCallback`` instead please use an ``EvalCallback`` instead
- Removed deprecated ``sde_net_arch`` parameter - Removed deprecated ``sde_net_arch`` parameter
- Upgraded to Stable-Baselines3 >= 1.7.0a4 - Upgraded to Stable-Baselines3 >= 1.7.0a10
New Features: New Features:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
- Introduced mypy type checking - Introduced mypy type checking
- Added ``with_bias`` parameter to ``ARSPolicy`` - Added ``with_bias`` parameter to ``ARSPolicy``
- Added option to have non-shared features extractor between actor and critic in on-policy algorithms (@AlexPasqua)
- Features extractors now properly support unnormalized image-like observations (3D tensor)
when passing ``normalize_images=False``
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^
@ -25,6 +28,7 @@ Bug Fixes:
Deprecations: Deprecations:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
Others: Others:
^^^^^^^ ^^^^^^^
@ -364,4 +368,4 @@ Contributors:
------------- -------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua

View File

@ -69,7 +69,7 @@ class ARSPolicy(BasePolicy):
def forward(self, obs: th.Tensor) -> th.Tensor: def forward(self, obs: th.Tensor) -> th.Tensor:
features = self.extract_features(obs) features = self.extract_features(obs, self.features_extractor)
if isinstance(self.action_space, gym.spaces.Box): if isinstance(self.action_space, gym.spaces.Box):
return self.action_net(features) return self.action_net(features)
elif isinstance(self.action_space, gym.spaces.Discrete): elif isinstance(self.action_space, gym.spaces.Discrete):

View File

@ -32,6 +32,7 @@ class MaskableActorCriticPolicy(BasePolicy):
:param features_extractor_class: Features extractor to use. :param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments :param features_extractor_kwargs: Keyword arguments
to pass to the features extractor. to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use, :param optimizer_class: The optimizer to use,
@ -50,6 +51,7 @@ class MaskableActorCriticPolicy(BasePolicy):
ortho_init: bool = True, ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None, features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True, normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
@ -68,6 +70,7 @@ class MaskableActorCriticPolicy(BasePolicy):
features_extractor_kwargs, features_extractor_kwargs,
optimizer_class=optimizer_class, optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs, optimizer_kwargs=optimizer_kwargs,
normalize_images=normalize_images,
squash_output=False, squash_output=False,
) )
@ -82,10 +85,21 @@ class MaskableActorCriticPolicy(BasePolicy):
self.activation_fn = activation_fn self.activation_fn = activation_fn
self.ortho_init = ortho_init self.ortho_init = ortho_init
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) self.share_features_extractor = share_features_extractor
self.features_extractor = self.make_features_extractor()
self.features_dim = self.features_extractor.features_dim self.features_dim = self.features_extractor.features_dim
if self.share_features_extractor:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.features_extractor
else:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.make_features_extractor()
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
raise ValueError(
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
)
self.normalize_images = normalize_images
# Action distribution # Action distribution
self.action_dist = make_masked_proba_distribution(action_space) self.action_dist = make_masked_proba_distribution(action_space)
@ -107,7 +121,12 @@ class MaskableActorCriticPolicy(BasePolicy):
""" """
# Preprocess the observation if needed # Preprocess the observation if needed
features = self.extract_features(obs) features = self.extract_features(obs)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features) latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
# Evaluate the values for the given observations # Evaluate the values for the given observations
values = self.value_net(latent_vf) values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi) distribution = self._get_action_dist_from_latent(latent_pi)
@ -117,6 +136,19 @@ class MaskableActorCriticPolicy(BasePolicy):
log_prob = distribution.log_prob(actions) log_prob = distribution.log_prob(actions)
return actions, values, log_prob return actions, values, log_prob
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
:param obs: Observation
:return: the output of the features extractor(s)
"""
if self.share_features_extractor:
return super().extract_features(obs, self.features_extractor)
else:
pi_features = super().extract_features(obs, self.pi_features_extractor)
vf_features = super().extract_features(obs, self.vf_features_extractor)
return pi_features, vf_features
def _get_constructor_parameters(self) -> Dict[str, Any]: def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters() data = super()._get_constructor_parameters()
@ -174,6 +206,13 @@ class MaskableActorCriticPolicy(BasePolicy):
self.action_net: 0.01, self.action_net: 0.01,
self.value_net: 1, self.value_net: 1,
} }
if not self.share_features_extractor:
# Note(antonin): this is to keep SB3 results
# consistent, see GH#1148
del module_gains[self.features_extractor]
module_gains[self.pi_features_extractor] = np.sqrt(2)
module_gains[self.vf_features_extractor] = np.sqrt(2)
for module, gain in module_gains.items(): for module, gain in module_gains.items():
module.apply(partial(self.init_weights, gain=gain)) module.apply(partial(self.init_weights, gain=gain))
@ -268,13 +307,19 @@ class MaskableActorCriticPolicy(BasePolicy):
Evaluate actions according to the current policy, Evaluate actions according to the current policy,
given the observations. given the observations.
:param obs: :param obs: Observation
:param actions: :param actions: Actions
:return: estimated value, log likelihood of taking those actions :return: estimated value, log likelihood of taking those actions
and entropy of the action distribution. and entropy of the action distribution.
""" """
features = self.extract_features(obs) features = self.extract_features(obs)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features) latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
distribution = self._get_action_dist_from_latent(latent_pi) distribution = self._get_action_dist_from_latent(latent_pi)
if action_masks is not None: if action_masks is not None:
distribution.apply_masking(action_masks) distribution.apply_masking(action_masks)
@ -286,11 +331,11 @@ class MaskableActorCriticPolicy(BasePolicy):
""" """
Get the current policy distribution given the observations. Get the current policy distribution given the observations.
:param obs: :param obs: Observation
:param action_masks: :param action_masks: Actions' mask
:return: the action distribution. :return: the action distribution.
""" """
features = self.extract_features(obs) features = super().extract_features(obs, self.pi_features_extractor)
latent_pi = self.mlp_extractor.forward_actor(features) latent_pi = self.mlp_extractor.forward_actor(features)
distribution = self._get_action_dist_from_latent(latent_pi) distribution = self._get_action_dist_from_latent(latent_pi)
if action_masks is not None: if action_masks is not None:
@ -301,10 +346,10 @@ class MaskableActorCriticPolicy(BasePolicy):
""" """
Get the estimated values according to the current policy given the observations. Get the estimated values according to the current policy given the observations.
:param obs: :param obs: Observation
:return: the estimated values. :return: the estimated values.
""" """
features = self.extract_features(obs) features = super().extract_features(obs, self.vf_features_extractor)
latent_vf = self.mlp_extractor.forward_critic(features) latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf) return self.value_net(latent_vf)
@ -323,6 +368,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
:param features_extractor_class: Features extractor to use. :param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments :param features_extractor_kwargs: Keyword arguments
to pass to the features extractor. to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use, :param optimizer_class: The optimizer to use,
@ -341,6 +387,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
ortho_init: bool = True, ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None, features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True, normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
@ -354,6 +401,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
ortho_init, ortho_init,
features_extractor_class, features_extractor_class,
features_extractor_kwargs, features_extractor_kwargs,
share_features_extractor,
normalize_images, normalize_images,
optimizer_class, optimizer_class,
optimizer_kwargs, optimizer_kwargs,
@ -374,6 +422,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
:param features_extractor_class: Uses the CombinedExtractor :param features_extractor_class: Uses the CombinedExtractor
:param features_extractor_kwargs: Keyword arguments :param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor. to pass to the feature extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use, :param optimizer_class: The optimizer to use,
@ -392,6 +441,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
ortho_init: bool = True, ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None, features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True, normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
@ -405,6 +455,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
ortho_init, ortho_init,
features_extractor_class, features_extractor_class,
features_extractor_kwargs, features_extractor_kwargs,
share_features_extractor,
normalize_images, normalize_images,
optimizer_class, optimizer_class,
optimizer_kwargs, optimizer_kwargs,

View File

@ -44,6 +44,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
:param features_extractor_class: Features extractor to use. :param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments :param features_extractor_kwargs: Keyword arguments
to pass to the features extractor. to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use, :param optimizer_class: The optimizer to use,
@ -75,6 +76,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
squash_output: bool = False, squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None, features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True, normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
@ -99,6 +101,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
squash_output, squash_output,
features_extractor_class, features_extractor_class,
features_extractor_kwargs, features_extractor_kwargs,
share_features_extractor,
normalize_images, normalize_images,
optimizer_class, optimizer_class,
optimizer_kwargs, optimizer_kwargs,
@ -120,7 +123,11 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
self.lstm_critic = None self.lstm_critic = None
assert not ( assert not (
self.shared_lstm and self.enable_critic_lstm self.shared_lstm and self.enable_critic_lstm
), "You must choose between shared LSTM, seperate or no LSTM for the critic" ), "You must choose between shared LSTM, seperate or no LSTM for the critic."
assert not (
self.shared_lstm and not self.share_features_extractor
), "If the features extractor is not shared, the LSTM cannot be shared."
# No LSTM for the critic, we still need to convert # No LSTM for the critic, we still need to convert
# output of features extractor to the correct size # output of features extractor to the correct size
@ -222,17 +229,21 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
""" """
# Preprocess the observation if needed # Preprocess the observation if needed
features = self.extract_features(obs) features = self.extract_features(obs)
if self.share_features_extractor:
pi_features = vf_features = features # alis
else:
pi_features, vf_features = features
# latent_pi, latent_vf = self.mlp_extractor(features) # latent_pi, latent_vf = self.mlp_extractor(features)
latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) latent_pi, lstm_states_pi = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor)
if self.lstm_critic is not None: if self.lstm_critic is not None:
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) latent_vf, lstm_states_vf = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic)
elif self.shared_lstm: elif self.shared_lstm:
# Re-use LSTM features but do not backpropagate # Re-use LSTM features but do not backpropagate
latent_vf = latent_pi.detach() latent_vf = latent_pi.detach()
lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach())
else: else:
# Critic only has a feedforward network # Critic only has a feedforward network
latent_vf = self.critic(features) latent_vf = self.critic(vf_features)
lstm_states_vf = lstm_states_pi lstm_states_vf = lstm_states_pi
latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_pi = self.mlp_extractor.forward_actor(latent_pi)
@ -260,7 +271,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
or not (we reset the lstm states in that case). or not (we reset the lstm states in that case).
:return: the action distribution and new hidden states. :return: the action distribution and new hidden states.
""" """
features = self.extract_features(obs) # Call the method from the parent of the parent class
features = super(ActorCriticPolicy, self).extract_features(obs, self.pi_features_extractor)
latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor)
latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_pi = self.mlp_extractor.forward_actor(latent_pi)
return self._get_action_dist_from_latent(latent_pi), lstm_states return self._get_action_dist_from_latent(latent_pi), lstm_states
@ -280,7 +292,9 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
or not (we reset the lstm states in that case). or not (we reset the lstm states in that case).
:return: the estimated values. :return: the estimated values.
""" """
features = self.extract_features(obs) # Call the method from the parent of the parent class
features = super(ActorCriticPolicy, self).extract_features(obs, self.vf_features_extractor)
if self.lstm_critic is not None: if self.lstm_critic is not None:
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic) latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic)
elif self.shared_lstm: elif self.shared_lstm:
@ -294,11 +308,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
return self.value_net(latent_vf) return self.value_net(latent_vf)
def evaluate_actions( def evaluate_actions(
self, self, obs: th.Tensor, actions: th.Tensor, lstm_states: RNNStates, episode_starts: th.Tensor
obs: th.Tensor,
actions: th.Tensor,
lstm_states: RNNStates,
episode_starts: th.Tensor,
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
""" """
Evaluate actions according to the current policy, Evaluate actions according to the current policy,
@ -314,14 +324,17 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
""" """
# Preprocess the observation if needed # Preprocess the observation if needed
features = self.extract_features(obs) features = self.extract_features(obs)
latent_pi, _ = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) if self.share_features_extractor:
pi_features = vf_features = features # alias
else:
pi_features, vf_features = features
latent_pi, _ = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor)
if self.lstm_critic is not None: if self.lstm_critic is not None:
latent_vf, _ = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) latent_vf, _ = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic)
elif self.shared_lstm: elif self.shared_lstm:
latent_vf = latent_pi.detach() latent_vf = latent_pi.detach()
else: else:
latent_vf = self.critic(features) latent_vf = self.critic(vf_features)
latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_pi = self.mlp_extractor.forward_actor(latent_pi)
latent_vf = self.mlp_extractor.forward_critic(latent_vf) latent_vf = self.mlp_extractor.forward_critic(latent_vf)
@ -439,6 +452,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
:param features_extractor_class: Features extractor to use. :param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments :param features_extractor_kwargs: Keyword arguments
to pass to the features extractor. to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use, :param optimizer_class: The optimizer to use,
@ -469,6 +483,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
squash_output: bool = False, squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None, features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True, normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
@ -492,6 +507,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
squash_output, squash_output,
features_extractor_class, features_extractor_class,
features_extractor_kwargs, features_extractor_kwargs,
share_features_extractor,
normalize_images, normalize_images,
optimizer_class, optimizer_class,
optimizer_kwargs, optimizer_kwargs,
@ -526,6 +542,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
:param features_extractor_class: Features extractor to use. :param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments :param features_extractor_kwargs: Keyword arguments
to pass to the features extractor. to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use, :param optimizer_class: The optimizer to use,
@ -556,6 +573,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
squash_output: bool = False, squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None, features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True, normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
@ -579,6 +597,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
squash_output, squash_output,
features_extractor_class, features_extractor_class,
features_extractor_kwargs, features_extractor_kwargs,
share_features_extractor,
normalize_images, normalize_images,
optimizer_class, optimizer_class,
optimizer_kwargs, optimizer_kwargs,

View File

@ -53,7 +53,6 @@ class QuantileNetwork(BasePolicy):
self.features_extractor = features_extractor self.features_extractor = features_extractor
self.features_dim = features_dim self.features_dim = features_dim
self.n_quantiles = n_quantiles self.n_quantiles = n_quantiles
self.normalize_images = normalize_images
action_dim = self.action_space.n # number of actions action_dim = self.action_space.n # number of actions
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn) quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
self.quantile_net = nn.Sequential(*quantile_net) self.quantile_net = nn.Sequential(*quantile_net)
@ -65,7 +64,7 @@ class QuantileNetwork(BasePolicy):
:param obs: Observation :param obs: Observation
:return: The estimated quantiles for each action. :return: The estimated quantiles for each action.
""" """
quantiles = self.quantile_net(self.extract_features(obs)) quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
return quantiles.view(-1, self.n_quantiles, self.action_space.n) return quantiles.view(-1, self.n_quantiles, self.action_space.n)
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
@ -132,6 +131,7 @@ class QRDQNPolicy(BasePolicy):
features_extractor_kwargs, features_extractor_kwargs,
optimizer_class=optimizer_class, optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs, optimizer_kwargs=optimizer_kwargs,
normalize_images=normalize_images,
) )
if net_arch is None: if net_arch is None:
@ -143,7 +143,6 @@ class QRDQNPolicy(BasePolicy):
self.n_quantiles = n_quantiles self.n_quantiles = n_quantiles
self.net_arch = net_arch self.net_arch = net_arch
self.activation_fn = activation_fn self.activation_fn = activation_fn
self.normalize_images = normalize_images
self.net_args = { self.net_args = {
"observation_space": self.observation_space, "observation_space": self.observation_space,

View File

@ -149,7 +149,7 @@ class Actor(BasePolicy):
:return: :return:
Mean, standard deviation and optional keyword arguments. Mean, standard deviation and optional keyword arguments.
""" """
features = self.extract_features(obs) features = self.extract_features(obs, self.features_extractor)
latent_pi = self.latent_pi(features) latent_pi = self.latent_pi(features)
mean_actions = self.mu(latent_pi) mean_actions = self.mu(latent_pi)
@ -230,7 +230,7 @@ class Critic(BaseModel):
# Learn the features extractor using the policy loss only # Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor # when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor): with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs) features = self.extract_features(obs, self.features_extractor)
qvalue_input = th.cat([features, action], dim=1) qvalue_input = th.cat([features, action], dim=1)
quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1) quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1)
return quantiles return quantiles
@ -293,6 +293,7 @@ class TQCPolicy(BasePolicy):
features_extractor_kwargs, features_extractor_kwargs,
optimizer_class=optimizer_class, optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs, optimizer_kwargs=optimizer_kwargs,
normalize_images=normalize_images,
squash_output=True, squash_output=True,
) )

View File

@ -1 +1 @@
1.7.0a4 1.7.0a10

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=1.7.0a4", "stable_baselines3>=1.7.0a10",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

View File

@ -1,4 +1,5 @@
import os import os
import random
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
@ -6,13 +7,15 @@ import pytest
import torch as th import torch as th
from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.envs import FakeImageEnv
from stable_baselines3.common.utils import zip_strict from stable_baselines3.common.utils import zip_strict
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize, VecTransposeImage, is_vecenv_wrapped
from sb3_contrib import QRDQN, TQC, TRPO from sb3_contrib import QRDQN, TQC, TRPO, MaskablePPO, RecurrentPPO
from sb3_contrib.common.wrappers import ActionMasker
@pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO]) @pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO])
def test_cnn(tmp_path, model_class): @pytest.mark.parametrize("share_features_extractor", [True, False])
def test_cnn(tmp_path, model_class, share_features_extractor):
SAVE_NAME = "cnn_model.zip" SAVE_NAME = "cnn_model.zip"
# Fake grayscale with frameskip # Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution # Atari after preprocessing: 84x84x1, here we are using lower resolution
@ -23,8 +26,11 @@ def test_cnn(tmp_path, model_class):
n_channels=1, n_channels=1,
discrete=model_class not in {TQC}, discrete=model_class not in {TQC},
) )
kwargs = {} kwargs = dict(policy_kwargs=dict(share_features_extractor=share_features_extractor))
if model_class in {TQC, QRDQN}: if model_class in {TQC, QRDQN}:
# share_features_extractor is checked later for offpolicy algorithms
if share_features_extractor:
return
# Avoid memory error when using replay buffer # Avoid memory error when using replay buffer
# Reduce the size of the features and the number of quantiles # Reduce the size of the features and the number of quantiles
kwargs = dict( kwargs = dict(
@ -34,6 +40,7 @@ def test_cnn(tmp_path, model_class):
features_extractor_kwargs=dict(features_dim=32), features_extractor_kwargs=dict(features_dim=32),
), ),
) )
model = model_class("CnnPolicy", env, **kwargs).learn(250) model = model_class("CnnPolicy", env, **kwargs).learn(250)
obs = env.reset() obs = env.reset()
@ -155,3 +162,53 @@ def test_feature_extractor_target_net(model_class, share_features_extractor):
# Critic should be the same # Critic should be the same
params_should_match(original_param, model.critic.parameters()) params_should_match(original_param, model.critic.parameters())
@pytest.mark.parametrize("model_class", [TRPO, MaskablePPO, RecurrentPPO, QRDQN, TQC])
@pytest.mark.parametrize("normalize_images", [True, False])
def test_image_like_input(model_class, normalize_images):
"""
Check that we can handle image-like input (3D tensor)
when normalize_images=False
"""
# Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution
# to check that the network handle it automatically
env = FakeImageEnv(
screen_height=36,
screen_width=36,
n_channels=1,
channel_first=True,
discrete=model_class not in {TQC},
)
if model_class == MaskablePPO:
def action_mask_fn(env):
random_invalid_action = random.randrange(env.action_space.n)
return [i != random_invalid_action for i in range(env.action_space.n)]
env = ActionMasker(env, action_mask_fn)
vec_env = VecNormalize(DummyVecEnv([lambda: env]))
# Reduce the size of the features
# deactivate normalization
kwargs = dict(
policy_kwargs=dict(
normalize_images=normalize_images,
features_extractor_kwargs=dict(features_dim=32),
),
seed=1,
)
policy = "CnnLstmPolicy" if model_class == RecurrentPPO else "CnnPolicy"
if model_class in {TRPO, MaskablePPO, RecurrentPPO}:
kwargs.update(dict(n_steps=64, batch_size=64))
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs.update(dict(buffer_size=250))
if normalize_images:
with pytest.raises(AssertionError):
model_class(policy, vec_env, **kwargs).learn(128)
else:
model_class(policy, vec_env, **kwargs).learn(128)

View File

@ -226,7 +226,8 @@ def test_discrete_action_space_required():
MaskablePPO("MlpPolicy", env) MaskablePPO("MlpPolicy", env)
def test_cnn(): @pytest.mark.parametrize("share_features_extractor", [True, False])
def test_cnn(share_features_extractor):
def action_mask_fn(env): def action_mask_fn(env):
random_invalid_action = random.randrange(env.action_space.n) random_invalid_action = random.randrange(env.action_space.n)
return [i != random_invalid_action for i in range(env.action_space.n)] return [i != random_invalid_action for i in range(env.action_space.n)]
@ -242,6 +243,7 @@ def test_cnn():
verbose=1, verbose=1,
policy_kwargs=dict( policy_kwargs=dict(
features_extractor_kwargs=dict(features_dim=32), features_extractor_kwargs=dict(features_dim=32),
share_features_extractor=share_features_extractor,
), ),
) )
model.learn(100) model.learn(100)

View File

@ -61,6 +61,7 @@ class CartPoleNoVelEnv(CartPoleEnv):
"policy_kwargs", "policy_kwargs",
[ [
{}, {},
{"share_features_extractor": False},
dict(shared_lstm=True, enable_critic_lstm=False), dict(shared_lstm=True, enable_critic_lstm=False),
dict( dict(
enable_critic_lstm=True, enable_critic_lstm=True,
@ -74,6 +75,11 @@ class CartPoleNoVelEnv(CartPoleEnv):
lstm_kwargs=dict(dropout=0.5), lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2, n_lstm_layers=2,
), ),
dict(
enable_critic_lstm=False,
lstm_hidden_size=4,
share_features_extractor=False,
),
], ],
) )
def test_cnn(policy_kwargs): def test_cnn(policy_kwargs):
@ -83,6 +89,7 @@ def test_cnn(policy_kwargs):
n_steps=16, n_steps=16,
seed=0, seed=0,
policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)), policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)),
n_epochs=2,
) )
model.learn(total_timesteps=32) model.learn(total_timesteps=32)
@ -130,6 +137,16 @@ def test_check():
policy_kwargs=policy_kwargs, policy_kwargs=policy_kwargs,
) )
policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=False, share_features_extractor=False)
with pytest.raises(AssertionError):
RecurrentPPO(
"MlpLstmPolicy",
"CartPole-v1",
n_steps=16,
seed=0,
policy_kwargs=policy_kwargs,
)
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"]) @pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
def test_run(env): def test_run(env):