[Feature] Non-shared features extractor in on-policy algorithms (#130)
* Modified sb3_contrib/common/maskable/policies.py - Added support for non-shared features extractor in file sb3_contrib/common/maskable/policies.py - updated changelog * Modified sb3_contrib/common/recurrent/policies.py * Modified sb3_contrib/qrdqn/policies.py and sb3_contrib/tqc/policies.py * Updated test_cnn.py * Upgrade SB3 version * Revert changes in formatting * Remove duplicate normalize_images * Add test for image-like inputs * Fixes and add more tests * Update SB3 version * Fix ARS warnings Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
6b23c6cfe3
commit
ab8684f469
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.7.0a4 (WIP)
|
||||
Release 1.7.0a10 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -11,12 +11,15 @@ Breaking Changes:
|
|||
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
||||
please use an ``EvalCallback`` instead
|
||||
- Removed deprecated ``sde_net_arch`` parameter
|
||||
- Upgraded to Stable-Baselines3 >= 1.7.0a4
|
||||
- Upgraded to Stable-Baselines3 >= 1.7.0a10
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Introduced mypy type checking
|
||||
- Added ``with_bias`` parameter to ``ARSPolicy``
|
||||
- Added option to have non-shared features extractor between actor and critic in on-policy algorithms (@AlexPasqua)
|
||||
- Features extractors now properly support unnormalized image-like observations (3D tensor)
|
||||
when passing ``normalize_images=False``
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -25,6 +28,7 @@ Bug Fixes:
|
|||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
|
@ -364,4 +368,4 @@ Contributors:
|
|||
-------------
|
||||
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
||||
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong
|
||||
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class ARSPolicy(BasePolicy):
|
|||
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
|
||||
features = self.extract_features(obs)
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
return self.action_net(features)
|
||||
elif isinstance(self.action_space, gym.spaces.Discrete):
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -50,6 +51,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
ortho_init: bool = True,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
@ -68,6 +70,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
normalize_images=normalize_images,
|
||||
squash_output=False,
|
||||
)
|
||||
|
||||
|
|
@ -82,10 +85,21 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
self.activation_fn = activation_fn
|
||||
self.ortho_init = ortho_init
|
||||
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
self.share_features_extractor = share_features_extractor
|
||||
self.features_extractor = self.make_features_extractor()
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
if self.share_features_extractor:
|
||||
self.pi_features_extractor = self.features_extractor
|
||||
self.vf_features_extractor = self.features_extractor
|
||||
else:
|
||||
self.pi_features_extractor = self.features_extractor
|
||||
self.vf_features_extractor = self.make_features_extractor()
|
||||
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
|
||||
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
|
||||
raise ValueError(
|
||||
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
|
||||
)
|
||||
|
||||
self.normalize_images = normalize_images
|
||||
# Action distribution
|
||||
self.action_dist = make_masked_proba_distribution(action_space)
|
||||
|
||||
|
|
@ -107,7 +121,12 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
if self.share_features_extractor:
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
else:
|
||||
pi_features, vf_features = features
|
||||
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
||||
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
||||
# Evaluate the values for the given observations
|
||||
values = self.value_net(latent_vf)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
|
|
@ -117,6 +136,19 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
log_prob = distribution.log_prob(actions)
|
||||
return actions, values, log_prob
|
||||
|
||||
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
:param obs: Observation
|
||||
:return: the output of the features extractor(s)
|
||||
"""
|
||||
if self.share_features_extractor:
|
||||
return super().extract_features(obs, self.features_extractor)
|
||||
else:
|
||||
pi_features = super().extract_features(obs, self.pi_features_extractor)
|
||||
vf_features = super().extract_features(obs, self.vf_features_extractor)
|
||||
return pi_features, vf_features
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
|
|
@ -174,6 +206,13 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
self.action_net: 0.01,
|
||||
self.value_net: 1,
|
||||
}
|
||||
if not self.share_features_extractor:
|
||||
# Note(antonin): this is to keep SB3 results
|
||||
# consistent, see GH#1148
|
||||
del module_gains[self.features_extractor]
|
||||
module_gains[self.pi_features_extractor] = np.sqrt(2)
|
||||
module_gains[self.vf_features_extractor] = np.sqrt(2)
|
||||
|
||||
for module, gain in module_gains.items():
|
||||
module.apply(partial(self.init_weights, gain=gain))
|
||||
|
||||
|
|
@ -268,13 +307,19 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
||||
:param obs:
|
||||
:param actions:
|
||||
:param obs: Observation
|
||||
:param actions: Actions
|
||||
:return: estimated value, log likelihood of taking those actions
|
||||
and entropy of the action distribution.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
if self.share_features_extractor:
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
else:
|
||||
pi_features, vf_features = features
|
||||
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
||||
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
||||
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
if action_masks is not None:
|
||||
distribution.apply_masking(action_masks)
|
||||
|
|
@ -286,11 +331,11 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
"""
|
||||
Get the current policy distribution given the observations.
|
||||
|
||||
:param obs:
|
||||
:param action_masks:
|
||||
:param obs: Observation
|
||||
:param action_masks: Actions' mask
|
||||
:return: the action distribution.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
features = super().extract_features(obs, self.pi_features_extractor)
|
||||
latent_pi = self.mlp_extractor.forward_actor(features)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
if action_masks is not None:
|
||||
|
|
@ -301,10 +346,10 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
"""
|
||||
Get the estimated values according to the current policy given the observations.
|
||||
|
||||
:param obs:
|
||||
:param obs: Observation
|
||||
:return: the estimated values.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
features = super().extract_features(obs, self.vf_features_extractor)
|
||||
latent_vf = self.mlp_extractor.forward_critic(features)
|
||||
return self.value_net(latent_vf)
|
||||
|
||||
|
|
@ -323,6 +368,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
|||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -341,6 +387,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
|||
ortho_init: bool = True,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
@ -354,6 +401,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
|||
ortho_init,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
share_features_extractor,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
|
|
@ -374,6 +422,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
|||
:param features_extractor_class: Uses the CombinedExtractor
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -392,6 +441,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
|||
ortho_init: bool = True,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
@ -405,6 +455,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
|||
ortho_init,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
share_features_extractor,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -75,6 +76,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
@ -99,6 +101,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
squash_output,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
share_features_extractor,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
|
|
@ -120,7 +123,11 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
self.lstm_critic = None
|
||||
assert not (
|
||||
self.shared_lstm and self.enable_critic_lstm
|
||||
), "You must choose between shared LSTM, seperate or no LSTM for the critic"
|
||||
), "You must choose between shared LSTM, seperate or no LSTM for the critic."
|
||||
|
||||
assert not (
|
||||
self.shared_lstm and not self.share_features_extractor
|
||||
), "If the features extractor is not shared, the LSTM cannot be shared."
|
||||
|
||||
# No LSTM for the critic, we still need to convert
|
||||
# output of features extractor to the correct size
|
||||
|
|
@ -222,17 +229,21 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
if self.share_features_extractor:
|
||||
pi_features = vf_features = features # alis
|
||||
else:
|
||||
pi_features, vf_features = features
|
||||
# latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor)
|
||||
latent_pi, lstm_states_pi = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor)
|
||||
if self.lstm_critic is not None:
|
||||
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic)
|
||||
latent_vf, lstm_states_vf = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic)
|
||||
elif self.shared_lstm:
|
||||
# Re-use LSTM features but do not backpropagate
|
||||
latent_vf = latent_pi.detach()
|
||||
lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach())
|
||||
else:
|
||||
# Critic only has a feedforward network
|
||||
latent_vf = self.critic(features)
|
||||
latent_vf = self.critic(vf_features)
|
||||
lstm_states_vf = lstm_states_pi
|
||||
|
||||
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
||||
|
|
@ -260,7 +271,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
or not (we reset the lstm states in that case).
|
||||
:return: the action distribution and new hidden states.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
# Call the method from the parent of the parent class
|
||||
features = super(ActorCriticPolicy, self).extract_features(obs, self.pi_features_extractor)
|
||||
latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor)
|
||||
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
||||
return self._get_action_dist_from_latent(latent_pi), lstm_states
|
||||
|
|
@ -280,7 +292,9 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
or not (we reset the lstm states in that case).
|
||||
:return: the estimated values.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
# Call the method from the parent of the parent class
|
||||
features = super(ActorCriticPolicy, self).extract_features(obs, self.vf_features_extractor)
|
||||
|
||||
if self.lstm_critic is not None:
|
||||
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic)
|
||||
elif self.shared_lstm:
|
||||
|
|
@ -294,11 +308,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
return self.value_net(latent_vf)
|
||||
|
||||
def evaluate_actions(
|
||||
self,
|
||||
obs: th.Tensor,
|
||||
actions: th.Tensor,
|
||||
lstm_states: RNNStates,
|
||||
episode_starts: th.Tensor,
|
||||
self, obs: th.Tensor, actions: th.Tensor, lstm_states: RNNStates, episode_starts: th.Tensor
|
||||
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
|
|
@ -314,14 +324,17 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, _ = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor)
|
||||
|
||||
if self.share_features_extractor:
|
||||
pi_features = vf_features = features # alias
|
||||
else:
|
||||
pi_features, vf_features = features
|
||||
latent_pi, _ = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor)
|
||||
if self.lstm_critic is not None:
|
||||
latent_vf, _ = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic)
|
||||
latent_vf, _ = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic)
|
||||
elif self.shared_lstm:
|
||||
latent_vf = latent_pi.detach()
|
||||
else:
|
||||
latent_vf = self.critic(features)
|
||||
latent_vf = self.critic(vf_features)
|
||||
|
||||
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
||||
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
|
||||
|
|
@ -439,6 +452,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
|||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -469,6 +483,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
|||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
@ -492,6 +507,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
|||
squash_output,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
share_features_extractor,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
|
|
@ -526,6 +542,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
|||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -556,6 +573,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
|||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
share_features_extractor: bool = True,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
@ -579,6 +597,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
|||
squash_output,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
share_features_extractor,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ class QuantileNetwork(BasePolicy):
|
|||
self.features_extractor = features_extractor
|
||||
self.features_dim = features_dim
|
||||
self.n_quantiles = n_quantiles
|
||||
self.normalize_images = normalize_images
|
||||
action_dim = self.action_space.n # number of actions
|
||||
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
|
||||
self.quantile_net = nn.Sequential(*quantile_net)
|
||||
|
|
@ -65,7 +64,7 @@ class QuantileNetwork(BasePolicy):
|
|||
:param obs: Observation
|
||||
:return: The estimated quantiles for each action.
|
||||
"""
|
||||
quantiles = self.quantile_net(self.extract_features(obs))
|
||||
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
|
||||
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
|
|
@ -132,6 +131,7 @@ class QRDQNPolicy(BasePolicy):
|
|||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
if net_arch is None:
|
||||
|
|
@ -143,7 +143,6 @@ class QRDQNPolicy(BasePolicy):
|
|||
self.n_quantiles = n_quantiles
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.normalize_images = normalize_images
|
||||
|
||||
self.net_args = {
|
||||
"observation_space": self.observation_space,
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class Actor(BasePolicy):
|
|||
:return:
|
||||
Mean, standard deviation and optional keyword arguments.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
latent_pi = self.latent_pi(features)
|
||||
mean_actions = self.mu(latent_pi)
|
||||
|
||||
|
|
@ -230,7 +230,7 @@ class Critic(BaseModel):
|
|||
# Learn the features extractor using the policy loss only
|
||||
# when the features_extractor is shared with the actor
|
||||
with th.set_grad_enabled(not self.share_features_extractor):
|
||||
features = self.extract_features(obs)
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
qvalue_input = th.cat([features, action], dim=1)
|
||||
quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1)
|
||||
return quantiles
|
||||
|
|
@ -293,6 +293,7 @@ class TQCPolicy(BasePolicy):
|
|||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
normalize_images=normalize_images,
|
||||
squash_output=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a4
|
||||
1.7.0a10
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.7.0a4",
|
||||
"stable_baselines3>=1.7.0a10",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -6,13 +7,15 @@ import pytest
|
|||
import torch as th
|
||||
from stable_baselines3.common.envs import FakeImageEnv
|
||||
from stable_baselines3.common.utils import zip_strict
|
||||
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize, VecTransposeImage, is_vecenv_wrapped
|
||||
|
||||
from sb3_contrib import QRDQN, TQC, TRPO
|
||||
from sb3_contrib import QRDQN, TQC, TRPO, MaskablePPO, RecurrentPPO
|
||||
from sb3_contrib.common.wrappers import ActionMasker
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO])
|
||||
def test_cnn(tmp_path, model_class):
|
||||
@pytest.mark.parametrize("share_features_extractor", [True, False])
|
||||
def test_cnn(tmp_path, model_class, share_features_extractor):
|
||||
SAVE_NAME = "cnn_model.zip"
|
||||
# Fake grayscale with frameskip
|
||||
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
||||
|
|
@ -23,8 +26,11 @@ def test_cnn(tmp_path, model_class):
|
|||
n_channels=1,
|
||||
discrete=model_class not in {TQC},
|
||||
)
|
||||
kwargs = {}
|
||||
kwargs = dict(policy_kwargs=dict(share_features_extractor=share_features_extractor))
|
||||
if model_class in {TQC, QRDQN}:
|
||||
# share_features_extractor is checked later for offpolicy algorithms
|
||||
if share_features_extractor:
|
||||
return
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features and the number of quantiles
|
||||
kwargs = dict(
|
||||
|
|
@ -34,6 +40,7 @@ def test_cnn(tmp_path, model_class):
|
|||
features_extractor_kwargs=dict(features_dim=32),
|
||||
),
|
||||
)
|
||||
|
||||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||
|
||||
obs = env.reset()
|
||||
|
|
@ -155,3 +162,53 @@ def test_feature_extractor_target_net(model_class, share_features_extractor):
|
|||
|
||||
# Critic should be the same
|
||||
params_should_match(original_param, model.critic.parameters())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TRPO, MaskablePPO, RecurrentPPO, QRDQN, TQC])
|
||||
@pytest.mark.parametrize("normalize_images", [True, False])
|
||||
def test_image_like_input(model_class, normalize_images):
|
||||
"""
|
||||
Check that we can handle image-like input (3D tensor)
|
||||
when normalize_images=False
|
||||
"""
|
||||
# Fake grayscale with frameskip
|
||||
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
||||
# to check that the network handle it automatically
|
||||
env = FakeImageEnv(
|
||||
screen_height=36,
|
||||
screen_width=36,
|
||||
n_channels=1,
|
||||
channel_first=True,
|
||||
discrete=model_class not in {TQC},
|
||||
)
|
||||
if model_class == MaskablePPO:
|
||||
|
||||
def action_mask_fn(env):
|
||||
random_invalid_action = random.randrange(env.action_space.n)
|
||||
return [i != random_invalid_action for i in range(env.action_space.n)]
|
||||
|
||||
env = ActionMasker(env, action_mask_fn)
|
||||
|
||||
vec_env = VecNormalize(DummyVecEnv([lambda: env]))
|
||||
# Reduce the size of the features
|
||||
# deactivate normalization
|
||||
kwargs = dict(
|
||||
policy_kwargs=dict(
|
||||
normalize_images=normalize_images,
|
||||
features_extractor_kwargs=dict(features_dim=32),
|
||||
),
|
||||
seed=1,
|
||||
)
|
||||
policy = "CnnLstmPolicy" if model_class == RecurrentPPO else "CnnPolicy"
|
||||
|
||||
if model_class in {TRPO, MaskablePPO, RecurrentPPO}:
|
||||
kwargs.update(dict(n_steps=64, batch_size=64))
|
||||
else:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs.update(dict(buffer_size=250))
|
||||
if normalize_images:
|
||||
with pytest.raises(AssertionError):
|
||||
model_class(policy, vec_env, **kwargs).learn(128)
|
||||
else:
|
||||
model_class(policy, vec_env, **kwargs).learn(128)
|
||||
|
|
|
|||
|
|
@ -226,7 +226,8 @@ def test_discrete_action_space_required():
|
|||
MaskablePPO("MlpPolicy", env)
|
||||
|
||||
|
||||
def test_cnn():
|
||||
@pytest.mark.parametrize("share_features_extractor", [True, False])
|
||||
def test_cnn(share_features_extractor):
|
||||
def action_mask_fn(env):
|
||||
random_invalid_action = random.randrange(env.action_space.n)
|
||||
return [i != random_invalid_action for i in range(env.action_space.n)]
|
||||
|
|
@ -242,6 +243,7 @@ def test_cnn():
|
|||
verbose=1,
|
||||
policy_kwargs=dict(
|
||||
features_extractor_kwargs=dict(features_dim=32),
|
||||
share_features_extractor=share_features_extractor,
|
||||
),
|
||||
)
|
||||
model.learn(100)
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ class CartPoleNoVelEnv(CartPoleEnv):
|
|||
"policy_kwargs",
|
||||
[
|
||||
{},
|
||||
{"share_features_extractor": False},
|
||||
dict(shared_lstm=True, enable_critic_lstm=False),
|
||||
dict(
|
||||
enable_critic_lstm=True,
|
||||
|
|
@ -74,6 +75,11 @@ class CartPoleNoVelEnv(CartPoleEnv):
|
|||
lstm_kwargs=dict(dropout=0.5),
|
||||
n_lstm_layers=2,
|
||||
),
|
||||
dict(
|
||||
enable_critic_lstm=False,
|
||||
lstm_hidden_size=4,
|
||||
share_features_extractor=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_cnn(policy_kwargs):
|
||||
|
|
@ -83,6 +89,7 @@ def test_cnn(policy_kwargs):
|
|||
n_steps=16,
|
||||
seed=0,
|
||||
policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)),
|
||||
n_epochs=2,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=32)
|
||||
|
|
@ -130,6 +137,16 @@ def test_check():
|
|||
policy_kwargs=policy_kwargs,
|
||||
)
|
||||
|
||||
policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=False, share_features_extractor=False)
|
||||
with pytest.raises(AssertionError):
|
||||
RecurrentPPO(
|
||||
"MlpLstmPolicy",
|
||||
"CartPole-v1",
|
||||
n_steps=16,
|
||||
seed=0,
|
||||
policy_kwargs=policy_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
||||
def test_run(env):
|
||||
|
|
|
|||
Loading…
Reference in New Issue