Fixed the return type of ``.load()`` methods (#106)

* Fix return type for learn using TypeVar

* Update changelog

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-10-10 17:21:38 +02:00 committed by GitHub
parent dec7b5303a
commit e9c97948c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 32 additions and 18 deletions

View File

@ -25,6 +25,8 @@ Deprecations:
Others: Others:
^^^^^^^ ^^^^^^^
- Fixed the return type of ``.load()`` methods so that they now use ``TypeVar``
Release 1.6.1 (2022-09-29) Release 1.6.1 (2022-09-29)
------------------------------- -------------------------------

View File

@ -3,7 +3,7 @@ import sys
import time import time
import warnings import warnings
from functools import partial from functools import partial
from typing import Any, Dict, Optional, Type, Union from typing import Any, Dict, Optional, Type, TypeVar, Union
import gym import gym
import numpy as np import numpy as np
@ -19,6 +19,8 @@ from stable_baselines3.common.utils import get_schedule_fn, safe_mean
from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
from sb3_contrib.common.vec_env.async_eval import AsyncEval from sb3_contrib.common.vec_env.async_eval import AsyncEval
ARSSelf = TypeVar("ARSSelf", bound="ARS")
class ARS(BaseAlgorithm): class ARS(BaseAlgorithm):
""" """
@ -302,7 +304,7 @@ class ARS(BaseAlgorithm):
self._n_updates += 1 self._n_updates += 1
def learn( def learn(
self, self: ARSSelf,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
@ -313,7 +315,7 @@ class ARS(BaseAlgorithm):
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
async_eval: Optional[AsyncEval] = None, async_eval: Optional[AsyncEval] = None,
) -> "ARS": ) -> ARSSelf:
""" """
Return a trained model. Return a trained model.

View File

@ -2,7 +2,7 @@ import sys
import time import time
import warnings import warnings
from collections import deque from collections import deque
from typing import Any, Dict, Optional, Tuple, Type, Union from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
import gym import gym
import numpy as np import numpy as np
@ -23,6 +23,8 @@ from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
MaskablePPOSelf = TypeVar("MaskablePPOSelf", bound="MaskablePPO")
class MaskablePPO(OnPolicyAlgorithm): class MaskablePPO(OnPolicyAlgorithm):
""" """
@ -550,7 +552,7 @@ class MaskablePPO(OnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf) self.logger.record("train/clip_range_vf", clip_range_vf)
def learn( def learn(
self, self: MaskablePPOSelf,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
@ -561,7 +563,7 @@ class MaskablePPO(OnPolicyAlgorithm):
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
use_masking: bool = True, use_masking: bool = True,
) -> "MaskablePPO": ) -> MaskablePPOSelf:
iteration = 0 iteration = 0
total_timesteps, callback = self._setup_learn( total_timesteps, callback = self._setup_learn(

View File

@ -1,7 +1,7 @@
import sys import sys
import time import time
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Optional, Tuple, Type, Union from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
import gym import gym
import numpy as np import numpy as np
@ -20,6 +20,8 @@ from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
from sb3_contrib.common.recurrent.type_aliases import RNNStates from sb3_contrib.common.recurrent.type_aliases import RNNStates
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
RecurrentPPOSelf = TypeVar("RecurrentPPOSelf", bound="RecurrentPPO")
class RecurrentPPO(OnPolicyAlgorithm): class RecurrentPPO(OnPolicyAlgorithm):
""" """
@ -488,7 +490,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf) self.logger.record("train/clip_range_vf", clip_range_vf)
def learn( def learn(
self, self: RecurrentPPOSelf,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
@ -498,7 +500,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
tb_log_name: str = "RecurrentPPO", tb_log_name: str = "RecurrentPPO",
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
) -> "RecurrentPPO": ) -> RecurrentPPOSelf:
iteration = 0 iteration = 0
total_timesteps, callback = self._setup_learn( total_timesteps, callback = self._setup_learn(

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym import gym
import numpy as np import numpy as np
@ -13,6 +13,8 @@ from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name
from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
QRDQNSelf = TypeVar("QRDQNSelf", bound="QRDQN")
class QRDQN(OffPolicyAlgorithm): class QRDQN(OffPolicyAlgorithm):
""" """
@ -250,7 +252,7 @@ class QRDQN(OffPolicyAlgorithm):
return action, state return action, state
def learn( def learn(
self, self: QRDQNSelf,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 4, log_interval: int = 4,
@ -260,7 +262,7 @@ class QRDQN(OffPolicyAlgorithm):
tb_log_name: str = "QRDQN", tb_log_name: str = "QRDQN",
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm: ) -> QRDQNSelf:
return super().learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym import gym
import numpy as np import numpy as np
@ -13,6 +13,8 @@ from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
TQCSelf = TypeVar("TQCSelf", bound="TQC")
class TQC(OffPolicyAlgorithm): class TQC(OffPolicyAlgorithm):
""" """
@ -287,7 +289,7 @@ class TQC(OffPolicyAlgorithm):
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn( def learn(
self, self: TQCSelf,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 4, log_interval: int = 4,
@ -297,7 +299,7 @@ class TQC(OffPolicyAlgorithm):
tb_log_name: str = "TQC", tb_log_name: str = "TQC",
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm: ) -> TQCSelf:
return super().learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,

View File

@ -1,7 +1,7 @@
import copy import copy
import warnings import warnings
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
import torch as th import torch as th
@ -17,6 +17,8 @@ from torch.nn import functional as F
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
TRPOSelf = TypeVar("TRPOSelf", bound="TRPO")
class TRPO(OnPolicyAlgorithm): class TRPO(OnPolicyAlgorithm):
""" """
@ -403,7 +405,7 @@ class TRPO(OnPolicyAlgorithm):
return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector
def learn( def learn(
self, self: TRPOSelf,
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
@ -413,7 +415,7 @@ class TRPO(OnPolicyAlgorithm):
tb_log_name: str = "TRPO", tb_log_name: str = "TRPO",
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
) -> OnPolicyAlgorithm: ) -> TRPOSelf:
return super().learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,