Remove PyType and upgrade to latest SB3 version (#215)

This commit is contained in:
Antonin RAFFIN 2023-11-08 09:50:06 +01:00 committed by Andreas Schaarschmidt
parent 5e437fc4dc
commit c965ba9d3b
8 changed files with 20 additions and 30 deletions

View File

@ -58,8 +58,6 @@ jobs:
- name: Type check
run: |
make type
# skip PyType, doesn't support 3.11 yet
if: "!(matrix.python-version == '3.11')"
- name: Test with pytest
run: |
make pytest

View File

@ -4,13 +4,10 @@ LINT_PATHS=sb3_contrib/ tests/ setup.py docs/conf.py
pytest:
./scripts/run_tests.sh
pytype:
pytype -j auto
mypy:
mypy ${LINT_PATHS}
type: pytype mypy
type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names

View File

@ -4,7 +4,7 @@ Changelog
==========
Release 2.2.0a9 (WIP)
Release 2.2.0a11 (WIP)
--------------------------
Breaking Changes:

View File

@ -19,9 +19,6 @@ max-complexity = 15
[tool.black]
line-length = 127
[tool.pytype]
inputs = ['sb3_contrib']
[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"

View File

@ -4,7 +4,8 @@ import torch as th
from gymnasium import spaces
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
from stable_baselines3.common.torch_layers import create_mlp
from stable_baselines3.common.type_aliases import PyTorchObs
from torch import nn
@ -66,10 +67,7 @@ class ARSPolicy(BasePolicy):
)
return data
def forward(self, obs: th.Tensor) -> th.Tensor:
# Make mypy happy:
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
def forward(self, obs: PyTorchObs) -> th.Tensor:
features = self.extract_features(obs, self.features_extractor)
if isinstance(self.action_space, spaces.Box):
return self.action_net(features)
@ -79,7 +77,7 @@ class ARSPolicy(BasePolicy):
else:
raise NotImplementedError()
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
# Non deterministic action does not really make sense for ARS, we ignore this parameter for now..
return self(observation)

View File

@ -10,7 +10,7 @@ from stable_baselines3.common.torch_layers import (
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from torch import nn
@ -58,7 +58,7 @@ class QuantileNetwork(BasePolicy):
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
self.quantile_net = nn.Sequential(*quantile_net)
def forward(self, obs: th.Tensor) -> th.Tensor:
def forward(self, obs: PyTorchObs) -> th.Tensor:
"""
Predict the quantiles.
@ -68,7 +68,7 @@ class QuantileNetwork(BasePolicy):
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
return quantiles.view(-1, self.n_quantiles, int(self.action_space.n))
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
q_values = self(observation).mean(dim=1)
# Greedy action
action = q_values.argmax(dim=1).reshape(-1)
@ -181,10 +181,10 @@ class QRDQNPolicy(BasePolicy):
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return QuantileNetwork(**net_args).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
return self.quantile_net._predict(obs, deterministic=deterministic)
def _get_constructor_parameters(self) -> Dict[str, Any]:

View File

@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import (
create_mlp,
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from torch import nn as nn
# CAP the standard deviation of the actor
@ -143,7 +143,7 @@ class Actor(BasePolicy):
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
"""
Get the parameters for the action distribution.
@ -163,17 +163,17 @@ class Actor(BasePolicy):
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, {}
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# return action and associated log prob
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic)
@ -231,7 +231,7 @@ class Critic(BaseModel):
self.add_module(f"qf{i}", qf_net)
self.q_networks.append(qf_net)
def forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor:
def forward(self, obs: PyTorchObs, action: th.Tensor) -> th.Tensor:
# Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
@ -411,10 +411,10 @@ class TQCPolicy(BasePolicy):
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return Critic(**critic_kwargs).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self.actor(observation, deterministic)
def set_training_mode(self, mode: bool) -> None:

View File

@ -1 +1 @@
2.2.0a9
2.2.0a11