Using policy_aliases instead of register_policy (#66)

* Using policy_aliases instead of register_policy

* Moving policy_aliases definitions

* Update SB3 version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Grégoire Passault 2022-04-08 15:36:23 -04:00 committed by GitHub
parent 9d7e33d213
commit 99853265a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 63 additions and 47 deletions

View File

@ -3,6 +3,24 @@
Changelog
==========
Release 1.5.1a1 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.5.1a1
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Release 1.5.0 (2022-03-25)
-------------------------------
@ -243,4 +261,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
Contributors:
-------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar

View File

@ -15,7 +15,7 @@ from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
from sb3_contrib.ars.policies import ARSPolicy
from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
from sb3_contrib.common.vec_env.async_eval import AsyncEval
@ -37,7 +37,6 @@ class ARS(BaseAlgorithm):
:param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses.
:param n_eval_episodes: Number of episodes to evaluate each candidate.
:param policy_kwargs: Keyword arguments to pass to the policy on creation
:param policy_base: Base class to use for the policy
:param tensorboard_log: String with the directory to put tensorboard logs:
:param seed: Random seed for the training
:param verbose: Verbosity level: 0 no output, 1 info, 2 debug
@ -45,6 +44,11 @@ class ARS(BaseAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"LinearPolicy": LinearPolicy,
}
def __init__(
self,
policy: Union[str, Type[ARSPolicy]],
@ -57,7 +61,6 @@ class ARS(BaseAlgorithm):
alive_bonus_offset: float = 0,
n_eval_episodes: int = 1,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_base: Type[BasePolicy] = ARSPolicy,
tensorboard_log: Optional[str] = None,
seed: Optional[int] = None,
verbose: int = 0,
@ -70,7 +73,6 @@ class ARS(BaseAlgorithm):
env,
learning_rate=learning_rate,
tensorboard_log=tensorboard_log,
policy_base=policy_base,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,

View File

@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Type
import gym
import torch as th
from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import create_mlp
from torch import nn
@ -113,7 +113,3 @@ class ARSLinearPolicy(ARSPolicy):
MlpPolicy = ARSPolicy
LinearPolicy = ARSLinearPolicy
register_policy("LinearPolicy", LinearPolicy)
register_policy("MlpPolicy", MlpPolicy)

View File

@ -1,5 +1,3 @@
from stable_baselines3.common.policies import register_policy
from sb3_contrib.common.maskable.policies import (
MaskableActorCriticCnnPolicy,
MaskableActorCriticPolicy,
@ -9,7 +7,3 @@ from sb3_contrib.common.maskable.policies import (
MlpPolicy = MaskableActorCriticPolicy
CnnPolicy = MaskableActorCriticCnnPolicy
MultiInputPolicy = MaskableMultiInputActorCriticPolicy
register_policy("MlpPolicy", MaskableActorCriticPolicy)
register_policy("CnnPolicy", MaskableActorCriticCnnPolicy)
register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)

View File

@ -10,6 +10,7 @@ from stable_baselines3.common import utils
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
@ -18,6 +19,7 @@ from torch.nn import functional as F
from sb3_contrib.common.maskable.buffers import MaskableDictRolloutBuffer, MaskableRolloutBuffer
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
class MaskablePPO(OnPolicyAlgorithm):
@ -65,6 +67,12 @@ class MaskablePPO(OnPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[MaskableActorCriticPolicy]],
@ -105,7 +113,6 @@ class MaskablePPO(OnPolicyAlgorithm):
tensorboard_log=tensorboard_log,
create_eval_env=create_eval_env,
policy_kwargs=policy_kwargs,
policy_base=MaskableActorCriticPolicy,
verbose=verbose,
seed=seed,
device=device,

View File

@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Type
import gym
import torch as th
from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
@ -307,8 +307,3 @@ class MultiInputPolicy(QRDQNPolicy):
optimizer_class,
optimizer_kwargs,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View File

@ -5,12 +5,13 @@ import numpy as np
import torch as th
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import QRDQNPolicy
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
class QRDQN(OffPolicyAlgorithm):
@ -57,6 +58,12 @@ class QRDQN(OffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[QRDQNPolicy]],
@ -89,7 +96,6 @@ class QRDQN(OffPolicyAlgorithm):
super(QRDQN, self).__init__(
policy,
env,
QRDQNPolicy,
learning_rate,
buffer_size,
learning_starts,

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy
from stable_baselines3.common.policies import BaseModel, BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
@ -580,8 +580,3 @@ class MultiInputPolicy(TQCPolicy):
n_critics,
share_features_extractor,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View File

@ -6,11 +6,12 @@ import torch as th
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import polyak_update
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.tqc.policies import TQCPolicy
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
class TQC(OffPolicyAlgorithm):
@ -64,6 +65,12 @@ class TQC(OffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[TQCPolicy]],
@ -99,7 +106,6 @@ class TQC(OffPolicyAlgorithm):
super(TQC, self).__init__(
policy,
env,
TQCPolicy,
learning_rate,
buffer_size,
learning_starts,

View File

@ -1,16 +1,7 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for TRPO
from stable_baselines3.common.policies import (
ActorCriticCnnPolicy,
ActorCriticPolicy,
MultiInputActorCriticPolicy,
register_policy,
)
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy
register_policy("MlpPolicy", ActorCriticPolicy)
register_policy("CnnPolicy", ActorCriticCnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View File

@ -7,7 +7,7 @@ import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule
from stable_baselines3.common.utils import explained_variance
from torch import nn
@ -15,6 +15,7 @@ from torch.distributions import kl_divergence
from torch.nn import functional as F
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
class TRPO(OnPolicyAlgorithm):
@ -66,6 +67,12 @@ class TRPO(OnPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
@ -106,7 +113,6 @@ class TRPO(OnPolicyAlgorithm):
max_grad_norm=0.0,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
policy_base=ActorCriticPolicy,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
verbose=verbose,

View File

@ -1 +1 @@
1.5.0
1.5.1a1

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.5.0",
"stable_baselines3>=1.5.1a1",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",