Fix `Self` return type (#116)
* Self hint for distributions * ClassSelf to SelfClass
This commit is contained in:
parent
a9735b9f31
commit
36aeae18b5
|
|
@ -19,7 +19,7 @@ from stable_baselines3.common.utils import get_schedule_fn, safe_mean
|
|||
from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
|
||||
from sb3_contrib.common.vec_env.async_eval import AsyncEval
|
||||
|
||||
ARSSelf = TypeVar("ARSSelf", bound="ARS")
|
||||
SelfARS = TypeVar("SelfARS", bound="ARS")
|
||||
|
||||
|
||||
class ARS(BaseAlgorithm):
|
||||
|
|
@ -304,7 +304,7 @@ class ARS(BaseAlgorithm):
|
|||
self._n_updates += 1
|
||||
|
||||
def learn(
|
||||
self: ARSSelf,
|
||||
self: SelfARS,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
|
|
@ -312,7 +312,7 @@ class ARS(BaseAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
async_eval: Optional[AsyncEval] = None,
|
||||
progress_bar: bool = False,
|
||||
) -> ARSSelf:
|
||||
) -> SelfARS:
|
||||
"""
|
||||
Return a trained model.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -9,6 +9,11 @@ from torch import nn
|
|||
from torch.distributions import Categorical
|
||||
from torch.distributions.utils import logits_to_probs
|
||||
|
||||
SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistribution", bound="MaskableCategoricalDistribution")
|
||||
SelfMaskableMultiCategoricalDistribution = TypeVar(
|
||||
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
|
||||
)
|
||||
|
||||
|
||||
class MaskableCategorical(Categorical):
|
||||
"""
|
||||
|
|
@ -115,7 +120,9 @@ class MaskableCategoricalDistribution(MaskableDistribution):
|
|||
action_logits = nn.Linear(latent_dim, self.action_dim)
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableCategoricalDistribution":
|
||||
def proba_distribution(
|
||||
self: SelfMaskableCategoricalDistribution, action_logits: th.Tensor
|
||||
) -> SelfMaskableCategoricalDistribution:
|
||||
# Restructure shape to align with logits
|
||||
reshaped_logits = action_logits.view(-1, self.action_dim)
|
||||
self.distribution = MaskableCategorical(logits=reshaped_logits)
|
||||
|
|
@ -178,7 +185,9 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
|||
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableMultiCategoricalDistribution":
|
||||
def proba_distribution(
|
||||
self: SelfMaskableMultiCategoricalDistribution, action_logits: th.Tensor
|
||||
) -> SelfMaskableMultiCategoricalDistribution:
|
||||
# Restructure shape to align with logits
|
||||
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
|||
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
|
||||
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
MaskablePPOSelf = TypeVar("MaskablePPOSelf", bound="MaskablePPO")
|
||||
SelfMaskablePPO = TypeVar("SelfMaskablePPO", bound="MaskablePPO")
|
||||
|
||||
|
||||
class MaskablePPO(OnPolicyAlgorithm):
|
||||
|
|
@ -498,7 +498,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
self: MaskablePPOSelf,
|
||||
self: SelfMaskablePPO,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
|
|
@ -506,7 +506,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
use_masking: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> MaskablePPOSelf:
|
||||
) -> SelfMaskablePPO:
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
|
|||
from sb3_contrib.common.recurrent.type_aliases import RNNStates
|
||||
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
|
||||
|
||||
RecurrentPPOSelf = TypeVar("RecurrentPPOSelf", bound="RecurrentPPO")
|
||||
SelfRecurrentPPO = TypeVar("SelfRecurrentPPO", bound="RecurrentPPO")
|
||||
|
||||
|
||||
class RecurrentPPO(OnPolicyAlgorithm):
|
||||
|
|
@ -444,14 +444,14 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
self: RecurrentPPOSelf,
|
||||
self: SelfRecurrentPPO,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
tb_log_name: str = "RecurrentPPO",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> RecurrentPPOSelf:
|
||||
) -> SelfRecurrentPPO:
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name
|
|||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
||||
|
||||
QRDQNSelf = TypeVar("QRDQNSelf", bound="QRDQN")
|
||||
SelfQRDQN = TypeVar("SelfQRDQN", bound="QRDQN")
|
||||
|
||||
|
||||
class QRDQN(OffPolicyAlgorithm):
|
||||
|
|
@ -247,14 +247,14 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
return action, state
|
||||
|
||||
def learn(
|
||||
self: QRDQNSelf,
|
||||
self: SelfQRDQN,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "QRDQN",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> QRDQNSelf:
|
||||
) -> SelfQRDQN:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
|||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
|
||||
|
||||
TQCSelf = TypeVar("TQCSelf", bound="TQC")
|
||||
SelfTQC = TypeVar("SelfTQC", bound="TQC")
|
||||
|
||||
|
||||
class TQC(OffPolicyAlgorithm):
|
||||
|
|
@ -284,14 +284,14 @@ class TQC(OffPolicyAlgorithm):
|
|||
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||
|
||||
def learn(
|
||||
self: TQCSelf,
|
||||
self: SelfTQC,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "TQC",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> TQCSelf:
|
||||
) -> SelfTQC:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from torch.nn import functional as F
|
|||
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
|
||||
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
TRPOSelf = TypeVar("TRPOSelf", bound="TRPO")
|
||||
SelfTRPO = TypeVar("SelfTRPO", bound="TRPO")
|
||||
|
||||
|
||||
class TRPO(OnPolicyAlgorithm):
|
||||
|
|
@ -400,14 +400,14 @@ class TRPO(OnPolicyAlgorithm):
|
|||
return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector
|
||||
|
||||
def learn(
|
||||
self: TRPOSelf,
|
||||
self: SelfTRPO,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
tb_log_name: str = "TRPO",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> TRPOSelf:
|
||||
) -> SelfTRPO:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
Loading…
Reference in New Issue