From e9c97948c8d1a3487f859755f5bac904c7e7c37e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 10 Oct 2022 17:21:38 +0200 Subject: [PATCH] Fixed the return type of ``.load()`` methods (#106) * Fix return type for learn using TypeVar * Update changelog Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 2 ++ sb3_contrib/ars/ars.py | 8 +++++--- sb3_contrib/ppo_mask/ppo_mask.py | 8 +++++--- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 8 +++++--- sb3_contrib/qrdqn/qrdqn.py | 8 +++++--- sb3_contrib/tqc/tqc.py | 8 +++++--- sb3_contrib/trpo/trpo.py | 8 +++++--- 7 files changed, 32 insertions(+), 18 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 24c1295..bb1f7b0 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -25,6 +25,8 @@ Deprecations: Others: ^^^^^^^ +- Fixed the return type of ``.load()`` methods so that they now use ``TypeVar`` + Release 1.6.1 (2022-09-29) ------------------------------- diff --git a/sb3_contrib/ars/ars.py b/sb3_contrib/ars/ars.py index a5bc1a1..241a814 100644 --- a/sb3_contrib/ars/ars.py +++ b/sb3_contrib/ars/ars.py @@ -3,7 +3,7 @@ import sys import time import warnings from functools import partial -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, TypeVar, Union import gym import numpy as np @@ -19,6 +19,8 @@ from stable_baselines3.common.utils import get_schedule_fn, safe_mean from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy from sb3_contrib.common.vec_env.async_eval import AsyncEval +ARSSelf = TypeVar("ARSSelf", bound="ARS") + class ARS(BaseAlgorithm): """ @@ -302,7 +304,7 @@ class ARS(BaseAlgorithm): self._n_updates += 1 def learn( - self, + self: ARSSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, @@ -313,7 +315,7 @@ class ARS(BaseAlgorithm): eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, async_eval: Optional[AsyncEval] = None, - ) -> "ARS": + ) -> ARSSelf: """ Return a trained model. diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 86e8d21..0cedc4c 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -2,7 +2,7 @@ import sys import time import warnings from collections import deque -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -23,6 +23,8 @@ from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +MaskablePPOSelf = TypeVar("MaskablePPOSelf", bound="MaskablePPO") + class MaskablePPO(OnPolicyAlgorithm): """ @@ -550,7 +552,7 @@ class MaskablePPO(OnPolicyAlgorithm): self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self, + self: MaskablePPOSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, @@ -561,7 +563,7 @@ class MaskablePPO(OnPolicyAlgorithm): eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, use_masking: bool = True, - ) -> "MaskablePPO": + ) -> MaskablePPOSelf: iteration = 0 total_timesteps, callback = self._setup_learn( diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 5580fde..68d501f 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -1,7 +1,7 @@ import sys import time from copy import deepcopy -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -20,6 +20,8 @@ from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy from sb3_contrib.common.recurrent.type_aliases import RNNStates from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy +RecurrentPPOSelf = TypeVar("RecurrentPPOSelf", bound="RecurrentPPO") + class RecurrentPPO(OnPolicyAlgorithm): """ @@ -488,7 +490,7 @@ class RecurrentPPO(OnPolicyAlgorithm): self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self, + self: RecurrentPPOSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, @@ -498,7 +500,7 @@ class RecurrentPPO(OnPolicyAlgorithm): tb_log_name: str = "RecurrentPPO", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "RecurrentPPO": + ) -> RecurrentPPOSelf: iteration = 0 total_timesteps, callback = self._setup_learn( diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 67242a3..38f32ff 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -13,6 +13,8 @@ from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy +QRDQNSelf = TypeVar("QRDQNSelf", bound="QRDQN") + class QRDQN(OffPolicyAlgorithm): """ @@ -250,7 +252,7 @@ class QRDQN(OffPolicyAlgorithm): return action, state def learn( - self, + self: QRDQNSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -260,7 +262,7 @@ class QRDQN(OffPolicyAlgorithm): tb_log_name: str = "QRDQN", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> OffPolicyAlgorithm: + ) -> QRDQNSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index ed2f389..b38b53c 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -13,6 +13,8 @@ from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy +TQCSelf = TypeVar("TQCSelf", bound="TQC") + class TQC(OffPolicyAlgorithm): """ @@ -287,7 +289,7 @@ class TQC(OffPolicyAlgorithm): self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) def learn( - self, + self: TQCSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -297,7 +299,7 @@ class TQC(OffPolicyAlgorithm): tb_log_name: str = "TQC", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> OffPolicyAlgorithm: + ) -> TQCSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index b15a869..a66d9e1 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -1,7 +1,7 @@ import copy import warnings from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch as th @@ -17,6 +17,8 @@ from torch.nn import functional as F from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +TRPOSelf = TypeVar("TRPOSelf", bound="TRPO") + class TRPO(OnPolicyAlgorithm): """ @@ -403,7 +405,7 @@ class TRPO(OnPolicyAlgorithm): return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector def learn( - self, + self: TRPOSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, @@ -413,7 +415,7 @@ class TRPO(OnPolicyAlgorithm): tb_log_name: str = "TRPO", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> OnPolicyAlgorithm: + ) -> TRPOSelf: return super().learn( total_timesteps=total_timesteps,