From 99853265a9603ec703acf518d27048cd764e0a0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9goire=20Passault?= Date: Fri, 8 Apr 2022 15:36:23 -0400 Subject: [PATCH] Using policy_aliases instead of register_policy (#66) * Using policy_aliases instead of register_policy * Moving policy_aliases definitions * Update SB3 version Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 20 +++++++++++++++++++- sb3_contrib/ars/ars.py | 10 ++++++---- sb3_contrib/ars/policies.py | 6 +----- sb3_contrib/ppo_mask/policies.py | 6 ------ sb3_contrib/ppo_mask/ppo_mask.py | 9 ++++++++- sb3_contrib/qrdqn/policies.py | 7 +------ sb3_contrib/qrdqn/qrdqn.py | 10 ++++++++-- sb3_contrib/tqc/policies.py | 7 +------ sb3_contrib/tqc/tqc.py | 10 ++++++++-- sb3_contrib/trpo/policies.py | 11 +---------- sb3_contrib/trpo/trpo.py | 10 ++++++++-- sb3_contrib/version.txt | 2 +- setup.py | 2 +- 13 files changed, 63 insertions(+), 47 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index f6db019..d2b0136 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,24 @@ Changelog ========== +Release 1.5.1a1 (WIP) +------------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Upgraded to Stable-Baselines3 >= 1.5.1a1 +- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former + ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + Release 1.5.0 (2022-03-25) ------------------------------- @@ -243,4 +261,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) Contributors: ------------- -@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen +@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar diff --git a/sb3_contrib/ars/ars.py b/sb3_contrib/ars/ars.py index 8fed745..cc40a3f 100644 --- a/sb3_contrib/ars/ars.py +++ b/sb3_contrib/ars/ars.py @@ -15,7 +15,7 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_schedule_fn, safe_mean -from sb3_contrib.ars.policies import ARSPolicy +from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy from sb3_contrib.common.vec_env.async_eval import AsyncEval @@ -37,7 +37,6 @@ class ARS(BaseAlgorithm): :param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses. :param n_eval_episodes: Number of episodes to evaluate each candidate. :param policy_kwargs: Keyword arguments to pass to the policy on creation - :param policy_base: Base class to use for the policy :param tensorboard_log: String with the directory to put tensorboard logs: :param seed: Random seed for the training :param verbose: Verbosity level: 0 no output, 1 info, 2 debug @@ -45,6 +44,11 @@ class ARS(BaseAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "LinearPolicy": LinearPolicy, + } + def __init__( self, policy: Union[str, Type[ARSPolicy]], @@ -57,7 +61,6 @@ class ARS(BaseAlgorithm): alive_bonus_offset: float = 0, n_eval_episodes: int = 1, policy_kwargs: Optional[Dict[str, Any]] = None, - policy_base: Type[BasePolicy] = ARSPolicy, tensorboard_log: Optional[str] = None, seed: Optional[int] = None, verbose: int = 0, @@ -70,7 +73,6 @@ class ARS(BaseAlgorithm): env, learning_rate=learning_rate, tensorboard_log=tensorboard_log, - policy_base=policy_base, policy_kwargs=policy_kwargs, verbose=verbose, device=device, diff --git a/sb3_contrib/ars/policies.py b/sb3_contrib/ars/policies.py index e90927d..218cf0b 100644 --- a/sb3_contrib/ars/policies.py +++ b/sb3_contrib/ars/policies.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Type import gym import torch as th -from stable_baselines3.common.policies import BasePolicy, register_policy +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import create_mlp from torch import nn @@ -113,7 +113,3 @@ class ARSLinearPolicy(ARSPolicy): MlpPolicy = ARSPolicy LinearPolicy = ARSLinearPolicy - - -register_policy("LinearPolicy", LinearPolicy) -register_policy("MlpPolicy", MlpPolicy) diff --git a/sb3_contrib/ppo_mask/policies.py b/sb3_contrib/ppo_mask/policies.py index ad76571..b5252fa 100644 --- a/sb3_contrib/ppo_mask/policies.py +++ b/sb3_contrib/ppo_mask/policies.py @@ -1,5 +1,3 @@ -from stable_baselines3.common.policies import register_policy - from sb3_contrib.common.maskable.policies import ( MaskableActorCriticCnnPolicy, MaskableActorCriticPolicy, @@ -9,7 +7,3 @@ from sb3_contrib.common.maskable.policies import ( MlpPolicy = MaskableActorCriticPolicy CnnPolicy = MaskableActorCriticCnnPolicy MultiInputPolicy = MaskableMultiInputActorCriticPolicy - -register_policy("MlpPolicy", MaskableActorCriticPolicy) -register_policy("CnnPolicy", MaskableActorCriticCnnPolicy) -register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy) diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index b9cd6b5..2f70ce0 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -10,6 +10,7 @@ from stable_baselines3.common import utils from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv @@ -18,6 +19,7 @@ from torch.nn import functional as F from sb3_contrib.common.maskable.buffers import MaskableDictRolloutBuffer, MaskableRolloutBuffer from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported +from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy class MaskablePPO(OnPolicyAlgorithm): @@ -65,6 +67,12 @@ class MaskablePPO(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[MaskableActorCriticPolicy]], @@ -105,7 +113,6 @@ class MaskablePPO(OnPolicyAlgorithm): tensorboard_log=tensorboard_log, create_eval_env=create_eval_env, policy_kwargs=policy_kwargs, - policy_base=MaskableActorCriticPolicy, verbose=verbose, seed=seed, device=device, diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index d21b4ba..4aa0f67 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Type import gym import torch as th -from stable_baselines3.common.policies import BasePolicy, register_policy +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, @@ -307,8 +307,3 @@ class MultiInputPolicy(QRDQNPolicy): optimizer_class, optimizer_kwargs, ) - - -register_policy("MlpPolicy", MlpPolicy) -register_policy("CnnPolicy", CnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index bd08b59..24f286c 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -5,12 +5,13 @@ import numpy as np import torch as th from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update from sb3_contrib.common.utils import quantile_huber_loss -from sb3_contrib.qrdqn.policies import QRDQNPolicy +from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy class QRDQN(OffPolicyAlgorithm): @@ -57,6 +58,12 @@ class QRDQN(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[QRDQNPolicy]], @@ -89,7 +96,6 @@ class QRDQN(OffPolicyAlgorithm): super(QRDQN, self).__init__( policy, env, - QRDQNPolicy, learning_rate, buffer_size, learning_starts, diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index d098be0..171846d 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import gym import torch as th from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution -from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy +from stable_baselines3.common.policies import BaseModel, BasePolicy from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -580,8 +580,3 @@ class MultiInputPolicy(TQCPolicy): n_critics, share_features_extractor, ) - - -register_policy("MlpPolicy", MlpPolicy) -register_policy("CnnPolicy", CnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 687fc46..9c574e2 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -6,11 +6,12 @@ import torch as th from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.utils import polyak_update from sb3_contrib.common.utils import quantile_huber_loss -from sb3_contrib.tqc.policies import TQCPolicy +from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy class TQC(OffPolicyAlgorithm): @@ -64,6 +65,12 @@ class TQC(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[TQCPolicy]], @@ -99,7 +106,6 @@ class TQC(OffPolicyAlgorithm): super(TQC, self).__init__( policy, env, - TQCPolicy, learning_rate, buffer_size, learning_starts, diff --git a/sb3_contrib/trpo/policies.py b/sb3_contrib/trpo/policies.py index 27cde53..a715207 100644 --- a/sb3_contrib/trpo/policies.py +++ b/sb3_contrib/trpo/policies.py @@ -1,16 +1,7 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for TRPO -from stable_baselines3.common.policies import ( - ActorCriticCnnPolicy, - ActorCriticPolicy, - MultiInputActorCriticPolicy, - register_policy, -) +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy MultiInputPolicy = MultiInputActorCriticPolicy - -register_policy("MlpPolicy", ActorCriticPolicy) -register_policy("CnnPolicy", ActorCriticCnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 5eeeef0..e3d2f03 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -7,7 +7,7 @@ import numpy as np import torch as th from gym import spaces from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule from stable_baselines3.common.utils import explained_variance from torch import nn @@ -15,6 +15,7 @@ from torch.distributions import kl_divergence from torch.nn import functional as F from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad +from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy class TRPO(OnPolicyAlgorithm): @@ -66,6 +67,12 @@ class TRPO(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], @@ -106,7 +113,6 @@ class TRPO(OnPolicyAlgorithm): max_grad_norm=0.0, use_sde=use_sde, sde_sample_freq=sde_sample_freq, - policy_base=ActorCriticPolicy, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, verbose=verbose, diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index bc80560..1110517 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.5.0 +1.5.1a1 diff --git a/setup.py b/setup.py index 7a4fb3a..c26523a 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.5.0", + "stable_baselines3>=1.5.1a1", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin",