diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9642924..048fb2a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -58,8 +58,6 @@ jobs: - name: Type check run: | make type - # skip PyType, doesn't support 3.11 yet - if: "!(matrix.python-version == '3.11')" - name: Test with pytest run: | make pytest diff --git a/Makefile b/Makefile index fab99ff..7f10123 100644 --- a/Makefile +++ b/Makefile @@ -4,13 +4,10 @@ LINT_PATHS=sb3_contrib/ tests/ setup.py docs/conf.py pytest: ./scripts/run_tests.sh -pytype: - pytype -j auto - mypy: mypy ${LINT_PATHS} -type: pytype mypy +type: mypy lint: # stop the build if there are Python syntax errors or undefined names diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index be39952..8f555cb 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 2.2.0a9 (WIP) +Release 2.2.0a11 (WIP) -------------------------- Breaking Changes: diff --git a/pyproject.toml b/pyproject.toml index 925931c..127d2b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,9 +19,6 @@ max-complexity = 15 [tool.black] line-length = 127 -[tool.pytype] -inputs = ['sb3_contrib'] - [tool.mypy] ignore_missing_imports = true follow_imports = "silent" diff --git a/sb3_contrib/ars/policies.py b/sb3_contrib/ars/policies.py index 8267cd3..2c3aeec 100644 --- a/sb3_contrib/ars/policies.py +++ b/sb3_contrib/ars/policies.py @@ -4,7 +4,8 @@ import torch as th from gymnasium import spaces from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp +from stable_baselines3.common.torch_layers import create_mlp +from stable_baselines3.common.type_aliases import PyTorchObs from torch import nn @@ -66,10 +67,7 @@ class ARSPolicy(BasePolicy): ) return data - def forward(self, obs: th.Tensor) -> th.Tensor: - # Make mypy happy: - assert isinstance(self.features_extractor, BaseFeaturesExtractor) - + def forward(self, obs: PyTorchObs) -> th.Tensor: features = self.extract_features(obs, self.features_extractor) if isinstance(self.action_space, spaces.Box): return self.action_net(features) @@ -79,7 +77,7 @@ class ARSPolicy(BasePolicy): else: raise NotImplementedError() - def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor: # Non deterministic action does not really make sense for ARS, we ignore this parameter for now.. return self(observation) diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index a58e47b..07752cd 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -10,7 +10,7 @@ from stable_baselines3.common.torch_layers import ( NatureCNN, create_mlp, ) -from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.type_aliases import PyTorchObs, Schedule from torch import nn @@ -58,7 +58,7 @@ class QuantileNetwork(BasePolicy): quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn) self.quantile_net = nn.Sequential(*quantile_net) - def forward(self, obs: th.Tensor) -> th.Tensor: + def forward(self, obs: PyTorchObs) -> th.Tensor: """ Predict the quantiles. @@ -68,7 +68,7 @@ class QuantileNetwork(BasePolicy): quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor)) return quantiles.view(-1, self.n_quantiles, int(self.action_space.n)) - def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor: q_values = self(observation).mean(dim=1) # Greedy action action = q_values.argmax(dim=1).reshape(-1) @@ -181,10 +181,10 @@ class QRDQNPolicy(BasePolicy): net_args = self._update_features_extractor(self.net_args, features_extractor=None) return QuantileNetwork(**net_args).to(self.device) - def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: + def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: return self._predict(obs, deterministic=deterministic) - def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: + def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: return self.quantile_net._predict(obs, deterministic=deterministic) def _get_constructor_parameters(self) -> Dict[str, Any]: diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index 9913d80..5f1d55b 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import ( create_mlp, get_actor_critic_arch, ) -from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.type_aliases import PyTorchObs, Schedule from torch import nn as nn # CAP the standard deviation of the actor @@ -143,7 +143,7 @@ class Actor(BasePolicy): assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg self.action_dist.sample_weights(self.log_std, batch_size=batch_size) - def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: + def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: """ Get the parameters for the action distribution. @@ -163,17 +163,17 @@ class Actor(BasePolicy): log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) return mean_actions, log_std, {} - def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # Note: the action is squashed return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) - def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # return action and associated log prob return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self(observation, deterministic) @@ -231,7 +231,7 @@ class Critic(BaseModel): self.add_module(f"qf{i}", qf_net) self.q_networks.append(qf_net) - def forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor: + def forward(self, obs: PyTorchObs, action: th.Tensor) -> th.Tensor: # Learn the features extractor using the policy loss only # when the features_extractor is shared with the actor with th.set_grad_enabled(not self.share_features_extractor): @@ -411,10 +411,10 @@ class TQCPolicy(BasePolicy): critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) return Critic(**critic_kwargs).to(self.device) - def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self._predict(obs, deterministic=deterministic) - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self.actor(observation, deterministic) def set_training_mode(self, mode: bool) -> None: diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index b7120ad..13ce6d7 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.2.0a9 +2.2.0a11