Using policy_aliases instead of register_policy (#66)
* Using policy_aliases instead of register_policy * Moving policy_aliases definitions * Update SB3 version Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
9d7e33d213
commit
99853265a9
|
|
@ -3,6 +3,24 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.5.1a1 (WIP)
|
||||
-------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Upgraded to Stable-Baselines3 >= 1.5.1a1
|
||||
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
|
||||
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Release 1.5.0 (2022-03-25)
|
||||
-------------------------------
|
||||
|
||||
|
|
@ -243,4 +261,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
|
|||
Contributors:
|
||||
-------------
|
||||
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from stable_baselines3.common.policies import BasePolicy
|
|||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
|
||||
|
||||
from sb3_contrib.ars.policies import ARSPolicy
|
||||
from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
|
||||
from sb3_contrib.common.vec_env.async_eval import AsyncEval
|
||||
|
||||
|
||||
|
|
@ -37,7 +37,6 @@ class ARS(BaseAlgorithm):
|
|||
:param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses.
|
||||
:param n_eval_episodes: Number of episodes to evaluate each candidate.
|
||||
:param policy_kwargs: Keyword arguments to pass to the policy on creation
|
||||
:param policy_base: Base class to use for the policy
|
||||
:param tensorboard_log: String with the directory to put tensorboard logs:
|
||||
:param seed: Random seed for the training
|
||||
:param verbose: Verbosity level: 0 no output, 1 info, 2 debug
|
||||
|
|
@ -45,6 +44,11 @@ class ARS(BaseAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": MlpPolicy,
|
||||
"LinearPolicy": LinearPolicy,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[ARSPolicy]],
|
||||
|
|
@ -57,7 +61,6 @@ class ARS(BaseAlgorithm):
|
|||
alive_bonus_offset: float = 0,
|
||||
n_eval_episodes: int = 1,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_base: Type[BasePolicy] = ARSPolicy,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
verbose: int = 0,
|
||||
|
|
@ -70,7 +73,6 @@ class ARS(BaseAlgorithm):
|
|||
env,
|
||||
learning_rate=learning_rate,
|
||||
tensorboard_log=tensorboard_log,
|
||||
policy_base=policy_base,
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=verbose,
|
||||
device=device,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Type
|
|||
|
||||
import gym
|
||||
import torch as th
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import create_mlp
|
||||
from torch import nn
|
||||
|
|
@ -113,7 +113,3 @@ class ARSLinearPolicy(ARSPolicy):
|
|||
|
||||
MlpPolicy = ARSPolicy
|
||||
LinearPolicy = ARSLinearPolicy
|
||||
|
||||
|
||||
register_policy("LinearPolicy", LinearPolicy)
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from stable_baselines3.common.policies import register_policy
|
||||
|
||||
from sb3_contrib.common.maskable.policies import (
|
||||
MaskableActorCriticCnnPolicy,
|
||||
MaskableActorCriticPolicy,
|
||||
|
|
@ -9,7 +7,3 @@ from sb3_contrib.common.maskable.policies import (
|
|||
MlpPolicy = MaskableActorCriticPolicy
|
||||
CnnPolicy = MaskableActorCriticCnnPolicy
|
||||
MultiInputPolicy = MaskableMultiInputActorCriticPolicy
|
||||
|
||||
register_policy("MlpPolicy", MaskableActorCriticPolicy)
|
||||
register_policy("CnnPolicy", MaskableActorCriticCnnPolicy)
|
||||
register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from stable_baselines3.common import utils
|
|||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
|
@ -18,6 +19,7 @@ from torch.nn import functional as F
|
|||
from sb3_contrib.common.maskable.buffers import MaskableDictRolloutBuffer, MaskableRolloutBuffer
|
||||
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
||||
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
|
||||
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
|
||||
class MaskablePPO(OnPolicyAlgorithm):
|
||||
|
|
@ -65,6 +67,12 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": MlpPolicy,
|
||||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[MaskableActorCriticPolicy]],
|
||||
|
|
@ -105,7 +113,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
tensorboard_log=tensorboard_log,
|
||||
create_eval_env=create_eval_env,
|
||||
policy_kwargs=policy_kwargs,
|
||||
policy_base=MaskableActorCriticPolicy,
|
||||
verbose=verbose,
|
||||
seed=seed,
|
||||
device=device,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Type
|
|||
|
||||
import gym
|
||||
import torch as th
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
CombinedExtractor,
|
||||
|
|
@ -307,8 +307,3 @@ class MultiInputPolicy(QRDQNPolicy):
|
|||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||
|
|
|
|||
|
|
@ -5,12 +5,13 @@ import numpy as np
|
|||
import torch as th
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
||||
|
||||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.qrdqn.policies import QRDQNPolicy
|
||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
||||
|
||||
|
||||
class QRDQN(OffPolicyAlgorithm):
|
||||
|
|
@ -57,6 +58,12 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": MlpPolicy,
|
||||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[QRDQNPolicy]],
|
||||
|
|
@ -89,7 +96,6 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
super(QRDQN, self).__init__(
|
||||
policy,
|
||||
env,
|
||||
QRDQNPolicy,
|
||||
learning_rate,
|
||||
buffer_size,
|
||||
learning_starts,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|||
import gym
|
||||
import torch as th
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
|
|
@ -580,8 +580,3 @@ class MultiInputPolicy(TQCPolicy):
|
|||
n_critics,
|
||||
share_features_extractor,
|
||||
)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||
|
|
|
|||
|
|
@ -6,11 +6,12 @@ import torch as th
|
|||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
from stable_baselines3.common.utils import polyak_update
|
||||
|
||||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.tqc.policies import TQCPolicy
|
||||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
|
||||
|
||||
|
||||
class TQC(OffPolicyAlgorithm):
|
||||
|
|
@ -64,6 +65,12 @@ class TQC(OffPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": MlpPolicy,
|
||||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[TQCPolicy]],
|
||||
|
|
@ -99,7 +106,6 @@ class TQC(OffPolicyAlgorithm):
|
|||
super(TQC, self).__init__(
|
||||
policy,
|
||||
env,
|
||||
TQCPolicy,
|
||||
learning_rate,
|
||||
buffer_size,
|
||||
learning_starts,
|
||||
|
|
|
|||
|
|
@ -1,16 +1,7 @@
|
|||
# This file is here just to define MlpPolicy/CnnPolicy
|
||||
# that work for TRPO
|
||||
from stable_baselines3.common.policies import (
|
||||
ActorCriticCnnPolicy,
|
||||
ActorCriticPolicy,
|
||||
MultiInputActorCriticPolicy,
|
||||
register_policy,
|
||||
)
|
||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||
|
||||
MlpPolicy = ActorCriticPolicy
|
||||
CnnPolicy = ActorCriticCnnPolicy
|
||||
MultiInputPolicy = MultiInputActorCriticPolicy
|
||||
|
||||
register_policy("MlpPolicy", ActorCriticPolicy)
|
||||
register_policy("CnnPolicy", ActorCriticCnnPolicy)
|
||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance
|
||||
from torch import nn
|
||||
|
|
@ -15,6 +15,7 @@ from torch.distributions import kl_divergence
|
|||
from torch.nn import functional as F
|
||||
|
||||
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
|
||||
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
|
||||
class TRPO(OnPolicyAlgorithm):
|
||||
|
|
@ -66,6 +67,12 @@ class TRPO(OnPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": MlpPolicy,
|
||||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[ActorCriticPolicy]],
|
||||
|
|
@ -106,7 +113,6 @@ class TRPO(OnPolicyAlgorithm):
|
|||
max_grad_norm=0.0,
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
policy_base=ActorCriticPolicy,
|
||||
tensorboard_log=tensorboard_log,
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=verbose,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.5.0
|
||||
1.5.1a1
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.5.0",
|
||||
"stable_baselines3>=1.5.1a1",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
Loading…
Reference in New Issue