Fix `Self` return type (#116)
* Self hint for distributions * ClassSelf to SelfClass
This commit is contained in:
parent
a9735b9f31
commit
36aeae18b5
|
|
@ -19,7 +19,7 @@ from stable_baselines3.common.utils import get_schedule_fn, safe_mean
|
||||||
from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
|
from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
|
||||||
from sb3_contrib.common.vec_env.async_eval import AsyncEval
|
from sb3_contrib.common.vec_env.async_eval import AsyncEval
|
||||||
|
|
||||||
ARSSelf = TypeVar("ARSSelf", bound="ARS")
|
SelfARS = TypeVar("SelfARS", bound="ARS")
|
||||||
|
|
||||||
|
|
||||||
class ARS(BaseAlgorithm):
|
class ARS(BaseAlgorithm):
|
||||||
|
|
@ -304,7 +304,7 @@ class ARS(BaseAlgorithm):
|
||||||
self._n_updates += 1
|
self._n_updates += 1
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self: ARSSelf,
|
self: SelfARS,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 1,
|
log_interval: int = 1,
|
||||||
|
|
@ -312,7 +312,7 @@ class ARS(BaseAlgorithm):
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
async_eval: Optional[AsyncEval] = None,
|
async_eval: Optional[AsyncEval] = None,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> ARSSelf:
|
) -> SelfARS:
|
||||||
"""
|
"""
|
||||||
Return a trained model.
|
Return a trained model.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
|
|
@ -9,6 +9,11 @@ from torch import nn
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
from torch.distributions.utils import logits_to_probs
|
from torch.distributions.utils import logits_to_probs
|
||||||
|
|
||||||
|
SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistribution", bound="MaskableCategoricalDistribution")
|
||||||
|
SelfMaskableMultiCategoricalDistribution = TypeVar(
|
||||||
|
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskableCategorical(Categorical):
|
class MaskableCategorical(Categorical):
|
||||||
"""
|
"""
|
||||||
|
|
@ -115,7 +120,9 @@ class MaskableCategoricalDistribution(MaskableDistribution):
|
||||||
action_logits = nn.Linear(latent_dim, self.action_dim)
|
action_logits = nn.Linear(latent_dim, self.action_dim)
|
||||||
return action_logits
|
return action_logits
|
||||||
|
|
||||||
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableCategoricalDistribution":
|
def proba_distribution(
|
||||||
|
self: SelfMaskableCategoricalDistribution, action_logits: th.Tensor
|
||||||
|
) -> SelfMaskableCategoricalDistribution:
|
||||||
# Restructure shape to align with logits
|
# Restructure shape to align with logits
|
||||||
reshaped_logits = action_logits.view(-1, self.action_dim)
|
reshaped_logits = action_logits.view(-1, self.action_dim)
|
||||||
self.distribution = MaskableCategorical(logits=reshaped_logits)
|
self.distribution = MaskableCategorical(logits=reshaped_logits)
|
||||||
|
|
@ -178,7 +185,9 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
||||||
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
|
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
|
||||||
return action_logits
|
return action_logits
|
||||||
|
|
||||||
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableMultiCategoricalDistribution":
|
def proba_distribution(
|
||||||
|
self: SelfMaskableMultiCategoricalDistribution, action_logits: th.Tensor
|
||||||
|
) -> SelfMaskableMultiCategoricalDistribution:
|
||||||
# Restructure shape to align with logits
|
# Restructure shape to align with logits
|
||||||
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
|
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
||||||
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
|
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
|
||||||
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
|
|
||||||
MaskablePPOSelf = TypeVar("MaskablePPOSelf", bound="MaskablePPO")
|
SelfMaskablePPO = TypeVar("SelfMaskablePPO", bound="MaskablePPO")
|
||||||
|
|
||||||
|
|
||||||
class MaskablePPO(OnPolicyAlgorithm):
|
class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
@ -498,7 +498,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self: MaskablePPOSelf,
|
self: SelfMaskablePPO,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 1,
|
log_interval: int = 1,
|
||||||
|
|
@ -506,7 +506,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
use_masking: bool = True,
|
use_masking: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> MaskablePPOSelf:
|
) -> SelfMaskablePPO:
|
||||||
iteration = 0
|
iteration = 0
|
||||||
|
|
||||||
total_timesteps, callback = self._setup_learn(
|
total_timesteps, callback = self._setup_learn(
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
|
||||||
from sb3_contrib.common.recurrent.type_aliases import RNNStates
|
from sb3_contrib.common.recurrent.type_aliases import RNNStates
|
||||||
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
|
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
|
||||||
|
|
||||||
RecurrentPPOSelf = TypeVar("RecurrentPPOSelf", bound="RecurrentPPO")
|
SelfRecurrentPPO = TypeVar("SelfRecurrentPPO", bound="RecurrentPPO")
|
||||||
|
|
||||||
|
|
||||||
class RecurrentPPO(OnPolicyAlgorithm):
|
class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
|
|
@ -444,14 +444,14 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self: RecurrentPPOSelf,
|
self: SelfRecurrentPPO,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 1,
|
log_interval: int = 1,
|
||||||
tb_log_name: str = "RecurrentPPO",
|
tb_log_name: str = "RecurrentPPO",
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> RecurrentPPOSelf:
|
) -> SelfRecurrentPPO:
|
||||||
iteration = 0
|
iteration = 0
|
||||||
|
|
||||||
total_timesteps, callback = self._setup_learn(
|
total_timesteps, callback = self._setup_learn(
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name
|
||||||
from sb3_contrib.common.utils import quantile_huber_loss
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
||||||
|
|
||||||
QRDQNSelf = TypeVar("QRDQNSelf", bound="QRDQN")
|
SelfQRDQN = TypeVar("SelfQRDQN", bound="QRDQN")
|
||||||
|
|
||||||
|
|
||||||
class QRDQN(OffPolicyAlgorithm):
|
class QRDQN(OffPolicyAlgorithm):
|
||||||
|
|
@ -247,14 +247,14 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
return action, state
|
return action, state
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self: QRDQNSelf,
|
self: SelfQRDQN,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 4,
|
log_interval: int = 4,
|
||||||
tb_log_name: str = "QRDQN",
|
tb_log_name: str = "QRDQN",
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> QRDQNSelf:
|
) -> SelfQRDQN:
|
||||||
|
|
||||||
return super().learn(
|
return super().learn(
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||||
from sb3_contrib.common.utils import quantile_huber_loss
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
|
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
|
||||||
|
|
||||||
TQCSelf = TypeVar("TQCSelf", bound="TQC")
|
SelfTQC = TypeVar("SelfTQC", bound="TQC")
|
||||||
|
|
||||||
|
|
||||||
class TQC(OffPolicyAlgorithm):
|
class TQC(OffPolicyAlgorithm):
|
||||||
|
|
@ -284,14 +284,14 @@ class TQC(OffPolicyAlgorithm):
|
||||||
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self: TQCSelf,
|
self: SelfTQC,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 4,
|
log_interval: int = 4,
|
||||||
tb_log_name: str = "TQC",
|
tb_log_name: str = "TQC",
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> TQCSelf:
|
) -> SelfTQC:
|
||||||
|
|
||||||
return super().learn(
|
return super().learn(
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from torch.nn import functional as F
|
||||||
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
|
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
|
||||||
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
|
|
||||||
TRPOSelf = TypeVar("TRPOSelf", bound="TRPO")
|
SelfTRPO = TypeVar("SelfTRPO", bound="TRPO")
|
||||||
|
|
||||||
|
|
||||||
class TRPO(OnPolicyAlgorithm):
|
class TRPO(OnPolicyAlgorithm):
|
||||||
|
|
@ -400,14 +400,14 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector
|
return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self: TRPOSelf,
|
self: SelfTRPO,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 1,
|
log_interval: int = 1,
|
||||||
tb_log_name: str = "TRPO",
|
tb_log_name: str = "TRPO",
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> TRPOSelf:
|
) -> SelfTRPO:
|
||||||
|
|
||||||
return super().learn(
|
return super().learn(
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue