From 36aeae18b5c026930b4acf228d86fcfa9c50cfc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 22 Nov 2022 13:12:35 +0100 Subject: [PATCH] Fix `Self` return type (#116) * Self hint for distributions * ClassSelf to SelfClass --- sb3_contrib/ars/ars.py | 6 +++--- sb3_contrib/common/maskable/distributions.py | 15 ++++++++++++--- sb3_contrib/ppo_mask/ppo_mask.py | 6 +++--- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 6 +++--- sb3_contrib/qrdqn/qrdqn.py | 6 +++--- sb3_contrib/tqc/tqc.py | 6 +++--- sb3_contrib/trpo/trpo.py | 6 +++--- 7 files changed, 30 insertions(+), 21 deletions(-) diff --git a/sb3_contrib/ars/ars.py b/sb3_contrib/ars/ars.py index 33a0579..ce73698 100644 --- a/sb3_contrib/ars/ars.py +++ b/sb3_contrib/ars/ars.py @@ -19,7 +19,7 @@ from stable_baselines3.common.utils import get_schedule_fn, safe_mean from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy from sb3_contrib.common.vec_env.async_eval import AsyncEval -ARSSelf = TypeVar("ARSSelf", bound="ARS") +SelfARS = TypeVar("SelfARS", bound="ARS") class ARS(BaseAlgorithm): @@ -304,7 +304,7 @@ class ARS(BaseAlgorithm): self._n_updates += 1 def learn( - self: ARSSelf, + self: SelfARS, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, @@ -312,7 +312,7 @@ class ARS(BaseAlgorithm): reset_num_timesteps: bool = True, async_eval: Optional[AsyncEval] = None, progress_bar: bool = False, - ) -> ARSSelf: + ) -> SelfARS: """ Return a trained model. diff --git a/sb3_contrib/common/maskable/distributions.py b/sb3_contrib/common/maskable/distributions.py index 1ccf5ee..a89d9fb 100644 --- a/sb3_contrib/common/maskable/distributions.py +++ b/sb3_contrib/common/maskable/distributions.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, TypeVar import numpy as np import torch as th @@ -9,6 +9,11 @@ from torch import nn from torch.distributions import Categorical from torch.distributions.utils import logits_to_probs +SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistribution", bound="MaskableCategoricalDistribution") +SelfMaskableMultiCategoricalDistribution = TypeVar( + "SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution" +) + class MaskableCategorical(Categorical): """ @@ -115,7 +120,9 @@ class MaskableCategoricalDistribution(MaskableDistribution): action_logits = nn.Linear(latent_dim, self.action_dim) return action_logits - def proba_distribution(self, action_logits: th.Tensor) -> "MaskableCategoricalDistribution": + def proba_distribution( + self: SelfMaskableCategoricalDistribution, action_logits: th.Tensor + ) -> SelfMaskableCategoricalDistribution: # Restructure shape to align with logits reshaped_logits = action_logits.view(-1, self.action_dim) self.distribution = MaskableCategorical(logits=reshaped_logits) @@ -178,7 +185,9 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution): action_logits = nn.Linear(latent_dim, sum(self.action_dims)) return action_logits - def proba_distribution(self, action_logits: th.Tensor) -> "MaskableMultiCategoricalDistribution": + def proba_distribution( + self: SelfMaskableMultiCategoricalDistribution, action_logits: th.Tensor + ) -> SelfMaskableMultiCategoricalDistribution: # Restructure shape to align with logits reshaped_logits = action_logits.view(-1, sum(self.action_dims)) diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 769461d..78aa58f 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -22,7 +22,7 @@ from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy -MaskablePPOSelf = TypeVar("MaskablePPOSelf", bound="MaskablePPO") +SelfMaskablePPO = TypeVar("SelfMaskablePPO", bound="MaskablePPO") class MaskablePPO(OnPolicyAlgorithm): @@ -498,7 +498,7 @@ class MaskablePPO(OnPolicyAlgorithm): self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self: MaskablePPOSelf, + self: SelfMaskablePPO, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, @@ -506,7 +506,7 @@ class MaskablePPO(OnPolicyAlgorithm): reset_num_timesteps: bool = True, use_masking: bool = True, progress_bar: bool = False, - ) -> MaskablePPOSelf: + ) -> SelfMaskablePPO: iteration = 0 total_timesteps, callback = self._setup_learn( diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 7ee0f55..12a8c7f 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -20,7 +20,7 @@ from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy from sb3_contrib.common.recurrent.type_aliases import RNNStates from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy -RecurrentPPOSelf = TypeVar("RecurrentPPOSelf", bound="RecurrentPPO") +SelfRecurrentPPO = TypeVar("SelfRecurrentPPO", bound="RecurrentPPO") class RecurrentPPO(OnPolicyAlgorithm): @@ -444,14 +444,14 @@ class RecurrentPPO(OnPolicyAlgorithm): self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self: RecurrentPPOSelf, + self: SelfRecurrentPPO, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "RecurrentPPO", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> RecurrentPPOSelf: + ) -> SelfRecurrentPPO: iteration = 0 total_timesteps, callback = self._setup_learn( diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index d4aef00..22c5810 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -13,7 +13,7 @@ from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy -QRDQNSelf = TypeVar("QRDQNSelf", bound="QRDQN") +SelfQRDQN = TypeVar("SelfQRDQN", bound="QRDQN") class QRDQN(OffPolicyAlgorithm): @@ -247,14 +247,14 @@ class QRDQN(OffPolicyAlgorithm): return action, state def learn( - self: QRDQNSelf, + self: SelfQRDQN, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "QRDQN", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> QRDQNSelf: + ) -> SelfQRDQN: return super().learn( total_timesteps=total_timesteps, diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 7f95baa..6ec91a1 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -13,7 +13,7 @@ from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy -TQCSelf = TypeVar("TQCSelf", bound="TQC") +SelfTQC = TypeVar("SelfTQC", bound="TQC") class TQC(OffPolicyAlgorithm): @@ -284,14 +284,14 @@ class TQC(OffPolicyAlgorithm): self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) def learn( - self: TQCSelf, + self: SelfTQC, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "TQC", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> TQCSelf: + ) -> SelfTQC: return super().learn( total_timesteps=total_timesteps, diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index d8ae9be..eee1f98 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -17,7 +17,7 @@ from torch.nn import functional as F from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy -TRPOSelf = TypeVar("TRPOSelf", bound="TRPO") +SelfTRPO = TypeVar("SelfTRPO", bound="TRPO") class TRPO(OnPolicyAlgorithm): @@ -400,14 +400,14 @@ class TRPO(OnPolicyAlgorithm): return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector def learn( - self: TRPOSelf, + self: SelfTRPO, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "TRPO", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> TRPOSelf: + ) -> SelfTRPO: return super().learn( total_timesteps=total_timesteps,