Fixed the return type of ``.load()`` methods (#106)

* Fix return type for learn using TypeVar

* Update changelog

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-10-10 17:21:38 +02:00 committed by GitHub
parent dec7b5303a
commit e9c97948c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 32 additions and 18 deletions

View File

@ -25,6 +25,8 @@ Deprecations:
Others:
^^^^^^^
- Fixed the return type of ``.load()`` methods so that they now use ``TypeVar``
Release 1.6.1 (2022-09-29)
-------------------------------

View File

@ -3,7 +3,7 @@ import sys
import time
import warnings
from functools import partial
from typing import Any, Dict, Optional, Type, Union
from typing import Any, Dict, Optional, Type, TypeVar, Union
import gym
import numpy as np
@ -19,6 +19,8 @@ from stable_baselines3.common.utils import get_schedule_fn, safe_mean
from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
from sb3_contrib.common.vec_env.async_eval import AsyncEval
ARSSelf = TypeVar("ARSSelf", bound="ARS")
class ARS(BaseAlgorithm):
"""
@ -302,7 +304,7 @@ class ARS(BaseAlgorithm):
self._n_updates += 1
def learn(
self,
self: ARSSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
@ -313,7 +315,7 @@ class ARS(BaseAlgorithm):
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
async_eval: Optional[AsyncEval] = None,
) -> "ARS":
) -> ARSSelf:
"""
Return a trained model.

View File

@ -2,7 +2,7 @@ import sys
import time
import warnings
from collections import deque
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
@ -23,6 +23,8 @@ from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
MaskablePPOSelf = TypeVar("MaskablePPOSelf", bound="MaskablePPO")
class MaskablePPO(OnPolicyAlgorithm):
"""
@ -550,7 +552,7 @@ class MaskablePPO(OnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self,
self: MaskablePPOSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
@ -561,7 +563,7 @@ class MaskablePPO(OnPolicyAlgorithm):
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
use_masking: bool = True,
) -> "MaskablePPO":
) -> MaskablePPOSelf:
iteration = 0
total_timesteps, callback = self._setup_learn(

View File

@ -1,7 +1,7 @@
import sys
import time
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
@ -20,6 +20,8 @@ from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
from sb3_contrib.common.recurrent.type_aliases import RNNStates
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
RecurrentPPOSelf = TypeVar("RecurrentPPOSelf", bound="RecurrentPPO")
class RecurrentPPO(OnPolicyAlgorithm):
"""
@ -488,7 +490,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self,
self: RecurrentPPOSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
@ -498,7 +500,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
tb_log_name: str = "RecurrentPPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> "RecurrentPPO":
) -> RecurrentPPOSelf:
iteration = 0
total_timesteps, callback = self._setup_learn(

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
@ -13,6 +13,8 @@ from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
QRDQNSelf = TypeVar("QRDQNSelf", bound="QRDQN")
class QRDQN(OffPolicyAlgorithm):
"""
@ -250,7 +252,7 @@ class QRDQN(OffPolicyAlgorithm):
return action, state
def learn(
self,
self: QRDQNSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
@ -260,7 +262,7 @@ class QRDQN(OffPolicyAlgorithm):
tb_log_name: str = "QRDQN",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
) -> QRDQNSelf:
return super().learn(
total_timesteps=total_timesteps,

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
@ -13,6 +13,8 @@ from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
TQCSelf = TypeVar("TQCSelf", bound="TQC")
class TQC(OffPolicyAlgorithm):
"""
@ -287,7 +289,7 @@ class TQC(OffPolicyAlgorithm):
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn(
self,
self: TQCSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
@ -297,7 +299,7 @@ class TQC(OffPolicyAlgorithm):
tb_log_name: str = "TQC",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
) -> TQCSelf:
return super().learn(
total_timesteps=total_timesteps,

View File

@ -1,7 +1,7 @@
import copy
import warnings
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
@ -17,6 +17,8 @@ from torch.nn import functional as F
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
TRPOSelf = TypeVar("TRPOSelf", bound="TRPO")
class TRPO(OnPolicyAlgorithm):
"""
@ -403,7 +405,7 @@ class TRPO(OnPolicyAlgorithm):
return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector
def learn(
self,
self: TRPOSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
@ -413,7 +415,7 @@ class TRPO(OnPolicyAlgorithm):
tb_log_name: str = "TRPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> OnPolicyAlgorithm:
) -> TRPOSelf:
return super().learn(
total_timesteps=total_timesteps,