From ab8684f4696e3e339533b3f871b8ec0f39984f43 Mon Sep 17 00:00:00 2001 From: Alex Pasquali Date: Fri, 23 Dec 2022 00:23:45 +0100 Subject: [PATCH] [Feature] Non-shared features extractor in on-policy algorithms (#130) * Modified sb3_contrib/common/maskable/policies.py - Added support for non-shared features extractor in file sb3_contrib/common/maskable/policies.py - updated changelog * Modified sb3_contrib/common/recurrent/policies.py * Modified sb3_contrib/qrdqn/policies.py and sb3_contrib/tqc/policies.py * Updated test_cnn.py * Upgrade SB3 version * Revert changes in formatting * Remove duplicate normalize_images * Add test for image-like inputs * Fixes and add more tests * Update SB3 version * Fix ARS warnings Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 10 +++- sb3_contrib/ars/policies.py | 2 +- sb3_contrib/common/maskable/policies.py | 73 ++++++++++++++++++++---- sb3_contrib/common/recurrent/policies.py | 49 +++++++++++----- sb3_contrib/qrdqn/policies.py | 5 +- sb3_contrib/tqc/policies.py | 5 +- sb3_contrib/version.txt | 2 +- setup.py | 2 +- tests/test_cnn.py | 65 +++++++++++++++++++-- tests/test_invalid_actions.py | 4 +- tests/test_lstm.py | 17 ++++++ 11 files changed, 192 insertions(+), 42 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 942c98a..564f426 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.7.0a4 (WIP) +Release 1.7.0a10 (WIP) -------------------------- Breaking Changes: @@ -11,12 +11,15 @@ Breaking Changes: - Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters, please use an ``EvalCallback`` instead - Removed deprecated ``sde_net_arch`` parameter -- Upgraded to Stable-Baselines3 >= 1.7.0a4 +- Upgraded to Stable-Baselines3 >= 1.7.0a10 New Features: ^^^^^^^^^^^^^ - Introduced mypy type checking - Added ``with_bias`` parameter to ``ARSPolicy`` +- Added option to have non-shared features extractor between actor and critic in on-policy algorithms (@AlexPasqua) +- Features extractors now properly support unnormalized image-like observations (3D tensor) + when passing ``normalize_images=False`` Bug Fixes: ^^^^^^^^^^ @@ -25,6 +28,7 @@ Bug Fixes: Deprecations: ^^^^^^^^^^^^^ +- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()`` Others: ^^^^^^^ @@ -364,4 +368,4 @@ Contributors: ------------- @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec -@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong +@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua diff --git a/sb3_contrib/ars/policies.py b/sb3_contrib/ars/policies.py index 6ab17cc..4b1a677 100644 --- a/sb3_contrib/ars/policies.py +++ b/sb3_contrib/ars/policies.py @@ -69,7 +69,7 @@ class ARSPolicy(BasePolicy): def forward(self, obs: th.Tensor) -> th.Tensor: - features = self.extract_features(obs) + features = self.extract_features(obs, self.features_extractor) if isinstance(self.action_space, gym.spaces.Box): return self.action_net(features) elif isinstance(self.action_space, gym.spaces.Discrete): diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index 37b9e7f..c4fbdcc 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -32,6 +32,7 @@ class MaskableActorCriticPolicy(BasePolicy): :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -50,6 +51,7 @@ class MaskableActorCriticPolicy(BasePolicy): ortho_init: bool = True, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, @@ -68,6 +70,7 @@ class MaskableActorCriticPolicy(BasePolicy): features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, + normalize_images=normalize_images, squash_output=False, ) @@ -82,10 +85,21 @@ class MaskableActorCriticPolicy(BasePolicy): self.activation_fn = activation_fn self.ortho_init = ortho_init - self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + self.share_features_extractor = share_features_extractor + self.features_extractor = self.make_features_extractor() self.features_dim = self.features_extractor.features_dim + if self.share_features_extractor: + self.pi_features_extractor = self.features_extractor + self.vf_features_extractor = self.features_extractor + else: + self.pi_features_extractor = self.features_extractor + self.vf_features_extractor = self.make_features_extractor() + # if the features extractor is not shared, there cannot be shared layers in the mlp_extractor + if len(net_arch) > 0 and not isinstance(net_arch[0], dict): + raise ValueError( + "Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor" + ) - self.normalize_images = normalize_images # Action distribution self.action_dist = make_masked_proba_distribution(action_space) @@ -107,7 +121,12 @@ class MaskableActorCriticPolicy(BasePolicy): """ # Preprocess the observation if needed features = self.extract_features(obs) - latent_pi, latent_vf = self.mlp_extractor(features) + if self.share_features_extractor: + latent_pi, latent_vf = self.mlp_extractor(features) + else: + pi_features, vf_features = features + latent_pi = self.mlp_extractor.forward_actor(pi_features) + latent_vf = self.mlp_extractor.forward_critic(vf_features) # Evaluate the values for the given observations values = self.value_net(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) @@ -117,6 +136,19 @@ class MaskableActorCriticPolicy(BasePolicy): log_prob = distribution.log_prob(actions) return actions, values, log_prob + def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: + """ + Preprocess the observation if needed and extract features. + :param obs: Observation + :return: the output of the features extractor(s) + """ + if self.share_features_extractor: + return super().extract_features(obs, self.features_extractor) + else: + pi_features = super().extract_features(obs, self.pi_features_extractor) + vf_features = super().extract_features(obs, self.vf_features_extractor) + return pi_features, vf_features + def _get_constructor_parameters(self) -> Dict[str, Any]: data = super()._get_constructor_parameters() @@ -174,6 +206,13 @@ class MaskableActorCriticPolicy(BasePolicy): self.action_net: 0.01, self.value_net: 1, } + if not self.share_features_extractor: + # Note(antonin): this is to keep SB3 results + # consistent, see GH#1148 + del module_gains[self.features_extractor] + module_gains[self.pi_features_extractor] = np.sqrt(2) + module_gains[self.vf_features_extractor] = np.sqrt(2) + for module, gain in module_gains.items(): module.apply(partial(self.init_weights, gain=gain)) @@ -268,13 +307,19 @@ class MaskableActorCriticPolicy(BasePolicy): Evaluate actions according to the current policy, given the observations. - :param obs: - :param actions: + :param obs: Observation + :param actions: Actions :return: estimated value, log likelihood of taking those actions and entropy of the action distribution. """ features = self.extract_features(obs) - latent_pi, latent_vf = self.mlp_extractor(features) + if self.share_features_extractor: + latent_pi, latent_vf = self.mlp_extractor(features) + else: + pi_features, vf_features = features + latent_pi = self.mlp_extractor.forward_actor(pi_features) + latent_vf = self.mlp_extractor.forward_critic(vf_features) + distribution = self._get_action_dist_from_latent(latent_pi) if action_masks is not None: distribution.apply_masking(action_masks) @@ -286,11 +331,11 @@ class MaskableActorCriticPolicy(BasePolicy): """ Get the current policy distribution given the observations. - :param obs: - :param action_masks: + :param obs: Observation + :param action_masks: Actions' mask :return: the action distribution. """ - features = self.extract_features(obs) + features = super().extract_features(obs, self.pi_features_extractor) latent_pi = self.mlp_extractor.forward_actor(features) distribution = self._get_action_dist_from_latent(latent_pi) if action_masks is not None: @@ -301,10 +346,10 @@ class MaskableActorCriticPolicy(BasePolicy): """ Get the estimated values according to the current policy given the observations. - :param obs: + :param obs: Observation :return: the estimated values. """ - features = self.extract_features(obs) + features = super().extract_features(obs, self.vf_features_extractor) latent_vf = self.mlp_extractor.forward_critic(features) return self.value_net(latent_vf) @@ -323,6 +368,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy): :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -341,6 +387,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy): ortho_init: bool = True, features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, @@ -354,6 +401,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy): ortho_init, features_extractor_class, features_extractor_kwargs, + share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs, @@ -374,6 +422,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy): :param features_extractor_class: Uses the CombinedExtractor :param features_extractor_kwargs: Keyword arguments to pass to the feature extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -392,6 +441,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy): ortho_init: bool = True, features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, @@ -405,6 +455,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy): ortho_init, features_extractor_class, features_extractor_kwargs, + share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs, diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index cbc2718..7cb5b37 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -44,6 +44,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -75,6 +76,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): squash_output: bool = False, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, @@ -99,6 +101,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): squash_output, features_extractor_class, features_extractor_kwargs, + share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs, @@ -120,7 +123,11 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): self.lstm_critic = None assert not ( self.shared_lstm and self.enable_critic_lstm - ), "You must choose between shared LSTM, seperate or no LSTM for the critic" + ), "You must choose between shared LSTM, seperate or no LSTM for the critic." + + assert not ( + self.shared_lstm and not self.share_features_extractor + ), "If the features extractor is not shared, the LSTM cannot be shared." # No LSTM for the critic, we still need to convert # output of features extractor to the correct size @@ -222,17 +229,21 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): """ # Preprocess the observation if needed features = self.extract_features(obs) + if self.share_features_extractor: + pi_features = vf_features = features # alis + else: + pi_features, vf_features = features # latent_pi, latent_vf = self.mlp_extractor(features) - latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) + latent_pi, lstm_states_pi = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor) if self.lstm_critic is not None: - latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) + latent_vf, lstm_states_vf = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic) elif self.shared_lstm: # Re-use LSTM features but do not backpropagate latent_vf = latent_pi.detach() lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) else: # Critic only has a feedforward network - latent_vf = self.critic(features) + latent_vf = self.critic(vf_features) lstm_states_vf = lstm_states_pi latent_pi = self.mlp_extractor.forward_actor(latent_pi) @@ -260,7 +271,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): or not (we reset the lstm states in that case). :return: the action distribution and new hidden states. """ - features = self.extract_features(obs) + # Call the method from the parent of the parent class + features = super(ActorCriticPolicy, self).extract_features(obs, self.pi_features_extractor) latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) latent_pi = self.mlp_extractor.forward_actor(latent_pi) return self._get_action_dist_from_latent(latent_pi), lstm_states @@ -280,7 +292,9 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): or not (we reset the lstm states in that case). :return: the estimated values. """ - features = self.extract_features(obs) + # Call the method from the parent of the parent class + features = super(ActorCriticPolicy, self).extract_features(obs, self.vf_features_extractor) + if self.lstm_critic is not None: latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic) elif self.shared_lstm: @@ -294,11 +308,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): return self.value_net(latent_vf) def evaluate_actions( - self, - obs: th.Tensor, - actions: th.Tensor, - lstm_states: RNNStates, - episode_starts: th.Tensor, + self, obs: th.Tensor, actions: th.Tensor, lstm_states: RNNStates, episode_starts: th.Tensor ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ Evaluate actions according to the current policy, @@ -314,14 +324,17 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): """ # Preprocess the observation if needed features = self.extract_features(obs) - latent_pi, _ = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) - + if self.share_features_extractor: + pi_features = vf_features = features # alias + else: + pi_features, vf_features = features + latent_pi, _ = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor) if self.lstm_critic is not None: - latent_vf, _ = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) + latent_vf, _ = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic) elif self.shared_lstm: latent_vf = latent_pi.detach() else: - latent_vf = self.critic(features) + latent_vf = self.critic(vf_features) latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) @@ -439,6 +452,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -469,6 +483,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): squash_output: bool = False, features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, @@ -492,6 +507,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): squash_output, features_extractor_class, features_extractor_kwargs, + share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs, @@ -526,6 +542,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy): :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, @@ -556,6 +573,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy): squash_output: bool = False, features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, @@ -579,6 +597,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy): squash_output, features_extractor_class, features_extractor_kwargs, + share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs, diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index d283c2f..a4a0010 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -53,7 +53,6 @@ class QuantileNetwork(BasePolicy): self.features_extractor = features_extractor self.features_dim = features_dim self.n_quantiles = n_quantiles - self.normalize_images = normalize_images action_dim = self.action_space.n # number of actions quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn) self.quantile_net = nn.Sequential(*quantile_net) @@ -65,7 +64,7 @@ class QuantileNetwork(BasePolicy): :param obs: Observation :return: The estimated quantiles for each action. """ - quantiles = self.quantile_net(self.extract_features(obs)) + quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor)) return quantiles.view(-1, self.n_quantiles, self.action_space.n) def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: @@ -132,6 +131,7 @@ class QRDQNPolicy(BasePolicy): features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, + normalize_images=normalize_images, ) if net_arch is None: @@ -143,7 +143,6 @@ class QRDQNPolicy(BasePolicy): self.n_quantiles = n_quantiles self.net_arch = net_arch self.activation_fn = activation_fn - self.normalize_images = normalize_images self.net_args = { "observation_space": self.observation_space, diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index e2266f3..c25ebe3 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -149,7 +149,7 @@ class Actor(BasePolicy): :return: Mean, standard deviation and optional keyword arguments. """ - features = self.extract_features(obs) + features = self.extract_features(obs, self.features_extractor) latent_pi = self.latent_pi(features) mean_actions = self.mu(latent_pi) @@ -230,7 +230,7 @@ class Critic(BaseModel): # Learn the features extractor using the policy loss only # when the features_extractor is shared with the actor with th.set_grad_enabled(not self.share_features_extractor): - features = self.extract_features(obs) + features = self.extract_features(obs, self.features_extractor) qvalue_input = th.cat([features, action], dim=1) quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1) return quantiles @@ -293,6 +293,7 @@ class TQCPolicy(BasePolicy): features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, + normalize_images=normalize_images, squash_output=True, ) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 3f38852..89e17c2 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.7.0a4 \ No newline at end of file +1.7.0a10 diff --git a/setup.py b/setup.py index 60fe1d3..2abcfd9 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.7.0a4", + "stable_baselines3>=1.7.0a10", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_cnn.py b/tests/test_cnn.py index e570aab..be8758a 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -1,4 +1,5 @@ import os +import random from copy import deepcopy import numpy as np @@ -6,13 +7,15 @@ import pytest import torch as th from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.utils import zip_strict -from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped +from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize, VecTransposeImage, is_vecenv_wrapped -from sb3_contrib import QRDQN, TQC, TRPO +from sb3_contrib import QRDQN, TQC, TRPO, MaskablePPO, RecurrentPPO +from sb3_contrib.common.wrappers import ActionMasker @pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO]) -def test_cnn(tmp_path, model_class): +@pytest.mark.parametrize("share_features_extractor", [True, False]) +def test_cnn(tmp_path, model_class, share_features_extractor): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip # Atari after preprocessing: 84x84x1, here we are using lower resolution @@ -23,8 +26,11 @@ def test_cnn(tmp_path, model_class): n_channels=1, discrete=model_class not in {TQC}, ) - kwargs = {} + kwargs = dict(policy_kwargs=dict(share_features_extractor=share_features_extractor)) if model_class in {TQC, QRDQN}: + # share_features_extractor is checked later for offpolicy algorithms + if share_features_extractor: + return # Avoid memory error when using replay buffer # Reduce the size of the features and the number of quantiles kwargs = dict( @@ -34,6 +40,7 @@ def test_cnn(tmp_path, model_class): features_extractor_kwargs=dict(features_dim=32), ), ) + model = model_class("CnnPolicy", env, **kwargs).learn(250) obs = env.reset() @@ -155,3 +162,53 @@ def test_feature_extractor_target_net(model_class, share_features_extractor): # Critic should be the same params_should_match(original_param, model.critic.parameters()) + + +@pytest.mark.parametrize("model_class", [TRPO, MaskablePPO, RecurrentPPO, QRDQN, TQC]) +@pytest.mark.parametrize("normalize_images", [True, False]) +def test_image_like_input(model_class, normalize_images): + """ + Check that we can handle image-like input (3D tensor) + when normalize_images=False + """ + # Fake grayscale with frameskip + # Atari after preprocessing: 84x84x1, here we are using lower resolution + # to check that the network handle it automatically + env = FakeImageEnv( + screen_height=36, + screen_width=36, + n_channels=1, + channel_first=True, + discrete=model_class not in {TQC}, + ) + if model_class == MaskablePPO: + + def action_mask_fn(env): + random_invalid_action = random.randrange(env.action_space.n) + return [i != random_invalid_action for i in range(env.action_space.n)] + + env = ActionMasker(env, action_mask_fn) + + vec_env = VecNormalize(DummyVecEnv([lambda: env])) + # Reduce the size of the features + # deactivate normalization + kwargs = dict( + policy_kwargs=dict( + normalize_images=normalize_images, + features_extractor_kwargs=dict(features_dim=32), + ), + seed=1, + ) + policy = "CnnLstmPolicy" if model_class == RecurrentPPO else "CnnPolicy" + + if model_class in {TRPO, MaskablePPO, RecurrentPPO}: + kwargs.update(dict(n_steps=64, batch_size=64)) + else: + # Avoid memory error when using replay buffer + # Reduce the size of the features + kwargs.update(dict(buffer_size=250)) + if normalize_images: + with pytest.raises(AssertionError): + model_class(policy, vec_env, **kwargs).learn(128) + else: + model_class(policy, vec_env, **kwargs).learn(128) diff --git a/tests/test_invalid_actions.py b/tests/test_invalid_actions.py index 57138f4..3c2c5f8 100644 --- a/tests/test_invalid_actions.py +++ b/tests/test_invalid_actions.py @@ -226,7 +226,8 @@ def test_discrete_action_space_required(): MaskablePPO("MlpPolicy", env) -def test_cnn(): +@pytest.mark.parametrize("share_features_extractor", [True, False]) +def test_cnn(share_features_extractor): def action_mask_fn(env): random_invalid_action = random.randrange(env.action_space.n) return [i != random_invalid_action for i in range(env.action_space.n)] @@ -242,6 +243,7 @@ def test_cnn(): verbose=1, policy_kwargs=dict( features_extractor_kwargs=dict(features_dim=32), + share_features_extractor=share_features_extractor, ), ) model.learn(100) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 1da6779..1e62e0e 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -61,6 +61,7 @@ class CartPoleNoVelEnv(CartPoleEnv): "policy_kwargs", [ {}, + {"share_features_extractor": False}, dict(shared_lstm=True, enable_critic_lstm=False), dict( enable_critic_lstm=True, @@ -74,6 +75,11 @@ class CartPoleNoVelEnv(CartPoleEnv): lstm_kwargs=dict(dropout=0.5), n_lstm_layers=2, ), + dict( + enable_critic_lstm=False, + lstm_hidden_size=4, + share_features_extractor=False, + ), ], ) def test_cnn(policy_kwargs): @@ -83,6 +89,7 @@ def test_cnn(policy_kwargs): n_steps=16, seed=0, policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)), + n_epochs=2, ) model.learn(total_timesteps=32) @@ -130,6 +137,16 @@ def test_check(): policy_kwargs=policy_kwargs, ) + policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=False, share_features_extractor=False) + with pytest.raises(AssertionError): + RecurrentPPO( + "MlpLstmPolicy", + "CartPole-v1", + n_steps=16, + seed=0, + policy_kwargs=policy_kwargs, + ) + @pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"]) def test_run(env):