Using policy_aliases instead of register_policy (#66)
* Using policy_aliases instead of register_policy * Moving policy_aliases definitions * Update SB3 version Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
9d7e33d213
commit
99853265a9
|
|
@ -3,6 +3,24 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
Release 1.5.1a1 (WIP)
|
||||||
|
-------------------------------
|
||||||
|
|
||||||
|
Breaking Changes:
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Upgraded to Stable-Baselines3 >= 1.5.1a1
|
||||||
|
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
|
||||||
|
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
|
||||||
|
|
||||||
|
New Features:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
Deprecations:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
Release 1.5.0 (2022-03-25)
|
Release 1.5.0 (2022-03-25)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
|
|
@ -243,4 +261,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
|
||||||
Contributors:
|
Contributors:
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen
|
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
|
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
|
||||||
|
|
||||||
from sb3_contrib.ars.policies import ARSPolicy
|
from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
|
||||||
from sb3_contrib.common.vec_env.async_eval import AsyncEval
|
from sb3_contrib.common.vec_env.async_eval import AsyncEval
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -37,7 +37,6 @@ class ARS(BaseAlgorithm):
|
||||||
:param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses.
|
:param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses.
|
||||||
:param n_eval_episodes: Number of episodes to evaluate each candidate.
|
:param n_eval_episodes: Number of episodes to evaluate each candidate.
|
||||||
:param policy_kwargs: Keyword arguments to pass to the policy on creation
|
:param policy_kwargs: Keyword arguments to pass to the policy on creation
|
||||||
:param policy_base: Base class to use for the policy
|
|
||||||
:param tensorboard_log: String with the directory to put tensorboard logs:
|
:param tensorboard_log: String with the directory to put tensorboard logs:
|
||||||
:param seed: Random seed for the training
|
:param seed: Random seed for the training
|
||||||
:param verbose: Verbosity level: 0 no output, 1 info, 2 debug
|
:param verbose: Verbosity level: 0 no output, 1 info, 2 debug
|
||||||
|
|
@ -45,6 +44,11 @@ class ARS(BaseAlgorithm):
|
||||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||||
|
"MlpPolicy": MlpPolicy,
|
||||||
|
"LinearPolicy": LinearPolicy,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policy: Union[str, Type[ARSPolicy]],
|
policy: Union[str, Type[ARSPolicy]],
|
||||||
|
|
@ -57,7 +61,6 @@ class ARS(BaseAlgorithm):
|
||||||
alive_bonus_offset: float = 0,
|
alive_bonus_offset: float = 0,
|
||||||
n_eval_episodes: int = 1,
|
n_eval_episodes: int = 1,
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
policy_base: Type[BasePolicy] = ARSPolicy,
|
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
|
|
@ -70,7 +73,6 @@ class ARS(BaseAlgorithm):
|
||||||
env,
|
env,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
policy_base=policy_base,
|
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
device=device,
|
device=device,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
from stable_baselines3.common.torch_layers import create_mlp
|
from stable_baselines3.common.torch_layers import create_mlp
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
@ -113,7 +113,3 @@ class ARSLinearPolicy(ARSPolicy):
|
||||||
|
|
||||||
MlpPolicy = ARSPolicy
|
MlpPolicy = ARSPolicy
|
||||||
LinearPolicy = ARSLinearPolicy
|
LinearPolicy = ARSLinearPolicy
|
||||||
|
|
||||||
|
|
||||||
register_policy("LinearPolicy", LinearPolicy)
|
|
||||||
register_policy("MlpPolicy", MlpPolicy)
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
from stable_baselines3.common.policies import register_policy
|
|
||||||
|
|
||||||
from sb3_contrib.common.maskable.policies import (
|
from sb3_contrib.common.maskable.policies import (
|
||||||
MaskableActorCriticCnnPolicy,
|
MaskableActorCriticCnnPolicy,
|
||||||
MaskableActorCriticPolicy,
|
MaskableActorCriticPolicy,
|
||||||
|
|
@ -9,7 +7,3 @@ from sb3_contrib.common.maskable.policies import (
|
||||||
MlpPolicy = MaskableActorCriticPolicy
|
MlpPolicy = MaskableActorCriticPolicy
|
||||||
CnnPolicy = MaskableActorCriticCnnPolicy
|
CnnPolicy = MaskableActorCriticCnnPolicy
|
||||||
MultiInputPolicy = MaskableMultiInputActorCriticPolicy
|
MultiInputPolicy = MaskableMultiInputActorCriticPolicy
|
||||||
|
|
||||||
register_policy("MlpPolicy", MaskableActorCriticPolicy)
|
|
||||||
register_policy("CnnPolicy", MaskableActorCriticCnnPolicy)
|
|
||||||
register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from stable_baselines3.common import utils
|
||||||
from stable_baselines3.common.buffers import RolloutBuffer
|
from stable_baselines3.common.buffers import RolloutBuffer
|
||||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback
|
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
|
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
|
||||||
from stable_baselines3.common.vec_env import VecEnv
|
from stable_baselines3.common.vec_env import VecEnv
|
||||||
|
|
@ -18,6 +19,7 @@ from torch.nn import functional as F
|
||||||
from sb3_contrib.common.maskable.buffers import MaskableDictRolloutBuffer, MaskableRolloutBuffer
|
from sb3_contrib.common.maskable.buffers import MaskableDictRolloutBuffer, MaskableRolloutBuffer
|
||||||
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
||||||
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
|
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
|
||||||
|
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
|
|
||||||
|
|
||||||
class MaskablePPO(OnPolicyAlgorithm):
|
class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
@ -65,6 +67,12 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||||
|
"MlpPolicy": MlpPolicy,
|
||||||
|
"CnnPolicy": CnnPolicy,
|
||||||
|
"MultiInputPolicy": MultiInputPolicy,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policy: Union[str, Type[MaskableActorCriticPolicy]],
|
policy: Union[str, Type[MaskableActorCriticPolicy]],
|
||||||
|
|
@ -105,7 +113,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
create_eval_env=create_eval_env,
|
create_eval_env=create_eval_env,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
policy_base=MaskableActorCriticPolicy,
|
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
device=device,
|
device=device,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.torch_layers import (
|
from stable_baselines3.common.torch_layers import (
|
||||||
BaseFeaturesExtractor,
|
BaseFeaturesExtractor,
|
||||||
CombinedExtractor,
|
CombinedExtractor,
|
||||||
|
|
@ -307,8 +307,3 @@ class MultiInputPolicy(QRDQNPolicy):
|
||||||
optimizer_class,
|
optimizer_class,
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_policy("MlpPolicy", MlpPolicy)
|
|
||||||
register_policy("CnnPolicy", CnnPolicy)
|
|
||||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,13 @@ import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.buffers import ReplayBuffer
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
||||||
|
|
||||||
from sb3_contrib.common.utils import quantile_huber_loss
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
from sb3_contrib.qrdqn.policies import QRDQNPolicy
|
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
||||||
|
|
||||||
|
|
||||||
class QRDQN(OffPolicyAlgorithm):
|
class QRDQN(OffPolicyAlgorithm):
|
||||||
|
|
@ -57,6 +58,12 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||||
|
"MlpPolicy": MlpPolicy,
|
||||||
|
"CnnPolicy": CnnPolicy,
|
||||||
|
"MultiInputPolicy": MultiInputPolicy,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policy: Union[str, Type[QRDQNPolicy]],
|
policy: Union[str, Type[QRDQNPolicy]],
|
||||||
|
|
@ -89,7 +96,6 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
super(QRDQN, self).__init__(
|
super(QRDQN, self).__init__(
|
||||||
policy,
|
policy,
|
||||||
env,
|
env,
|
||||||
QRDQNPolicy,
|
|
||||||
learning_rate,
|
learning_rate,
|
||||||
buffer_size,
|
buffer_size,
|
||||||
learning_starts,
|
learning_starts,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
import gym
|
import gym
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy
|
from stable_baselines3.common.policies import BaseModel, BasePolicy
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
from stable_baselines3.common.torch_layers import (
|
from stable_baselines3.common.torch_layers import (
|
||||||
BaseFeaturesExtractor,
|
BaseFeaturesExtractor,
|
||||||
|
|
@ -580,8 +580,3 @@ class MultiInputPolicy(TQCPolicy):
|
||||||
n_critics,
|
n_critics,
|
||||||
share_features_extractor,
|
share_features_extractor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_policy("MlpPolicy", MlpPolicy)
|
|
||||||
register_policy("CnnPolicy", CnnPolicy)
|
|
||||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,12 @@ import torch as th
|
||||||
from stable_baselines3.common.buffers import ReplayBuffer
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.noise import ActionNoise
|
from stable_baselines3.common.noise import ActionNoise
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||||
from stable_baselines3.common.utils import polyak_update
|
from stable_baselines3.common.utils import polyak_update
|
||||||
|
|
||||||
from sb3_contrib.common.utils import quantile_huber_loss
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
from sb3_contrib.tqc.policies import TQCPolicy
|
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
|
||||||
|
|
||||||
|
|
||||||
class TQC(OffPolicyAlgorithm):
|
class TQC(OffPolicyAlgorithm):
|
||||||
|
|
@ -64,6 +65,12 @@ class TQC(OffPolicyAlgorithm):
|
||||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||||
|
"MlpPolicy": MlpPolicy,
|
||||||
|
"CnnPolicy": CnnPolicy,
|
||||||
|
"MultiInputPolicy": MultiInputPolicy,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policy: Union[str, Type[TQCPolicy]],
|
policy: Union[str, Type[TQCPolicy]],
|
||||||
|
|
@ -99,7 +106,6 @@ class TQC(OffPolicyAlgorithm):
|
||||||
super(TQC, self).__init__(
|
super(TQC, self).__init__(
|
||||||
policy,
|
policy,
|
||||||
env,
|
env,
|
||||||
TQCPolicy,
|
|
||||||
learning_rate,
|
learning_rate,
|
||||||
buffer_size,
|
buffer_size,
|
||||||
learning_starts,
|
learning_starts,
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,7 @@
|
||||||
# This file is here just to define MlpPolicy/CnnPolicy
|
# This file is here just to define MlpPolicy/CnnPolicy
|
||||||
# that work for TRPO
|
# that work for TRPO
|
||||||
from stable_baselines3.common.policies import (
|
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||||
ActorCriticCnnPolicy,
|
|
||||||
ActorCriticPolicy,
|
|
||||||
MultiInputActorCriticPolicy,
|
|
||||||
register_policy,
|
|
||||||
)
|
|
||||||
|
|
||||||
MlpPolicy = ActorCriticPolicy
|
MlpPolicy = ActorCriticPolicy
|
||||||
CnnPolicy = ActorCriticCnnPolicy
|
CnnPolicy = ActorCriticCnnPolicy
|
||||||
MultiInputPolicy = MultiInputActorCriticPolicy
|
MultiInputPolicy = MultiInputActorCriticPolicy
|
||||||
|
|
||||||
register_policy("MlpPolicy", ActorCriticPolicy)
|
|
||||||
register_policy("CnnPolicy", ActorCriticCnnPolicy)
|
|
||||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule
|
||||||
from stable_baselines3.common.utils import explained_variance
|
from stable_baselines3.common.utils import explained_variance
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
@ -15,6 +15,7 @@ from torch.distributions import kl_divergence
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
|
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
|
||||||
|
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
|
|
||||||
|
|
||||||
class TRPO(OnPolicyAlgorithm):
|
class TRPO(OnPolicyAlgorithm):
|
||||||
|
|
@ -66,6 +67,12 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||||
|
"MlpPolicy": MlpPolicy,
|
||||||
|
"CnnPolicy": CnnPolicy,
|
||||||
|
"MultiInputPolicy": MultiInputPolicy,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policy: Union[str, Type[ActorCriticPolicy]],
|
policy: Union[str, Type[ActorCriticPolicy]],
|
||||||
|
|
@ -106,7 +113,6 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
max_grad_norm=0.0,
|
max_grad_norm=0.0,
|
||||||
use_sde=use_sde,
|
use_sde=use_sde,
|
||||||
sde_sample_freq=sde_sample_freq,
|
sde_sample_freq=sde_sample_freq,
|
||||||
policy_base=ActorCriticPolicy,
|
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.5.0
|
1.5.1a1
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.5.0",
|
"stable_baselines3>=1.5.1a1",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue