Fix `Self` return type (#116)

* Self hint for distributions

* ClassSelf to SelfClass
This commit is contained in:
Quentin Gallouédec 2022-11-22 13:12:35 +01:00 committed by GitHub
parent a9735b9f31
commit 36aeae18b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 30 additions and 21 deletions

View File

@ -19,7 +19,7 @@ from stable_baselines3.common.utils import get_schedule_fn, safe_mean
from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
from sb3_contrib.common.vec_env.async_eval import AsyncEval from sb3_contrib.common.vec_env.async_eval import AsyncEval
ARSSelf = TypeVar("ARSSelf", bound="ARS") SelfARS = TypeVar("SelfARS", bound="ARS")
class ARS(BaseAlgorithm): class ARS(BaseAlgorithm):
@ -304,7 +304,7 @@ class ARS(BaseAlgorithm):
self._n_updates += 1 self._n_updates += 1
def learn( def learn(
self: ARSSelf, self: SelfARS,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
@ -312,7 +312,7 @@ class ARS(BaseAlgorithm):
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
async_eval: Optional[AsyncEval] = None, async_eval: Optional[AsyncEval] = None,
progress_bar: bool = False, progress_bar: bool = False,
) -> ARSSelf: ) -> SelfARS:
""" """
Return a trained model. Return a trained model.

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, TypeVar
import numpy as np import numpy as np
import torch as th import torch as th
@ -9,6 +9,11 @@ from torch import nn
from torch.distributions import Categorical from torch.distributions import Categorical
from torch.distributions.utils import logits_to_probs from torch.distributions.utils import logits_to_probs
SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistribution", bound="MaskableCategoricalDistribution")
SelfMaskableMultiCategoricalDistribution = TypeVar(
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
)
class MaskableCategorical(Categorical): class MaskableCategorical(Categorical):
""" """
@ -115,7 +120,9 @@ class MaskableCategoricalDistribution(MaskableDistribution):
action_logits = nn.Linear(latent_dim, self.action_dim) action_logits = nn.Linear(latent_dim, self.action_dim)
return action_logits return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableCategoricalDistribution": def proba_distribution(
self: SelfMaskableCategoricalDistribution, action_logits: th.Tensor
) -> SelfMaskableCategoricalDistribution:
# Restructure shape to align with logits # Restructure shape to align with logits
reshaped_logits = action_logits.view(-1, self.action_dim) reshaped_logits = action_logits.view(-1, self.action_dim)
self.distribution = MaskableCategorical(logits=reshaped_logits) self.distribution = MaskableCategorical(logits=reshaped_logits)
@ -178,7 +185,9 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
action_logits = nn.Linear(latent_dim, sum(self.action_dims)) action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableMultiCategoricalDistribution": def proba_distribution(
self: SelfMaskableMultiCategoricalDistribution, action_logits: th.Tensor
) -> SelfMaskableMultiCategoricalDistribution:
# Restructure shape to align with logits # Restructure shape to align with logits
reshaped_logits = action_logits.view(-1, sum(self.action_dims)) reshaped_logits = action_logits.view(-1, sum(self.action_dims))

View File

@ -22,7 +22,7 @@ from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
MaskablePPOSelf = TypeVar("MaskablePPOSelf", bound="MaskablePPO") SelfMaskablePPO = TypeVar("SelfMaskablePPO", bound="MaskablePPO")
class MaskablePPO(OnPolicyAlgorithm): class MaskablePPO(OnPolicyAlgorithm):
@ -498,7 +498,7 @@ class MaskablePPO(OnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf) self.logger.record("train/clip_range_vf", clip_range_vf)
def learn( def learn(
self: MaskablePPOSelf, self: SelfMaskablePPO,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
@ -506,7 +506,7 @@ class MaskablePPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
use_masking: bool = True, use_masking: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> MaskablePPOSelf: ) -> SelfMaskablePPO:
iteration = 0 iteration = 0
total_timesteps, callback = self._setup_learn( total_timesteps, callback = self._setup_learn(

View File

@ -20,7 +20,7 @@ from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
from sb3_contrib.common.recurrent.type_aliases import RNNStates from sb3_contrib.common.recurrent.type_aliases import RNNStates
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
RecurrentPPOSelf = TypeVar("RecurrentPPOSelf", bound="RecurrentPPO") SelfRecurrentPPO = TypeVar("SelfRecurrentPPO", bound="RecurrentPPO")
class RecurrentPPO(OnPolicyAlgorithm): class RecurrentPPO(OnPolicyAlgorithm):
@ -444,14 +444,14 @@ class RecurrentPPO(OnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf) self.logger.record("train/clip_range_vf", clip_range_vf)
def learn( def learn(
self: RecurrentPPOSelf, self: SelfRecurrentPPO,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
tb_log_name: str = "RecurrentPPO", tb_log_name: str = "RecurrentPPO",
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> RecurrentPPOSelf: ) -> SelfRecurrentPPO:
iteration = 0 iteration = 0
total_timesteps, callback = self._setup_learn( total_timesteps, callback = self._setup_learn(

View File

@ -13,7 +13,7 @@ from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name
from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
QRDQNSelf = TypeVar("QRDQNSelf", bound="QRDQN") SelfQRDQN = TypeVar("SelfQRDQN", bound="QRDQN")
class QRDQN(OffPolicyAlgorithm): class QRDQN(OffPolicyAlgorithm):
@ -247,14 +247,14 @@ class QRDQN(OffPolicyAlgorithm):
return action, state return action, state
def learn( def learn(
self: QRDQNSelf, self: SelfQRDQN,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 4, log_interval: int = 4,
tb_log_name: str = "QRDQN", tb_log_name: str = "QRDQN",
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> QRDQNSelf: ) -> SelfQRDQN:
return super().learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,

View File

@ -13,7 +13,7 @@ from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
TQCSelf = TypeVar("TQCSelf", bound="TQC") SelfTQC = TypeVar("SelfTQC", bound="TQC")
class TQC(OffPolicyAlgorithm): class TQC(OffPolicyAlgorithm):
@ -284,14 +284,14 @@ class TQC(OffPolicyAlgorithm):
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn( def learn(
self: TQCSelf, self: SelfTQC,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 4, log_interval: int = 4,
tb_log_name: str = "TQC", tb_log_name: str = "TQC",
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> TQCSelf: ) -> SelfTQC:
return super().learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,

View File

@ -17,7 +17,7 @@ from torch.nn import functional as F
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
TRPOSelf = TypeVar("TRPOSelf", bound="TRPO") SelfTRPO = TypeVar("SelfTRPO", bound="TRPO")
class TRPO(OnPolicyAlgorithm): class TRPO(OnPolicyAlgorithm):
@ -400,14 +400,14 @@ class TRPO(OnPolicyAlgorithm):
return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector
def learn( def learn(
self: TRPOSelf, self: SelfTRPO,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
tb_log_name: str = "TRPO", tb_log_name: str = "TRPO",
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> TRPOSelf: ) -> SelfTRPO:
return super().learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,