Fixed the return type of ``.load()`` methods (#106)
* Fix return type for learn using TypeVar * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
dec7b5303a
commit
e9c97948c8
|
|
@ -25,6 +25,8 @@ Deprecations:
|
|||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Fixed the return type of ``.load()`` methods so that they now use ``TypeVar``
|
||||
|
||||
|
||||
Release 1.6.1 (2022-09-29)
|
||||
-------------------------------
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import sys
|
|||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -19,6 +19,8 @@ from stable_baselines3.common.utils import get_schedule_fn, safe_mean
|
|||
from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
|
||||
from sb3_contrib.common.vec_env.async_eval import AsyncEval
|
||||
|
||||
ARSSelf = TypeVar("ARSSelf", bound="ARS")
|
||||
|
||||
|
||||
class ARS(BaseAlgorithm):
|
||||
"""
|
||||
|
|
@ -302,7 +304,7 @@ class ARS(BaseAlgorithm):
|
|||
self._n_updates += 1
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: ARSSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
|
|
@ -313,7 +315,7 @@ class ARS(BaseAlgorithm):
|
|||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
async_eval: Optional[AsyncEval] = None,
|
||||
) -> "ARS":
|
||||
) -> ARSSelf:
|
||||
"""
|
||||
Return a trained model.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import sys
|
|||
import time
|
||||
import warnings
|
||||
from collections import deque
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -23,6 +23,8 @@ from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
|||
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
|
||||
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
MaskablePPOSelf = TypeVar("MaskablePPOSelf", bound="MaskablePPO")
|
||||
|
||||
|
||||
class MaskablePPO(OnPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -550,7 +552,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: MaskablePPOSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
|
|
@ -561,7 +563,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
use_masking: bool = True,
|
||||
) -> "MaskablePPO":
|
||||
) -> MaskablePPOSelf:
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import sys
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -20,6 +20,8 @@ from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
|
|||
from sb3_contrib.common.recurrent.type_aliases import RNNStates
|
||||
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
|
||||
|
||||
RecurrentPPOSelf = TypeVar("RecurrentPPOSelf", bound="RecurrentPPO")
|
||||
|
||||
|
||||
class RecurrentPPO(OnPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -488,7 +490,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: RecurrentPPOSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
|
|
@ -498,7 +500,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
tb_log_name: str = "RecurrentPPO",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> "RecurrentPPO":
|
||||
) -> RecurrentPPOSelf:
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -13,6 +13,8 @@ from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name
|
|||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
||||
|
||||
QRDQNSelf = TypeVar("QRDQNSelf", bound="QRDQN")
|
||||
|
||||
|
||||
class QRDQN(OffPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -250,7 +252,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
return action, state
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: QRDQNSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
|
|
@ -260,7 +262,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
tb_log_name: str = "QRDQN",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
) -> QRDQNSelf:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -13,6 +13,8 @@ from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
|||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
|
||||
|
||||
TQCSelf = TypeVar("TQCSelf", bound="TQC")
|
||||
|
||||
|
||||
class TQC(OffPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -287,7 +289,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: TQCSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
|
|
@ -297,7 +299,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
tb_log_name: str = "TQC",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
) -> TQCSelf:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import copy
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -17,6 +17,8 @@ from torch.nn import functional as F
|
|||
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
|
||||
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
TRPOSelf = TypeVar("TRPOSelf", bound="TRPO")
|
||||
|
||||
|
||||
class TRPO(OnPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -403,7 +405,7 @@ class TRPO(OnPolicyAlgorithm):
|
|||
return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: TRPOSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
|
|
@ -413,7 +415,7 @@ class TRPO(OnPolicyAlgorithm):
|
|||
tb_log_name: str = "TRPO",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> OnPolicyAlgorithm:
|
||||
) -> TRPOSelf:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
Loading…
Reference in New Issue