Remove PyType and upgrade to latest SB3 version (#215)
This commit is contained in:
parent
5e437fc4dc
commit
c965ba9d3b
|
|
@ -58,8 +58,6 @@ jobs:
|
|||
- name: Type check
|
||||
run: |
|
||||
make type
|
||||
# skip PyType, doesn't support 3.11 yet
|
||||
if: "!(matrix.python-version == '3.11')"
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make pytest
|
||||
|
|
|
|||
5
Makefile
5
Makefile
|
|
@ -4,13 +4,10 @@ LINT_PATHS=sb3_contrib/ tests/ setup.py docs/conf.py
|
|||
pytest:
|
||||
./scripts/run_tests.sh
|
||||
|
||||
pytype:
|
||||
pytype -j auto
|
||||
|
||||
mypy:
|
||||
mypy ${LINT_PATHS}
|
||||
|
||||
type: pytype mypy
|
||||
type: mypy
|
||||
|
||||
lint:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 2.2.0a9 (WIP)
|
||||
Release 2.2.0a11 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
|
|||
|
|
@ -19,9 +19,6 @@ max-complexity = 15
|
|||
[tool.black]
|
||||
line-length = 127
|
||||
|
||||
[tool.pytype]
|
||||
inputs = ['sb3_contrib']
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "silent"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ import torch as th
|
|||
from gymnasium import spaces
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
|
||||
from stable_baselines3.common.torch_layers import create_mlp
|
||||
from stable_baselines3.common.type_aliases import PyTorchObs
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
|
@ -66,10 +67,7 @@ class ARSPolicy(BasePolicy):
|
|||
)
|
||||
return data
|
||||
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
# Make mypy happy:
|
||||
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
|
||||
|
||||
def forward(self, obs: PyTorchObs) -> th.Tensor:
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
return self.action_net(features)
|
||||
|
|
@ -79,7 +77,7 @@ class ARSPolicy(BasePolicy):
|
|||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||
# Non deterministic action does not really make sense for ARS, we ignore this parameter for now..
|
||||
return self(observation)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from stable_baselines3.common.torch_layers import (
|
|||
NatureCNN,
|
||||
create_mlp,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
|
@ -58,7 +58,7 @@ class QuantileNetwork(BasePolicy):
|
|||
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
|
||||
self.quantile_net = nn.Sequential(*quantile_net)
|
||||
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
def forward(self, obs: PyTorchObs) -> th.Tensor:
|
||||
"""
|
||||
Predict the quantiles.
|
||||
|
||||
|
|
@ -68,7 +68,7 @@ class QuantileNetwork(BasePolicy):
|
|||
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
|
||||
return quantiles.view(-1, self.n_quantiles, int(self.action_space.n))
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||
q_values = self(observation).mean(dim=1)
|
||||
# Greedy action
|
||||
action = q_values.argmax(dim=1).reshape(-1)
|
||||
|
|
@ -181,10 +181,10 @@ class QRDQNPolicy(BasePolicy):
|
|||
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
|
||||
return QuantileNetwork(**net_args).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||
return self.quantile_net._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import (
|
|||
create_mlp,
|
||||
get_actor_critic_arch,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||
from torch import nn as nn
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
|
|
@ -143,7 +143,7 @@ class Actor(BasePolicy):
|
|||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
||||
|
||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
"""
|
||||
Get the parameters for the action distribution.
|
||||
|
||||
|
|
@ -163,17 +163,17 @@ class Actor(BasePolicy):
|
|||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||
return mean_actions, log_std, {}
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# Note: the action is squashed
|
||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||
|
||||
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# return action and associated log prob
|
||||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
return self(observation, deterministic)
|
||||
|
||||
|
||||
|
|
@ -231,7 +231,7 @@ class Critic(BaseModel):
|
|||
self.add_module(f"qf{i}", qf_net)
|
||||
self.q_networks.append(qf_net)
|
||||
|
||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor:
|
||||
def forward(self, obs: PyTorchObs, action: th.Tensor) -> th.Tensor:
|
||||
# Learn the features extractor using the policy loss only
|
||||
# when the features_extractor is shared with the actor
|
||||
with th.set_grad_enabled(not self.share_features_extractor):
|
||||
|
|
@ -411,10 +411,10 @@ class TQCPolicy(BasePolicy):
|
|||
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||
return Critic(**critic_kwargs).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
return self.actor(observation, deterministic)
|
||||
|
||||
def set_training_mode(self, mode: bool) -> None:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.2.0a9
|
||||
2.2.0a11
|
||||
|
|
|
|||
Loading…
Reference in New Issue