Remove PyType and upgrade to latest SB3 version (#215)
This commit is contained in:
parent
5e437fc4dc
commit
c965ba9d3b
|
|
@ -58,8 +58,6 @@ jobs:
|
||||||
- name: Type check
|
- name: Type check
|
||||||
run: |
|
run: |
|
||||||
make type
|
make type
|
||||||
# skip PyType, doesn't support 3.11 yet
|
|
||||||
if: "!(matrix.python-version == '3.11')"
|
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
make pytest
|
make pytest
|
||||||
|
|
|
||||||
5
Makefile
5
Makefile
|
|
@ -4,13 +4,10 @@ LINT_PATHS=sb3_contrib/ tests/ setup.py docs/conf.py
|
||||||
pytest:
|
pytest:
|
||||||
./scripts/run_tests.sh
|
./scripts/run_tests.sh
|
||||||
|
|
||||||
pytype:
|
|
||||||
pytype -j auto
|
|
||||||
|
|
||||||
mypy:
|
mypy:
|
||||||
mypy ${LINT_PATHS}
|
mypy ${LINT_PATHS}
|
||||||
|
|
||||||
type: pytype mypy
|
type: mypy
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
# stop the build if there are Python syntax errors or undefined names
|
# stop the build if there are Python syntax errors or undefined names
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
Release 2.2.0a9 (WIP)
|
Release 2.2.0a11 (WIP)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,6 @@ max-complexity = 15
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 127
|
line-length = 127
|
||||||
|
|
||||||
[tool.pytype]
|
|
||||||
inputs = ['sb3_contrib']
|
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
follow_imports = "silent"
|
follow_imports = "silent"
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,8 @@ import torch as th
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
|
from stable_baselines3.common.torch_layers import create_mlp
|
||||||
|
from stable_baselines3.common.type_aliases import PyTorchObs
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -66,10 +67,7 @@ class ARSPolicy(BasePolicy):
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
def forward(self, obs: PyTorchObs) -> th.Tensor:
|
||||||
# Make mypy happy:
|
|
||||||
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
|
|
||||||
|
|
||||||
features = self.extract_features(obs, self.features_extractor)
|
features = self.extract_features(obs, self.features_extractor)
|
||||||
if isinstance(self.action_space, spaces.Box):
|
if isinstance(self.action_space, spaces.Box):
|
||||||
return self.action_net(features)
|
return self.action_net(features)
|
||||||
|
|
@ -79,7 +77,7 @@ class ARSPolicy(BasePolicy):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||||
# Non deterministic action does not really make sense for ARS, we ignore this parameter for now..
|
# Non deterministic action does not really make sense for ARS, we ignore this parameter for now..
|
||||||
return self(observation)
|
return self(observation)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from stable_baselines3.common.torch_layers import (
|
||||||
NatureCNN,
|
NatureCNN,
|
||||||
create_mlp,
|
create_mlp,
|
||||||
)
|
)
|
||||||
from stable_baselines3.common.type_aliases import Schedule
|
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -58,7 +58,7 @@ class QuantileNetwork(BasePolicy):
|
||||||
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
|
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
|
||||||
self.quantile_net = nn.Sequential(*quantile_net)
|
self.quantile_net = nn.Sequential(*quantile_net)
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
def forward(self, obs: PyTorchObs) -> th.Tensor:
|
||||||
"""
|
"""
|
||||||
Predict the quantiles.
|
Predict the quantiles.
|
||||||
|
|
||||||
|
|
@ -68,7 +68,7 @@ class QuantileNetwork(BasePolicy):
|
||||||
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
|
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
|
||||||
return quantiles.view(-1, self.n_quantiles, int(self.action_space.n))
|
return quantiles.view(-1, self.n_quantiles, int(self.action_space.n))
|
||||||
|
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||||
q_values = self(observation).mean(dim=1)
|
q_values = self(observation).mean(dim=1)
|
||||||
# Greedy action
|
# Greedy action
|
||||||
action = q_values.argmax(dim=1).reshape(-1)
|
action = q_values.argmax(dim=1).reshape(-1)
|
||||||
|
|
@ -181,10 +181,10 @@ class QRDQNPolicy(BasePolicy):
|
||||||
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
|
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
|
||||||
return QuantileNetwork(**net_args).to(self.device)
|
return QuantileNetwork(**net_args).to(self.device)
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||||
return self._predict(obs, deterministic=deterministic)
|
return self._predict(obs, deterministic=deterministic)
|
||||||
|
|
||||||
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||||
return self.quantile_net._predict(obs, deterministic=deterministic)
|
return self.quantile_net._predict(obs, deterministic=deterministic)
|
||||||
|
|
||||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import (
|
||||||
create_mlp,
|
create_mlp,
|
||||||
get_actor_critic_arch,
|
get_actor_critic_arch,
|
||||||
)
|
)
|
||||||
from stable_baselines3.common.type_aliases import Schedule
|
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
# CAP the standard deviation of the actor
|
# CAP the standard deviation of the actor
|
||||||
|
|
@ -143,7 +143,7 @@ class Actor(BasePolicy):
|
||||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
||||||
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
||||||
|
|
||||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Get the parameters for the action distribution.
|
Get the parameters for the action distribution.
|
||||||
|
|
||||||
|
|
@ -163,17 +163,17 @@ class Actor(BasePolicy):
|
||||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||||
return mean_actions, log_std, {}
|
return mean_actions, log_std, {}
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
# Note: the action is squashed
|
# Note: the action is squashed
|
||||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||||
|
|
||||||
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
|
||||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
# return action and associated log prob
|
# return action and associated log prob
|
||||||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||||
|
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||||
return self(observation, deterministic)
|
return self(observation, deterministic)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -231,7 +231,7 @@ class Critic(BaseModel):
|
||||||
self.add_module(f"qf{i}", qf_net)
|
self.add_module(f"qf{i}", qf_net)
|
||||||
self.q_networks.append(qf_net)
|
self.q_networks.append(qf_net)
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor:
|
def forward(self, obs: PyTorchObs, action: th.Tensor) -> th.Tensor:
|
||||||
# Learn the features extractor using the policy loss only
|
# Learn the features extractor using the policy loss only
|
||||||
# when the features_extractor is shared with the actor
|
# when the features_extractor is shared with the actor
|
||||||
with th.set_grad_enabled(not self.share_features_extractor):
|
with th.set_grad_enabled(not self.share_features_extractor):
|
||||||
|
|
@ -411,10 +411,10 @@ class TQCPolicy(BasePolicy):
|
||||||
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||||
return Critic(**critic_kwargs).to(self.device)
|
return Critic(**critic_kwargs).to(self.device)
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||||
return self._predict(obs, deterministic=deterministic)
|
return self._predict(obs, deterministic=deterministic)
|
||||||
|
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||||
return self.actor(observation, deterministic)
|
return self.actor(observation, deterministic)
|
||||||
|
|
||||||
def set_training_mode(self, mode: bool) -> None:
|
def set_training_mode(self, mode: bool) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
2.2.0a9
|
2.2.0a11
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue