Fix features extractor issue (#5)
* Fix feature extractor issue * Sync with SB3 PR
This commit is contained in:
parent
017096371c
commit
2ce8d278cc
12
CHANGELOG.md
12
CHANGELOG.md
|
|
@ -1,12 +0,0 @@
|
||||||
## Release 0.10.0a0 (WIP)
|
|
||||||
|
|
||||||
### Breaking Changes
|
|
||||||
|
|
||||||
### New Features
|
|
||||||
- Add TQC
|
|
||||||
|
|
||||||
### Bug fixes
|
|
||||||
|
|
||||||
### Documentation
|
|
||||||
|
|
||||||
### Other
|
|
||||||
|
|
@ -178,6 +178,6 @@ make doc
|
||||||
|
|
||||||
## Changelog and Documentation
|
## Changelog and Documentation
|
||||||
|
|
||||||
Please do not forget to update the changelog (`CHANGELOG.rst`).
|
Please do not forget to update the changelog (`docs/misc/changelog.rst`).
|
||||||
|
|
||||||
Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one.
|
Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one.
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
Pre-Release 0.10.0a1 (WIP)
|
Pre-Release 0.10.0a2 (WIP)
|
||||||
------------------------------
|
------------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -16,6 +16,7 @@ New Features:
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
- Fixed feature extractor issue (``TQC`` with ``CnnPolicy``)
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -203,6 +203,8 @@ class Critic(BaseModel):
|
||||||
:param activation_fn: Activation function
|
:param activation_fn: Activation function
|
||||||
:param normalize_images: Whether to normalize images or not,
|
:param normalize_images: Whether to normalize images or not,
|
||||||
dividing by 255.0 (True by default)
|
dividing by 255.0 (True by default)
|
||||||
|
:param share_features_extractor: Whether the features extractor is shared or not
|
||||||
|
between the actor and the critic (this saves computation time)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -216,6 +218,7 @@ class Critic(BaseModel):
|
||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
n_quantiles: int = 25,
|
n_quantiles: int = 25,
|
||||||
n_critics: int = 2,
|
n_critics: int = 2,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
|
|
@ -226,6 +229,7 @@ class Critic(BaseModel):
|
||||||
|
|
||||||
action_dim = get_action_dim(self.action_space)
|
action_dim = get_action_dim(self.action_space)
|
||||||
|
|
||||||
|
self.share_features_extractor = share_features_extractor
|
||||||
self.q_networks = []
|
self.q_networks = []
|
||||||
self.n_quantiles = n_quantiles
|
self.n_quantiles = n_quantiles
|
||||||
self.n_critics = n_critics
|
self.n_critics = n_critics
|
||||||
|
|
@ -239,8 +243,8 @@ class Critic(BaseModel):
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
|
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
|
||||||
# Learn the features extractor using the policy loss only
|
# Learn the features extractor using the policy loss only
|
||||||
# this is much faster
|
# when the features_extractor is shared with the actor
|
||||||
with th.no_grad():
|
with th.set_grad_enabled(not self.share_features_extractor):
|
||||||
features = self.extract_features(obs)
|
features = self.extract_features(obs)
|
||||||
qvalue_input = th.cat([features, action], dim=1)
|
qvalue_input = th.cat([features, action], dim=1)
|
||||||
quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1)
|
quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1)
|
||||||
|
|
@ -274,6 +278,8 @@ class TQCPolicy(BasePolicy):
|
||||||
``th.optim.Adam`` by default
|
``th.optim.Adam`` by default
|
||||||
:param optimizer_kwargs: Additional keyword arguments,
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
excluding the learning rate, to pass to the optimizer
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
:param share_features_extractor: Whether to share or not the features extractor
|
||||||
|
between the actor and the critic (this saves computation time)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -295,6 +301,7 @@ class TQCPolicy(BasePolicy):
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
n_quantiles: int = 25,
|
n_quantiles: int = 25,
|
||||||
n_critics: int = 2,
|
n_critics: int = 2,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
):
|
):
|
||||||
super(TQCPolicy, self).__init__(
|
super(TQCPolicy, self).__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
|
|
@ -314,17 +321,11 @@ class TQCPolicy(BasePolicy):
|
||||||
|
|
||||||
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||||
|
|
||||||
# Create shared features extractor
|
|
||||||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
|
||||||
self.features_dim = self.features_extractor.features_dim
|
|
||||||
|
|
||||||
self.net_arch = net_arch
|
self.net_arch = net_arch
|
||||||
self.activation_fn = activation_fn
|
self.activation_fn = activation_fn
|
||||||
self.net_args = {
|
self.net_args = {
|
||||||
"observation_space": self.observation_space,
|
"observation_space": self.observation_space,
|
||||||
"action_space": self.action_space,
|
"action_space": self.action_space,
|
||||||
"features_extractor": self.features_extractor,
|
|
||||||
"features_dim": self.features_dim,
|
|
||||||
"net_arch": actor_arch,
|
"net_arch": actor_arch,
|
||||||
"activation_fn": self.activation_fn,
|
"activation_fn": self.activation_fn,
|
||||||
"normalize_images": normalize_images,
|
"normalize_images": normalize_images,
|
||||||
|
|
@ -339,10 +340,16 @@ class TQCPolicy(BasePolicy):
|
||||||
}
|
}
|
||||||
self.actor_kwargs.update(sde_kwargs)
|
self.actor_kwargs.update(sde_kwargs)
|
||||||
self.critic_kwargs = self.net_args.copy()
|
self.critic_kwargs = self.net_args.copy()
|
||||||
tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics, "net_arch": critic_arch}
|
tqc_kwargs = {
|
||||||
|
"n_quantiles": n_quantiles,
|
||||||
|
"n_critics": n_critics,
|
||||||
|
"net_arch": critic_arch,
|
||||||
|
"share_features_extractor": share_features_extractor,
|
||||||
|
}
|
||||||
self.critic_kwargs.update(tqc_kwargs)
|
self.critic_kwargs.update(tqc_kwargs)
|
||||||
self.actor, self.actor_target = None, None
|
self.actor, self.actor_target = None, None
|
||||||
self.critic, self.critic_target = None, None
|
self.critic, self.critic_target = None, None
|
||||||
|
self.share_features_extractor = share_features_extractor
|
||||||
|
|
||||||
self._build(lr_schedule)
|
self._build(lr_schedule)
|
||||||
|
|
||||||
|
|
@ -350,13 +357,21 @@ class TQCPolicy(BasePolicy):
|
||||||
self.actor = self.make_actor()
|
self.actor = self.make_actor()
|
||||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||||
|
|
||||||
self.critic = self.make_critic()
|
if self.share_features_extractor:
|
||||||
self.critic_target = self.make_critic()
|
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
|
||||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
# Do not optimize the shared features extractor with the critic loss
|
||||||
# Do not optimize the shared feature extractor with the critic loss
|
|
||||||
# otherwise, there are gradient computation issues
|
# otherwise, there are gradient computation issues
|
||||||
# Another solution: having duplicated features extractor but requires more memory and computation
|
|
||||||
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
|
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
|
||||||
|
else:
|
||||||
|
# Create a separate features extractor for the critic
|
||||||
|
# this requires more memory and computation
|
||||||
|
self.critic = self.make_critic(features_extractor=None)
|
||||||
|
critic_parameters = self.critic.parameters()
|
||||||
|
|
||||||
|
# Critic target should not share the feature extactor with critic
|
||||||
|
self.critic_target = self.make_critic(features_extractor=None)
|
||||||
|
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||||
|
|
||||||
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||||
|
|
||||||
def _get_data(self) -> Dict[str, Any]:
|
def _get_data(self) -> Dict[str, Any]:
|
||||||
|
|
@ -388,11 +403,13 @@ class TQCPolicy(BasePolicy):
|
||||||
"""
|
"""
|
||||||
self.actor.reset_noise(batch_size=batch_size)
|
self.actor.reset_noise(batch_size=batch_size)
|
||||||
|
|
||||||
def make_actor(self) -> Actor:
|
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
|
||||||
return Actor(**self.actor_kwargs).to(self.device)
|
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
|
||||||
|
return Actor(**actor_kwargs).to(self.device)
|
||||||
|
|
||||||
def make_critic(self) -> Critic:
|
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Critic:
|
||||||
return Critic(**self.critic_kwargs).to(self.device)
|
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||||
|
return Critic(**critic_kwargs).to(self.device)
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
return self._predict(obs, deterministic=deterministic)
|
return self._predict(obs, deterministic=deterministic)
|
||||||
|
|
@ -429,6 +446,8 @@ class CnnPolicy(TQCPolicy):
|
||||||
``th.optim.Adam`` by default
|
``th.optim.Adam`` by default
|
||||||
:param optimizer_kwargs: Additional keyword arguments,
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
excluding the learning rate, to pass to the optimizer
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
:param share_features_extractor: Whether to share or not the features extractor
|
||||||
|
between the actor and the critic (this saves computation time)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -450,6 +469,7 @@ class CnnPolicy(TQCPolicy):
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
n_quantiles: int = 25,
|
n_quantiles: int = 25,
|
||||||
n_critics: int = 2,
|
n_critics: int = 2,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
):
|
):
|
||||||
super(CnnPolicy, self).__init__(
|
super(CnnPolicy, self).__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
|
|
@ -469,6 +489,7 @@ class CnnPolicy(TQCPolicy):
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
n_quantiles,
|
n_quantiles,
|
||||||
n_critics,
|
n_critics,
|
||||||
|
share_features_extractor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
0.10.0a1
|
0.10.0a2
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch as th
|
||||||
|
from stable_baselines3.common.identity_env import FakeImageEnv
|
||||||
|
from stable_baselines3.common.utils import zip_strict
|
||||||
|
|
||||||
|
from sb3_contrib import TQC
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_class", [TQC])
|
||||||
|
def test_cnn(tmp_path, model_class):
|
||||||
|
SAVE_NAME = "cnn_model.zip"
|
||||||
|
# Fake grayscale with frameskip
|
||||||
|
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
||||||
|
# to check that the network handle it automatically
|
||||||
|
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
||||||
|
kwargs = {}
|
||||||
|
if model_class in {TQC}:
|
||||||
|
# Avoid memory error when using replay buffer
|
||||||
|
# Reduce the size of the features
|
||||||
|
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
|
||||||
|
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||||
|
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
action, _ = model.predict(obs, deterministic=True)
|
||||||
|
|
||||||
|
model.save(tmp_path / SAVE_NAME)
|
||||||
|
del model
|
||||||
|
|
||||||
|
model = model_class.load(tmp_path / SAVE_NAME)
|
||||||
|
|
||||||
|
# Check that the prediction is the same
|
||||||
|
assert np.allclose(action, model.predict(obs, deterministic=True)[0])
|
||||||
|
|
||||||
|
os.remove(str(tmp_path / SAVE_NAME))
|
||||||
|
|
||||||
|
|
||||||
|
def params_should_match(params, other_params):
|
||||||
|
for param, other_param in zip_strict(params, other_params):
|
||||||
|
assert th.allclose(param, other_param)
|
||||||
|
|
||||||
|
|
||||||
|
def params_should_differ(params, other_params):
|
||||||
|
for param, other_param in zip_strict(params, other_params):
|
||||||
|
assert not th.allclose(param, other_param)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_class", [TQC])
|
||||||
|
@pytest.mark.parametrize("share_features_extractor", [True, False])
|
||||||
|
def test_feature_extractor_target_net(model_class, share_features_extractor):
|
||||||
|
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
||||||
|
# Avoid memory error when using replay buffer
|
||||||
|
# Reduce the size of the features
|
||||||
|
kwargs = dict(
|
||||||
|
buffer_size=250,
|
||||||
|
learning_starts=100,
|
||||||
|
policy_kwargs=dict(
|
||||||
|
features_extractor_kwargs=dict(features_dim=32),
|
||||||
|
share_features_extractor=share_features_extractor,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model = model_class("CnnPolicy", env, seed=0, **kwargs)
|
||||||
|
|
||||||
|
if share_features_extractor:
|
||||||
|
# Check that the objects are the same and not just copied
|
||||||
|
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
|
||||||
|
else:
|
||||||
|
# Check that the objects differ
|
||||||
|
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
|
||||||
|
|
||||||
|
# Critic and target should be equal at the begginning of training
|
||||||
|
params_should_match(model.critic.parameters(), model.critic_target.parameters())
|
||||||
|
|
||||||
|
model.learn(200)
|
||||||
|
|
||||||
|
# Critic and target should differ
|
||||||
|
params_should_differ(model.critic.parameters(), model.critic_target.parameters())
|
||||||
|
|
||||||
|
# Re-initialize and collect some random data (without doing gradient steps)
|
||||||
|
model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10)
|
||||||
|
|
||||||
|
original_param = deepcopy(list(model.critic.parameters()))
|
||||||
|
original_target_param = deepcopy(list(model.critic_target.parameters()))
|
||||||
|
|
||||||
|
# Deactivate copy to target
|
||||||
|
model.tau = 0.0
|
||||||
|
model.train(gradient_steps=1)
|
||||||
|
|
||||||
|
# Target should be the same
|
||||||
|
params_should_match(original_target_param, model.critic_target.parameters())
|
||||||
|
|
||||||
|
# not the same for critic net (updated by gradient descent)
|
||||||
|
params_should_differ(original_param, model.critic.parameters())
|
||||||
|
|
||||||
|
# Update the reference as it should not change in the next step
|
||||||
|
original_param = deepcopy(list(model.critic.parameters()))
|
||||||
|
|
||||||
|
# Deactivate learning rate
|
||||||
|
model.lr_schedule = lambda _: 0.0
|
||||||
|
# Re-activate polyak update
|
||||||
|
model.tau = 0.01
|
||||||
|
|
||||||
|
model.train(gradient_steps=1)
|
||||||
|
|
||||||
|
# Target should have changed now (due to polyak update)
|
||||||
|
params_should_differ(original_target_param, model.critic_target.parameters())
|
||||||
|
|
||||||
|
# Critic should be the same
|
||||||
|
params_should_match(original_param, model.critic.parameters())
|
||||||
Loading…
Reference in New Issue