Fix features extractor issue (#5)

* Fix feature extractor issue

* Sync with SB3 PR
This commit is contained in:
Antonin RAFFIN 2020-10-27 14:30:35 +01:00 committed by GitHub
parent 017096371c
commit 2ce8d278cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 157 additions and 34 deletions

View File

@ -1,12 +0,0 @@
## Release 0.10.0a0 (WIP)
### Breaking Changes
### New Features
- Add TQC
### Bug fixes
### Documentation
### Other

View File

@ -178,6 +178,6 @@ make doc
## Changelog and Documentation
Please do not forget to update the changelog (`CHANGELOG.rst`).
Please do not forget to update the changelog (`docs/misc/changelog.rst`).
Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one.

View File

@ -4,7 +4,7 @@ Changelog
==========
Pre-Release 0.10.0a1 (WIP)
Pre-Release 0.10.0a2 (WIP)
------------------------------
Breaking Changes:
@ -16,6 +16,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed feature extractor issue (``TQC`` with ``CnnPolicy``)
Deprecations:
^^^^^^^^^^^^^

View File

@ -203,6 +203,8 @@ class Critic(BaseModel):
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param share_features_extractor: Whether the features extractor is shared or not
between the actor and the critic (this saves computation time)
"""
def __init__(
@ -216,6 +218,7 @@ class Critic(BaseModel):
normalize_images: bool = True,
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super().__init__(
observation_space,
@ -226,6 +229,7 @@ class Critic(BaseModel):
action_dim = get_action_dim(self.action_space)
self.share_features_extractor = share_features_extractor
self.q_networks = []
self.n_quantiles = n_quantiles
self.n_critics = n_critics
@ -239,8 +243,8 @@ class Critic(BaseModel):
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
# Learn the features extractor using the policy loss only
# this is much faster
with th.no_grad():
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs)
qvalue_input = th.cat([features, action], dim=1)
quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1)
@ -274,6 +278,8 @@ class TQCPolicy(BasePolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
@ -295,6 +301,7 @@ class TQCPolicy(BasePolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super(TQCPolicy, self).__init__(
observation_space,
@ -314,17 +321,11 @@ class TQCPolicy(BasePolicy):
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
# Create shared features extractor
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"features_extractor": self.features_extractor,
"features_dim": self.features_dim,
"net_arch": actor_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
@ -339,10 +340,16 @@ class TQCPolicy(BasePolicy):
}
self.actor_kwargs.update(sde_kwargs)
self.critic_kwargs = self.net_args.copy()
tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics, "net_arch": critic_arch}
tqc_kwargs = {
"n_quantiles": n_quantiles,
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
}
self.critic_kwargs.update(tqc_kwargs)
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
self.share_features_extractor = share_features_extractor
self._build(lr_schedule)
@ -350,13 +357,21 @@ class TQCPolicy(BasePolicy):
self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.critic = self.make_critic()
self.critic_target = self.make_critic()
self.critic_target.load_state_dict(self.critic.state_dict())
# Do not optimize the shared feature extractor with the critic loss
if self.share_features_extractor:
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
# Do not optimize the shared features extractor with the critic loss
# otherwise, there are gradient computation issues
# Another solution: having duplicated features extractor but requires more memory and computation
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
else:
# Create a separate features extractor for the critic
# this requires more memory and computation
self.critic = self.make_critic(features_extractor=None)
critic_parameters = self.critic.parameters()
# Critic target should not share the feature extactor with critic
self.critic_target = self.make_critic(features_extractor=None)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
def _get_data(self) -> Dict[str, Any]:
@ -388,11 +403,13 @@ class TQCPolicy(BasePolicy):
"""
self.actor.reset_noise(batch_size=batch_size)
def make_actor(self) -> Actor:
return Actor(**self.actor_kwargs).to(self.device)
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
return Actor(**actor_kwargs).to(self.device)
def make_critic(self) -> Critic:
return Critic(**self.critic_kwargs).to(self.device)
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Critic:
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return Critic(**critic_kwargs).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
@ -429,6 +446,8 @@ class CnnPolicy(TQCPolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
@ -450,6 +469,7 @@ class CnnPolicy(TQCPolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super(CnnPolicy, self).__init__(
observation_space,
@ -469,6 +489,7 @@ class CnnPolicy(TQCPolicy):
optimizer_kwargs,
n_quantiles,
n_critics,
share_features_extractor,
)

View File

@ -1 +1 @@
0.10.0a1
0.10.0a2

113
tests/test_cnn.py Normal file
View File

@ -0,0 +1,113 @@
import os
from copy import deepcopy
import numpy as np
import pytest
import torch as th
from stable_baselines3.common.identity_env import FakeImageEnv
from stable_baselines3.common.utils import zip_strict
from sb3_contrib import TQC
@pytest.mark.parametrize("model_class", [TQC])
def test_cnn(tmp_path, model_class):
SAVE_NAME = "cnn_model.zip"
# Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution
# to check that the network handle it automatically
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
kwargs = {}
if model_class in {TQC}:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
model = model_class("CnnPolicy", env, **kwargs).learn(250)
obs = env.reset()
action, _ = model.predict(obs, deterministic=True)
model.save(tmp_path / SAVE_NAME)
del model
model = model_class.load(tmp_path / SAVE_NAME)
# Check that the prediction is the same
assert np.allclose(action, model.predict(obs, deterministic=True)[0])
os.remove(str(tmp_path / SAVE_NAME))
def params_should_match(params, other_params):
for param, other_param in zip_strict(params, other_params):
assert th.allclose(param, other_param)
def params_should_differ(params, other_params):
for param, other_param in zip_strict(params, other_params):
assert not th.allclose(param, other_param)
@pytest.mark.parametrize("model_class", [TQC])
@pytest.mark.parametrize("share_features_extractor", [True, False])
def test_feature_extractor_target_net(model_class, share_features_extractor):
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(
buffer_size=250,
learning_starts=100,
policy_kwargs=dict(
features_extractor_kwargs=dict(features_dim=32),
share_features_extractor=share_features_extractor,
),
)
model = model_class("CnnPolicy", env, seed=0, **kwargs)
if share_features_extractor:
# Check that the objects are the same and not just copied
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
else:
# Check that the objects differ
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
# Critic and target should be equal at the begginning of training
params_should_match(model.critic.parameters(), model.critic_target.parameters())
model.learn(200)
# Critic and target should differ
params_should_differ(model.critic.parameters(), model.critic_target.parameters())
# Re-initialize and collect some random data (without doing gradient steps)
model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10)
original_param = deepcopy(list(model.critic.parameters()))
original_target_param = deepcopy(list(model.critic_target.parameters()))
# Deactivate copy to target
model.tau = 0.0
model.train(gradient_steps=1)
# Target should be the same
params_should_match(original_target_param, model.critic_target.parameters())
# not the same for critic net (updated by gradient descent)
params_should_differ(original_param, model.critic.parameters())
# Update the reference as it should not change in the next step
original_param = deepcopy(list(model.critic.parameters()))
# Deactivate learning rate
model.lr_schedule = lambda _: 0.0
# Re-activate polyak update
model.tau = 0.01
model.train(gradient_steps=1)
# Target should have changed now (due to polyak update)
params_should_differ(original_target_param, model.critic_target.parameters())
# Critic should be the same
params_should_match(original_param, model.critic.parameters())