Fix features extractor issue (#5)
* Fix feature extractor issue * Sync with SB3 PR
This commit is contained in:
parent
017096371c
commit
2ce8d278cc
12
CHANGELOG.md
12
CHANGELOG.md
|
|
@ -1,12 +0,0 @@
|
|||
## Release 0.10.0a0 (WIP)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
### New Features
|
||||
- Add TQC
|
||||
|
||||
### Bug fixes
|
||||
|
||||
### Documentation
|
||||
|
||||
### Other
|
||||
|
|
@ -178,6 +178,6 @@ make doc
|
|||
|
||||
## Changelog and Documentation
|
||||
|
||||
Please do not forget to update the changelog (`CHANGELOG.rst`).
|
||||
Please do not forget to update the changelog (`docs/misc/changelog.rst`).
|
||||
|
||||
Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one.
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Pre-Release 0.10.0a1 (WIP)
|
||||
Pre-Release 0.10.0a2 (WIP)
|
||||
------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -16,6 +16,7 @@ New Features:
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed feature extractor issue (``TQC`` with ``CnnPolicy``)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -203,6 +203,8 @@ class Critic(BaseModel):
|
|||
:param activation_fn: Activation function
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param share_features_extractor: Whether the features extractor is shared or not
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -216,6 +218,7 @@ class Critic(BaseModel):
|
|||
normalize_images: bool = True,
|
||||
n_quantiles: int = 25,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
@ -226,6 +229,7 @@ class Critic(BaseModel):
|
|||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
|
||||
self.share_features_extractor = share_features_extractor
|
||||
self.q_networks = []
|
||||
self.n_quantiles = n_quantiles
|
||||
self.n_critics = n_critics
|
||||
|
|
@ -239,8 +243,8 @@ class Critic(BaseModel):
|
|||
|
||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
|
||||
# Learn the features extractor using the policy loss only
|
||||
# this is much faster
|
||||
with th.no_grad():
|
||||
# when the features_extractor is shared with the actor
|
||||
with th.set_grad_enabled(not self.share_features_extractor):
|
||||
features = self.extract_features(obs)
|
||||
qvalue_input = th.cat([features, action], dim=1)
|
||||
quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1)
|
||||
|
|
@ -274,6 +278,8 @@ class TQCPolicy(BasePolicy):
|
|||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param share_features_extractor: Whether to share or not the features extractor
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -295,6 +301,7 @@ class TQCPolicy(BasePolicy):
|
|||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_quantiles: int = 25,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
):
|
||||
super(TQCPolicy, self).__init__(
|
||||
observation_space,
|
||||
|
|
@ -314,17 +321,11 @@ class TQCPolicy(BasePolicy):
|
|||
|
||||
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||
|
||||
# Create shared features extractor
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.net_args = {
|
||||
"observation_space": self.observation_space,
|
||||
"action_space": self.action_space,
|
||||
"features_extractor": self.features_extractor,
|
||||
"features_dim": self.features_dim,
|
||||
"net_arch": actor_arch,
|
||||
"activation_fn": self.activation_fn,
|
||||
"normalize_images": normalize_images,
|
||||
|
|
@ -339,10 +340,16 @@ class TQCPolicy(BasePolicy):
|
|||
}
|
||||
self.actor_kwargs.update(sde_kwargs)
|
||||
self.critic_kwargs = self.net_args.copy()
|
||||
tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics, "net_arch": critic_arch}
|
||||
tqc_kwargs = {
|
||||
"n_quantiles": n_quantiles,
|
||||
"n_critics": n_critics,
|
||||
"net_arch": critic_arch,
|
||||
"share_features_extractor": share_features_extractor,
|
||||
}
|
||||
self.critic_kwargs.update(tqc_kwargs)
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
self.share_features_extractor = share_features_extractor
|
||||
|
||||
self._build(lr_schedule)
|
||||
|
||||
|
|
@ -350,13 +357,21 @@ class TQCPolicy(BasePolicy):
|
|||
self.actor = self.make_actor()
|
||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
self.critic = self.make_critic()
|
||||
self.critic_target = self.make_critic()
|
||||
if self.share_features_extractor:
|
||||
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
|
||||
# Do not optimize the shared features extractor with the critic loss
|
||||
# otherwise, there are gradient computation issues
|
||||
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
|
||||
else:
|
||||
# Create a separate features extractor for the critic
|
||||
# this requires more memory and computation
|
||||
self.critic = self.make_critic(features_extractor=None)
|
||||
critic_parameters = self.critic.parameters()
|
||||
|
||||
# Critic target should not share the feature extactor with critic
|
||||
self.critic_target = self.make_critic(features_extractor=None)
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
# Do not optimize the shared feature extractor with the critic loss
|
||||
# otherwise, there are gradient computation issues
|
||||
# Another solution: having duplicated features extractor but requires more memory and computation
|
||||
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
|
||||
|
||||
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
|
|
@ -388,11 +403,13 @@ class TQCPolicy(BasePolicy):
|
|||
"""
|
||||
self.actor.reset_noise(batch_size=batch_size)
|
||||
|
||||
def make_actor(self) -> Actor:
|
||||
return Actor(**self.actor_kwargs).to(self.device)
|
||||
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
|
||||
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
|
||||
return Actor(**actor_kwargs).to(self.device)
|
||||
|
||||
def make_critic(self) -> Critic:
|
||||
return Critic(**self.critic_kwargs).to(self.device)
|
||||
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Critic:
|
||||
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||
return Critic(**critic_kwargs).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
|
@ -429,6 +446,8 @@ class CnnPolicy(TQCPolicy):
|
|||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param share_features_extractor: Whether to share or not the features extractor
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -450,6 +469,7 @@ class CnnPolicy(TQCPolicy):
|
|||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_quantiles: int = 25,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
):
|
||||
super(CnnPolicy, self).__init__(
|
||||
observation_space,
|
||||
|
|
@ -469,6 +489,7 @@ class CnnPolicy(TQCPolicy):
|
|||
optimizer_kwargs,
|
||||
n_quantiles,
|
||||
n_critics,
|
||||
share_features_extractor,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.10.0a1
|
||||
0.10.0a2
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
from stable_baselines3.common.identity_env import FakeImageEnv
|
||||
from stable_baselines3.common.utils import zip_strict
|
||||
|
||||
from sb3_contrib import TQC
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TQC])
|
||||
def test_cnn(tmp_path, model_class):
|
||||
SAVE_NAME = "cnn_model.zip"
|
||||
# Fake grayscale with frameskip
|
||||
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
||||
# to check that the network handle it automatically
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
||||
kwargs = {}
|
||||
if model_class in {TQC}:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
|
||||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||
|
||||
obs = env.reset()
|
||||
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
|
||||
model.save(tmp_path / SAVE_NAME)
|
||||
del model
|
||||
|
||||
model = model_class.load(tmp_path / SAVE_NAME)
|
||||
|
||||
# Check that the prediction is the same
|
||||
assert np.allclose(action, model.predict(obs, deterministic=True)[0])
|
||||
|
||||
os.remove(str(tmp_path / SAVE_NAME))
|
||||
|
||||
|
||||
def params_should_match(params, other_params):
|
||||
for param, other_param in zip_strict(params, other_params):
|
||||
assert th.allclose(param, other_param)
|
||||
|
||||
|
||||
def params_should_differ(params, other_params):
|
||||
for param, other_param in zip_strict(params, other_params):
|
||||
assert not th.allclose(param, other_param)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TQC])
|
||||
@pytest.mark.parametrize("share_features_extractor", [True, False])
|
||||
def test_feature_extractor_target_net(model_class, share_features_extractor):
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
learning_starts=100,
|
||||
policy_kwargs=dict(
|
||||
features_extractor_kwargs=dict(features_dim=32),
|
||||
share_features_extractor=share_features_extractor,
|
||||
),
|
||||
)
|
||||
model = model_class("CnnPolicy", env, seed=0, **kwargs)
|
||||
|
||||
if share_features_extractor:
|
||||
# Check that the objects are the same and not just copied
|
||||
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
|
||||
else:
|
||||
# Check that the objects differ
|
||||
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
|
||||
|
||||
# Critic and target should be equal at the begginning of training
|
||||
params_should_match(model.critic.parameters(), model.critic_target.parameters())
|
||||
|
||||
model.learn(200)
|
||||
|
||||
# Critic and target should differ
|
||||
params_should_differ(model.critic.parameters(), model.critic_target.parameters())
|
||||
|
||||
# Re-initialize and collect some random data (without doing gradient steps)
|
||||
model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10)
|
||||
|
||||
original_param = deepcopy(list(model.critic.parameters()))
|
||||
original_target_param = deepcopy(list(model.critic_target.parameters()))
|
||||
|
||||
# Deactivate copy to target
|
||||
model.tau = 0.0
|
||||
model.train(gradient_steps=1)
|
||||
|
||||
# Target should be the same
|
||||
params_should_match(original_target_param, model.critic_target.parameters())
|
||||
|
||||
# not the same for critic net (updated by gradient descent)
|
||||
params_should_differ(original_param, model.critic.parameters())
|
||||
|
||||
# Update the reference as it should not change in the next step
|
||||
original_param = deepcopy(list(model.critic.parameters()))
|
||||
|
||||
# Deactivate learning rate
|
||||
model.lr_schedule = lambda _: 0.0
|
||||
# Re-activate polyak update
|
||||
model.tau = 0.01
|
||||
|
||||
model.train(gradient_steps=1)
|
||||
|
||||
# Target should have changed now (due to polyak update)
|
||||
params_should_differ(original_target_param, model.critic_target.parameters())
|
||||
|
||||
# Critic should be the same
|
||||
params_should_match(original_param, model.critic.parameters())
|
||||
Loading…
Reference in New Issue