From 2ce8d278cc1ff91b40ef603894efbd6cb9b15bf2 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 27 Oct 2020 14:30:35 +0100 Subject: [PATCH] Fix features extractor issue (#5) * Fix feature extractor issue * Sync with SB3 PR --- CHANGELOG.md | 12 ---- CONTRIBUTING.md | 2 +- docs/misc/changelog.rst | 3 +- sb3_contrib/tqc/policies.py | 59 +++++++++++++------ sb3_contrib/version.txt | 2 +- tests/test_cnn.py | 113 ++++++++++++++++++++++++++++++++++++ 6 files changed, 157 insertions(+), 34 deletions(-) delete mode 100644 CHANGELOG.md create mode 100644 tests/test_cnn.py diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index 40573a5..0000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,12 +0,0 @@ -## Release 0.10.0a0 (WIP) - -### Breaking Changes - -### New Features -- Add TQC - -### Bug fixes - -### Documentation - -### Other diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d51f9f1..1112e98 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -178,6 +178,6 @@ make doc ## Changelog and Documentation -Please do not forget to update the changelog (`CHANGELOG.rst`). +Please do not forget to update the changelog (`docs/misc/changelog.rst`). Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8ab60da..d004eea 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Pre-Release 0.10.0a1 (WIP) +Pre-Release 0.10.0a2 (WIP) ------------------------------ Breaking Changes: @@ -16,6 +16,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Fixed feature extractor issue (``TQC`` with ``CnnPolicy``) Deprecations: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index 788a7ba..4ad4b15 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -203,6 +203,8 @@ class Critic(BaseModel): :param activation_fn: Activation function :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) + :param share_features_extractor: Whether the features extractor is shared or not + between the actor and the critic (this saves computation time) """ def __init__( @@ -216,6 +218,7 @@ class Critic(BaseModel): normalize_images: bool = True, n_quantiles: int = 25, n_critics: int = 2, + share_features_extractor: bool = True, ): super().__init__( observation_space, @@ -226,6 +229,7 @@ class Critic(BaseModel): action_dim = get_action_dim(self.action_space) + self.share_features_extractor = share_features_extractor self.q_networks = [] self.n_quantiles = n_quantiles self.n_critics = n_critics @@ -239,8 +243,8 @@ class Critic(BaseModel): def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]: # Learn the features extractor using the policy loss only - # this is much faster - with th.no_grad(): + # when the features_extractor is shared with the actor + with th.set_grad_enabled(not self.share_features_extractor): features = self.extract_features(obs) qvalue_input = th.cat([features, action], dim=1) quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1) @@ -274,6 +278,8 @@ class TQCPolicy(BasePolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) """ def __init__( @@ -295,6 +301,7 @@ class TQCPolicy(BasePolicy): optimizer_kwargs: Optional[Dict[str, Any]] = None, n_quantiles: int = 25, n_critics: int = 2, + share_features_extractor: bool = True, ): super(TQCPolicy, self).__init__( observation_space, @@ -314,17 +321,11 @@ class TQCPolicy(BasePolicy): actor_arch, critic_arch = get_actor_critic_arch(net_arch) - # Create shared features extractor - self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) - self.features_dim = self.features_extractor.features_dim - self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { "observation_space": self.observation_space, "action_space": self.action_space, - "features_extractor": self.features_extractor, - "features_dim": self.features_dim, "net_arch": actor_arch, "activation_fn": self.activation_fn, "normalize_images": normalize_images, @@ -339,10 +340,16 @@ class TQCPolicy(BasePolicy): } self.actor_kwargs.update(sde_kwargs) self.critic_kwargs = self.net_args.copy() - tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics, "net_arch": critic_arch} + tqc_kwargs = { + "n_quantiles": n_quantiles, + "n_critics": n_critics, + "net_arch": critic_arch, + "share_features_extractor": share_features_extractor, + } self.critic_kwargs.update(tqc_kwargs) self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None + self.share_features_extractor = share_features_extractor self._build(lr_schedule) @@ -350,13 +357,21 @@ class TQCPolicy(BasePolicy): self.actor = self.make_actor() self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) - self.critic = self.make_critic() - self.critic_target = self.make_critic() + if self.share_features_extractor: + self.critic = self.make_critic(features_extractor=self.actor.features_extractor) + # Do not optimize the shared features extractor with the critic loss + # otherwise, there are gradient computation issues + critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name] + else: + # Create a separate features extractor for the critic + # this requires more memory and computation + self.critic = self.make_critic(features_extractor=None) + critic_parameters = self.critic.parameters() + + # Critic target should not share the feature extactor with critic + self.critic_target = self.make_critic(features_extractor=None) self.critic_target.load_state_dict(self.critic.state_dict()) - # Do not optimize the shared feature extractor with the critic loss - # otherwise, there are gradient computation issues - # Another solution: having duplicated features extractor but requires more memory and computation - critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name] + self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) def _get_data(self) -> Dict[str, Any]: @@ -388,11 +403,13 @@ class TQCPolicy(BasePolicy): """ self.actor.reset_noise(batch_size=batch_size) - def make_actor(self) -> Actor: - return Actor(**self.actor_kwargs).to(self.device) + def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor: + actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) + return Actor(**actor_kwargs).to(self.device) - def make_critic(self) -> Critic: - return Critic(**self.critic_kwargs).to(self.device) + def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Critic: + critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) + return Critic(**critic_kwargs).to(self.device) def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: return self._predict(obs, deterministic=deterministic) @@ -429,6 +446,8 @@ class CnnPolicy(TQCPolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) """ def __init__( @@ -450,6 +469,7 @@ class CnnPolicy(TQCPolicy): optimizer_kwargs: Optional[Dict[str, Any]] = None, n_quantiles: int = 25, n_critics: int = 2, + share_features_extractor: bool = True, ): super(CnnPolicy, self).__init__( observation_space, @@ -469,6 +489,7 @@ class CnnPolicy(TQCPolicy): optimizer_kwargs, n_quantiles, n_critics, + share_features_extractor, ) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 8dabd1f..bb625ca 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -0.10.0a1 +0.10.0a2 diff --git a/tests/test_cnn.py b/tests/test_cnn.py new file mode 100644 index 0000000..0e39012 --- /dev/null +++ b/tests/test_cnn.py @@ -0,0 +1,113 @@ +import os +from copy import deepcopy + +import numpy as np +import pytest +import torch as th +from stable_baselines3.common.identity_env import FakeImageEnv +from stable_baselines3.common.utils import zip_strict + +from sb3_contrib import TQC + + +@pytest.mark.parametrize("model_class", [TQC]) +def test_cnn(tmp_path, model_class): + SAVE_NAME = "cnn_model.zip" + # Fake grayscale with frameskip + # Atari after preprocessing: 84x84x1, here we are using lower resolution + # to check that the network handle it automatically + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC}) + kwargs = {} + if model_class in {TQC}: + # Avoid memory error when using replay buffer + # Reduce the size of the features + kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))) + model = model_class("CnnPolicy", env, **kwargs).learn(250) + + obs = env.reset() + + action, _ = model.predict(obs, deterministic=True) + + model.save(tmp_path / SAVE_NAME) + del model + + model = model_class.load(tmp_path / SAVE_NAME) + + # Check that the prediction is the same + assert np.allclose(action, model.predict(obs, deterministic=True)[0]) + + os.remove(str(tmp_path / SAVE_NAME)) + + +def params_should_match(params, other_params): + for param, other_param in zip_strict(params, other_params): + assert th.allclose(param, other_param) + + +def params_should_differ(params, other_params): + for param, other_param in zip_strict(params, other_params): + assert not th.allclose(param, other_param) + + +@pytest.mark.parametrize("model_class", [TQC]) +@pytest.mark.parametrize("share_features_extractor", [True, False]) +def test_feature_extractor_target_net(model_class, share_features_extractor): + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC}) + # Avoid memory error when using replay buffer + # Reduce the size of the features + kwargs = dict( + buffer_size=250, + learning_starts=100, + policy_kwargs=dict( + features_extractor_kwargs=dict(features_dim=32), + share_features_extractor=share_features_extractor, + ), + ) + model = model_class("CnnPolicy", env, seed=0, **kwargs) + + if share_features_extractor: + # Check that the objects are the same and not just copied + assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor) + else: + # Check that the objects differ + assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor) + + # Critic and target should be equal at the begginning of training + params_should_match(model.critic.parameters(), model.critic_target.parameters()) + + model.learn(200) + + # Critic and target should differ + params_should_differ(model.critic.parameters(), model.critic_target.parameters()) + + # Re-initialize and collect some random data (without doing gradient steps) + model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10) + + original_param = deepcopy(list(model.critic.parameters())) + original_target_param = deepcopy(list(model.critic_target.parameters())) + + # Deactivate copy to target + model.tau = 0.0 + model.train(gradient_steps=1) + + # Target should be the same + params_should_match(original_target_param, model.critic_target.parameters()) + + # not the same for critic net (updated by gradient descent) + params_should_differ(original_param, model.critic.parameters()) + + # Update the reference as it should not change in the next step + original_param = deepcopy(list(model.critic.parameters())) + + # Deactivate learning rate + model.lr_schedule = lambda _: 0.0 + # Re-activate polyak update + model.tau = 0.01 + + model.train(gradient_steps=1) + + # Target should have changed now (due to polyak update) + params_should_differ(original_target_param, model.critic_target.parameters()) + + # Critic should be the same + params_should_match(original_param, model.critic.parameters())