Fix `Self` return type (#116)

* Self hint for distributions

* ClassSelf to SelfClass
This commit is contained in:
Quentin Gallouédec 2022-11-22 13:12:35 +01:00 committed by GitHub
parent a9735b9f31
commit 36aeae18b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 30 additions and 21 deletions

View File

@ -19,7 +19,7 @@ from stable_baselines3.common.utils import get_schedule_fn, safe_mean
from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
from sb3_contrib.common.vec_env.async_eval import AsyncEval
ARSSelf = TypeVar("ARSSelf", bound="ARS")
SelfARS = TypeVar("SelfARS", bound="ARS")
class ARS(BaseAlgorithm):
@ -304,7 +304,7 @@ class ARS(BaseAlgorithm):
self._n_updates += 1
def learn(
self: ARSSelf,
self: SelfARS,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
@ -312,7 +312,7 @@ class ARS(BaseAlgorithm):
reset_num_timesteps: bool = True,
async_eval: Optional[AsyncEval] = None,
progress_bar: bool = False,
) -> ARSSelf:
) -> SelfARS:
"""
Return a trained model.

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, TypeVar
import numpy as np
import torch as th
@ -9,6 +9,11 @@ from torch import nn
from torch.distributions import Categorical
from torch.distributions.utils import logits_to_probs
SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistribution", bound="MaskableCategoricalDistribution")
SelfMaskableMultiCategoricalDistribution = TypeVar(
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
)
class MaskableCategorical(Categorical):
"""
@ -115,7 +120,9 @@ class MaskableCategoricalDistribution(MaskableDistribution):
action_logits = nn.Linear(latent_dim, self.action_dim)
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableCategoricalDistribution":
def proba_distribution(
self: SelfMaskableCategoricalDistribution, action_logits: th.Tensor
) -> SelfMaskableCategoricalDistribution:
# Restructure shape to align with logits
reshaped_logits = action_logits.view(-1, self.action_dim)
self.distribution = MaskableCategorical(logits=reshaped_logits)
@ -178,7 +185,9 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableMultiCategoricalDistribution":
def proba_distribution(
self: SelfMaskableMultiCategoricalDistribution, action_logits: th.Tensor
) -> SelfMaskableMultiCategoricalDistribution:
# Restructure shape to align with logits
reshaped_logits = action_logits.view(-1, sum(self.action_dims))

View File

@ -22,7 +22,7 @@ from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
MaskablePPOSelf = TypeVar("MaskablePPOSelf", bound="MaskablePPO")
SelfMaskablePPO = TypeVar("SelfMaskablePPO", bound="MaskablePPO")
class MaskablePPO(OnPolicyAlgorithm):
@ -498,7 +498,7 @@ class MaskablePPO(OnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self: MaskablePPOSelf,
self: SelfMaskablePPO,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
@ -506,7 +506,7 @@ class MaskablePPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True,
use_masking: bool = True,
progress_bar: bool = False,
) -> MaskablePPOSelf:
) -> SelfMaskablePPO:
iteration = 0
total_timesteps, callback = self._setup_learn(

View File

@ -20,7 +20,7 @@ from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
from sb3_contrib.common.recurrent.type_aliases import RNNStates
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
RecurrentPPOSelf = TypeVar("RecurrentPPOSelf", bound="RecurrentPPO")
SelfRecurrentPPO = TypeVar("SelfRecurrentPPO", bound="RecurrentPPO")
class RecurrentPPO(OnPolicyAlgorithm):
@ -444,14 +444,14 @@ class RecurrentPPO(OnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self: RecurrentPPOSelf,
self: SelfRecurrentPPO,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "RecurrentPPO",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> RecurrentPPOSelf:
) -> SelfRecurrentPPO:
iteration = 0
total_timesteps, callback = self._setup_learn(

View File

@ -13,7 +13,7 @@ from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
QRDQNSelf = TypeVar("QRDQNSelf", bound="QRDQN")
SelfQRDQN = TypeVar("SelfQRDQN", bound="QRDQN")
class QRDQN(OffPolicyAlgorithm):
@ -247,14 +247,14 @@ class QRDQN(OffPolicyAlgorithm):
return action, state
def learn(
self: QRDQNSelf,
self: SelfQRDQN,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "QRDQN",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> QRDQNSelf:
) -> SelfQRDQN:
return super().learn(
total_timesteps=total_timesteps,

View File

@ -13,7 +13,7 @@ from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
TQCSelf = TypeVar("TQCSelf", bound="TQC")
SelfTQC = TypeVar("SelfTQC", bound="TQC")
class TQC(OffPolicyAlgorithm):
@ -284,14 +284,14 @@ class TQC(OffPolicyAlgorithm):
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn(
self: TQCSelf,
self: SelfTQC,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "TQC",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> TQCSelf:
) -> SelfTQC:
return super().learn(
total_timesteps=total_timesteps,

View File

@ -17,7 +17,7 @@ from torch.nn import functional as F
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
TRPOSelf = TypeVar("TRPOSelf", bound="TRPO")
SelfTRPO = TypeVar("SelfTRPO", bound="TRPO")
class TRPO(OnPolicyAlgorithm):
@ -400,14 +400,14 @@ class TRPO(OnPolicyAlgorithm):
return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector
def learn(
self: TRPOSelf,
self: SelfTRPO,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "TRPO",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> TRPOSelf:
) -> SelfTRPO:
return super().learn(
total_timesteps=total_timesteps,