Remove deprecated features (#108)
* Remove deprecated features * Upgrade SB3 * Fix tests
This commit is contained in:
parent
52795a307e
commit
c75ad7dd58
|
|
@ -3,6 +3,27 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
Release 1.7.0a0 (WIP)
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
Breaking Changes:
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
||||||
|
please use an ``EvalCallback`` instead
|
||||||
|
- Removed deprecated ``sde_net_arch`` parameter
|
||||||
|
|
||||||
|
New Features:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
Deprecations:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Others:
|
||||||
|
^^^^^^^
|
||||||
|
|
||||||
|
|
||||||
Release 1.6.2 (2022-10-10)
|
Release 1.6.2 (2022-10-10)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
|
||||||
|
|
@ -309,10 +309,6 @@ class ARS(BaseAlgorithm):
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 1,
|
log_interval: int = 1,
|
||||||
tb_log_name: str = "ARS",
|
tb_log_name: str = "ARS",
|
||||||
eval_env: Optional[GymEnv] = None,
|
|
||||||
eval_freq: int = -1,
|
|
||||||
n_eval_episodes: int = 5,
|
|
||||||
eval_log_path: Optional[str] = None,
|
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
async_eval: Optional[AsyncEval] = None,
|
async_eval: Optional[AsyncEval] = None,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
|
|
@ -324,14 +320,6 @@ class ARS(BaseAlgorithm):
|
||||||
:param callback: callback(s) called at every step with state of the algorithm.
|
:param callback: callback(s) called at every step with state of the algorithm.
|
||||||
:param log_interval: The number of timesteps before logging.
|
:param log_interval: The number of timesteps before logging.
|
||||||
:param tb_log_name: the name of the run for TensorBoard logging
|
:param tb_log_name: the name of the run for TensorBoard logging
|
||||||
:param eval_env: Environment to use for evaluation.
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
Please use `EvalCallback` or a custom Callback instead.
|
|
||||||
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
Please use `EvalCallback` or a custom Callback instead.
|
|
||||||
:param n_eval_episodes: Number of episode to evaluate the agent
|
|
||||||
:param eval_log_path: Path to a folder where the evaluations will be saved
|
|
||||||
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
||||||
:param async_eval: The object for asynchronous evaluation of candidates.
|
:param async_eval: The object for asynchronous evaluation of candidates.
|
||||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||||
|
|
@ -340,11 +328,7 @@ class ARS(BaseAlgorithm):
|
||||||
|
|
||||||
total_steps, callback = self._setup_learn(
|
total_steps, callback = self._setup_learn(
|
||||||
total_timesteps,
|
total_timesteps,
|
||||||
eval_env,
|
|
||||||
callback,
|
callback,
|
||||||
eval_freq,
|
|
||||||
n_eval_episodes,
|
|
||||||
eval_log_path,
|
|
||||||
reset_num_timesteps,
|
reset_num_timesteps,
|
||||||
tb_log_name,
|
tb_log_name,
|
||||||
progress_bar,
|
progress_bar,
|
||||||
|
|
|
||||||
|
|
@ -88,18 +88,21 @@ def evaluate_policy( # noqa: C901
|
||||||
current_lengths = np.zeros(n_envs, dtype="int")
|
current_lengths = np.zeros(n_envs, dtype="int")
|
||||||
observations = env.reset()
|
observations = env.reset()
|
||||||
states = None
|
states = None
|
||||||
|
episode_starts = np.ones((env.num_envs,), dtype=bool)
|
||||||
while (episode_counts < episode_count_targets).any():
|
while (episode_counts < episode_count_targets).any():
|
||||||
if use_masking:
|
if use_masking:
|
||||||
action_masks = get_action_masks(env)
|
action_masks = get_action_masks(env)
|
||||||
actions, state = model.predict(
|
actions, state = model.predict(
|
||||||
observations,
|
observations,
|
||||||
state=states,
|
state=states,
|
||||||
|
episode_start=episode_starts,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
action_masks=action_masks,
|
action_masks=action_masks,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
actions, states = model.predict(observations, state=states, deterministic=deterministic)
|
actions, states = model.predict(
|
||||||
|
observations, state=states, episode_start=episode_starts, deterministic=deterministic
|
||||||
|
)
|
||||||
observations, rewards, dones, infos = env.step(actions)
|
observations, rewards, dones, infos = env.step(actions)
|
||||||
current_rewards += rewards
|
current_rewards += rewards
|
||||||
current_lengths += 1
|
current_lengths += 1
|
||||||
|
|
@ -110,6 +113,7 @@ def evaluate_policy( # noqa: C901
|
||||||
reward = rewards[i]
|
reward = rewards[i]
|
||||||
done = dones[i]
|
done = dones[i]
|
||||||
info = infos[i]
|
info = infos[i]
|
||||||
|
episode_starts[i] = done
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(locals(), globals())
|
callback(locals(), globals())
|
||||||
|
|
@ -133,8 +137,6 @@ def evaluate_policy( # noqa: C901
|
||||||
episode_counts[i] += 1
|
episode_counts[i] += 1
|
||||||
current_rewards[i] = 0
|
current_rewards[i] = 0
|
||||||
current_lengths[i] = 0
|
current_lengths[i] = 0
|
||||||
if states is not None:
|
|
||||||
states[i] *= 0
|
|
||||||
|
|
||||||
if render:
|
if render:
|
||||||
env.render()
|
env.render()
|
||||||
|
|
|
||||||
|
|
@ -36,9 +36,6 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
:param log_std_init: Initial value for the log standard deviation
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
:param full_std: Whether to use (n_features x n_actions) parameters
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||||
for the std instead of only (n_features,) when using gSDE
|
for the std instead of only (n_features,) when using gSDE
|
||||||
:param sde_net_arch: Network architecture for extracting features
|
|
||||||
when using gSDE. If None, the latent features from the policy will be used.
|
|
||||||
Pass an empty list to use the states as features.
|
|
||||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||||
a positive standard deviation (cf paper). It allows to keep variance
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
|
@ -74,7 +71,6 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
log_std_init: float = 0.0,
|
log_std_init: float = 0.0,
|
||||||
full_std: bool = True,
|
full_std: bool = True,
|
||||||
sde_net_arch: Optional[List[int]] = None,
|
|
||||||
use_expln: bool = False,
|
use_expln: bool = False,
|
||||||
squash_output: bool = False,
|
squash_output: bool = False,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||||
|
|
@ -99,7 +95,6 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
use_sde,
|
use_sde,
|
||||||
log_std_init,
|
log_std_init,
|
||||||
full_std,
|
full_std,
|
||||||
sde_net_arch,
|
|
||||||
use_expln,
|
use_expln,
|
||||||
squash_output,
|
squash_output,
|
||||||
features_extractor_class,
|
features_extractor_class,
|
||||||
|
|
@ -436,9 +431,6 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||||
:param log_std_init: Initial value for the log standard deviation
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
:param full_std: Whether to use (n_features x n_actions) parameters
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||||
for the std instead of only (n_features,) when using gSDE
|
for the std instead of only (n_features,) when using gSDE
|
||||||
:param sde_net_arch: Network architecture for extracting features
|
|
||||||
when using gSDE. If None, the latent features from the policy will be used.
|
|
||||||
Pass an empty list to use the states as features.
|
|
||||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||||
a positive standard deviation (cf paper). It allows to keep variance
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
|
@ -473,7 +465,6 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
log_std_init: float = 0.0,
|
log_std_init: float = 0.0,
|
||||||
full_std: bool = True,
|
full_std: bool = True,
|
||||||
sde_net_arch: Optional[List[int]] = None,
|
|
||||||
use_expln: bool = False,
|
use_expln: bool = False,
|
||||||
squash_output: bool = False,
|
squash_output: bool = False,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||||
|
|
@ -497,7 +488,6 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||||
use_sde,
|
use_sde,
|
||||||
log_std_init,
|
log_std_init,
|
||||||
full_std,
|
full_std,
|
||||||
sde_net_arch,
|
|
||||||
use_expln,
|
use_expln,
|
||||||
squash_output,
|
squash_output,
|
||||||
features_extractor_class,
|
features_extractor_class,
|
||||||
|
|
@ -528,9 +518,6 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||||
:param log_std_init: Initial value for the log standard deviation
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
:param full_std: Whether to use (n_features x n_actions) parameters
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||||
for the std instead of only (n_features,) when using gSDE
|
for the std instead of only (n_features,) when using gSDE
|
||||||
:param sde_net_arch: Network architecture for extracting features
|
|
||||||
when using gSDE. If None, the latent features from the policy will be used.
|
|
||||||
Pass an empty list to use the states as features.
|
|
||||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||||
a positive standard deviation (cf paper). It allows to keep variance
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
|
@ -565,7 +552,6 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
log_std_init: float = 0.0,
|
log_std_init: float = 0.0,
|
||||||
full_std: bool = True,
|
full_std: bool = True,
|
||||||
sde_net_arch: Optional[List[int]] = None,
|
|
||||||
use_expln: bool = False,
|
use_expln: bool = False,
|
||||||
squash_output: bool = False,
|
squash_output: bool = False,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||||
|
|
@ -589,7 +575,6 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||||
use_sde,
|
use_sde,
|
||||||
log_std_init,
|
log_std_init,
|
||||||
full_std,
|
full_std,
|
||||||
sde_net_arch,
|
|
||||||
use_expln,
|
use_expln,
|
||||||
squash_output,
|
squash_output,
|
||||||
features_extractor_class,
|
features_extractor_class,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import warnings
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
|
|
@ -61,9 +60,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||||
By default, there is no limit on the kl div.
|
By default, there is no limit on the kl div.
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
:param create_eval_env: Whether to create a second environment that will be
|
|
||||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
:param seed: Seed for the pseudo random generators
|
:param seed: Seed for the pseudo random generators
|
||||||
|
|
@ -96,7 +92,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
max_grad_norm: float = 0.5,
|
max_grad_norm: float = 0.5,
|
||||||
target_kl: Optional[float] = None,
|
target_kl: Optional[float] = None,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
create_eval_env: bool = False,
|
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
|
@ -116,7 +111,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
use_sde=False,
|
use_sde=False,
|
||||||
sde_sample_freq=-1,
|
sde_sample_freq=-1,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
create_eval_env=create_eval_env,
|
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
|
@ -179,23 +173,11 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
def _init_callback(
|
def _init_callback(
|
||||||
self,
|
self,
|
||||||
callback: MaybeCallback,
|
callback: MaybeCallback,
|
||||||
eval_env: Optional[VecEnv] = None,
|
|
||||||
eval_freq: int = 10000,
|
|
||||||
n_eval_episodes: int = 5,
|
|
||||||
log_path: Optional[str] = None,
|
|
||||||
use_masking: bool = True,
|
use_masking: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> BaseCallback:
|
) -> BaseCallback:
|
||||||
"""
|
"""
|
||||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||||
:param eval_env: Environment to use for evaluation.
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
Please use `MaskableEvalCallback` or a custom Callback instead.
|
|
||||||
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
Please use `MaskableEvalCallback` or a custom Callback instead.
|
|
||||||
:param n_eval_episodes: How many episodes to play per evaluation
|
|
||||||
:param log_path: Path to a folder where the evaluations will be saved
|
|
||||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
:param use_masking: Whether or not to use invalid action masks during evaluation
|
||||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||||
:return: A hybrid callback calling `callback` and performing evaluation.
|
:return: A hybrid callback calling `callback` and performing evaluation.
|
||||||
|
|
@ -212,33 +194,13 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
callback = CallbackList([callback, ProgressBarCallback()])
|
callback = CallbackList([callback, ProgressBarCallback()])
|
||||||
|
|
||||||
# Create eval callback in charge of the evaluation
|
|
||||||
if eval_env is not None:
|
|
||||||
# Avoid circular import error
|
|
||||||
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
|
||||||
|
|
||||||
eval_callback = MaskableEvalCallback(
|
|
||||||
eval_env,
|
|
||||||
best_model_save_path=log_path,
|
|
||||||
log_path=log_path,
|
|
||||||
eval_freq=eval_freq,
|
|
||||||
n_eval_episodes=n_eval_episodes,
|
|
||||||
use_masking=use_masking,
|
|
||||||
verbose=self.verbose,
|
|
||||||
)
|
|
||||||
callback = CallbackList([callback, eval_callback])
|
|
||||||
|
|
||||||
callback.init_callback(self)
|
callback.init_callback(self)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
def _setup_learn(
|
def _setup_learn(
|
||||||
self,
|
self,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
eval_env: Optional[GymEnv],
|
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
eval_freq: int = 10000,
|
|
||||||
n_eval_episodes: int = 5,
|
|
||||||
log_path: Optional[str] = None,
|
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
tb_log_name: str = "run",
|
tb_log_name: str = "run",
|
||||||
use_masking: bool = True,
|
use_masking: bool = True,
|
||||||
|
|
@ -248,15 +210,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
Initialize different variables needed for training.
|
Initialize different variables needed for training.
|
||||||
|
|
||||||
:param total_timesteps: The total number of samples (env steps) to train on
|
:param total_timesteps: The total number of samples (env steps) to train on
|
||||||
:param eval_env: Environment to use for evaluation.
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
Please use `MaskableEvalCallback` or a custom Callback instead.
|
|
||||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||||
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
Please use `MaskableEvalCallback` or a custom Callback instead.
|
|
||||||
:param n_eval_episodes: How many episodes to play per evaluation
|
|
||||||
:param log_path: Path to a folder where the evaluations will be saved
|
|
||||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
||||||
:param tb_log_name: the name of the run for tensorboard log
|
:param tb_log_name: the name of the run for tensorboard log
|
||||||
:param use_masking: Whether or not to use invalid action masks during training
|
:param use_masking: Whether or not to use invalid action masks during training
|
||||||
|
|
@ -264,17 +218,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if eval_env is not None or eval_freq != -1:
|
|
||||||
warnings.warn(
|
|
||||||
"Parameters `eval_env` and `eval_freq` are deprecated and will be removed in the future. "
|
|
||||||
"Please use `MaskableEvalCallback` or a custom Callback instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
# By setting the `stacklevel` we refer to the initial caller of the deprecated feature.
|
|
||||||
# This causes the the `DepricationWarning` to not be ignored and to be shown to the user. See
|
|
||||||
# https://github.com/DLR-RM/stable-baselines3/pull/1082#discussion_r989842855 for more details.
|
|
||||||
stacklevel=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.start_time = time.time_ns()
|
self.start_time = time.time_ns()
|
||||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
if self.ep_info_buffer is None or reset_num_timesteps:
|
||||||
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
||||||
|
|
@ -297,17 +240,12 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
if self._vec_normalize_env is not None:
|
if self._vec_normalize_env is not None:
|
||||||
self._last_original_obs = self._vec_normalize_env.get_original_obs()
|
self._last_original_obs = self._vec_normalize_env.get_original_obs()
|
||||||
|
|
||||||
if eval_env is not None and self.seed is not None:
|
|
||||||
eval_env.seed(self.seed)
|
|
||||||
|
|
||||||
eval_env = self._get_eval_env(eval_env)
|
|
||||||
|
|
||||||
# Configure logger's outputs if no logger was passed
|
# Configure logger's outputs if no logger was passed
|
||||||
if not self._custom_logger:
|
if not self._custom_logger:
|
||||||
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
||||||
|
|
||||||
# Create eval callback if needed
|
# Create eval callback if needed
|
||||||
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking, progress_bar)
|
callback = self._init_callback(callback, use_masking, progress_bar)
|
||||||
|
|
||||||
return total_timesteps, callback
|
return total_timesteps, callback
|
||||||
|
|
||||||
|
|
@ -564,11 +502,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 1,
|
log_interval: int = 1,
|
||||||
eval_env: Optional[GymEnv] = None,
|
|
||||||
eval_freq: int = -1,
|
|
||||||
n_eval_episodes: int = 5,
|
|
||||||
tb_log_name: str = "PPO",
|
tb_log_name: str = "PPO",
|
||||||
eval_log_path: Optional[str] = None,
|
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
use_masking: bool = True,
|
use_masking: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
|
|
@ -577,11 +511,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
total_timesteps, callback = self._setup_learn(
|
total_timesteps, callback = self._setup_learn(
|
||||||
total_timesteps,
|
total_timesteps,
|
||||||
eval_env,
|
|
||||||
callback,
|
callback,
|
||||||
eval_freq,
|
|
||||||
n_eval_episodes,
|
|
||||||
eval_log_path,
|
|
||||||
reset_num_timesteps,
|
reset_num_timesteps,
|
||||||
tb_log_name,
|
tb_log_name,
|
||||||
use_masking,
|
use_masking,
|
||||||
|
|
|
||||||
|
|
@ -58,9 +58,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||||
By default, there is no limit on the kl div.
|
By default, there is no limit on the kl div.
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
:param create_eval_env: Whether to create a second environment that will be
|
|
||||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
:param seed: Seed for the pseudo random generators
|
:param seed: Seed for the pseudo random generators
|
||||||
|
|
@ -95,7 +92,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
sde_sample_freq: int = -1,
|
sde_sample_freq: int = -1,
|
||||||
target_kl: Optional[float] = None,
|
target_kl: Optional[float] = None,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
create_eval_env: bool = False,
|
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
|
@ -115,7 +111,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
use_sde=use_sde,
|
use_sde=use_sde,
|
||||||
sde_sample_freq=sde_sample_freq,
|
sde_sample_freq=sde_sample_freq,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
create_eval_env=create_eval_env,
|
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
|
@ -453,11 +448,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 1,
|
log_interval: int = 1,
|
||||||
eval_env: Optional[GymEnv] = None,
|
|
||||||
eval_freq: int = -1,
|
|
||||||
n_eval_episodes: int = 5,
|
|
||||||
tb_log_name: str = "RecurrentPPO",
|
tb_log_name: str = "RecurrentPPO",
|
||||||
eval_log_path: Optional[str] = None,
|
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> RecurrentPPOSelf:
|
) -> RecurrentPPOSelf:
|
||||||
|
|
@ -465,11 +456,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
total_timesteps, callback = self._setup_learn(
|
total_timesteps, callback = self._setup_learn(
|
||||||
total_timesteps,
|
total_timesteps,
|
||||||
eval_env,
|
|
||||||
callback,
|
callback,
|
||||||
eval_freq,
|
|
||||||
n_eval_episodes,
|
|
||||||
eval_log_path,
|
|
||||||
reset_num_timesteps,
|
reset_num_timesteps,
|
||||||
tb_log_name,
|
tb_log_name,
|
||||||
progress_bar,
|
progress_bar,
|
||||||
|
|
|
||||||
|
|
@ -50,9 +50,6 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
:param exploration_final_eps: final value of random action probability
|
:param exploration_final_eps: final value of random action probability
|
||||||
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping)
|
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping)
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
:param create_eval_env: Whether to create a second environment that will be
|
|
||||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
:param seed: Seed for the pseudo random generators
|
:param seed: Seed for the pseudo random generators
|
||||||
|
|
@ -88,7 +85,6 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
exploration_final_eps: float = 0.01,
|
exploration_final_eps: float = 0.01,
|
||||||
max_grad_norm: Optional[float] = None,
|
max_grad_norm: Optional[float] = None,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
create_eval_env: bool = False,
|
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
|
@ -114,7 +110,6 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
device=device,
|
device=device,
|
||||||
create_eval_env=create_eval_env,
|
|
||||||
seed=seed,
|
seed=seed,
|
||||||
sde_support=False,
|
sde_support=False,
|
||||||
optimize_memory_usage=optimize_memory_usage,
|
optimize_memory_usage=optimize_memory_usage,
|
||||||
|
|
@ -256,11 +251,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 4,
|
log_interval: int = 4,
|
||||||
eval_env: Optional[GymEnv] = None,
|
|
||||||
eval_freq: int = -1,
|
|
||||||
n_eval_episodes: int = 5,
|
|
||||||
tb_log_name: str = "QRDQN",
|
tb_log_name: str = "QRDQN",
|
||||||
eval_log_path: Optional[str] = None,
|
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> QRDQNSelf:
|
) -> QRDQNSelf:
|
||||||
|
|
@ -269,11 +260,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
log_interval=log_interval,
|
log_interval=log_interval,
|
||||||
eval_env=eval_env,
|
|
||||||
eval_freq=eval_freq,
|
|
||||||
n_eval_episodes=n_eval_episodes,
|
|
||||||
tb_log_name=tb_log_name,
|
tb_log_name=tb_log_name,
|
||||||
eval_log_path=eval_log_path,
|
|
||||||
reset_num_timesteps=reset_num_timesteps,
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
progress_bar=progress_bar,
|
progress_bar=progress_bar,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
|
@ -37,9 +36,6 @@ class Actor(BasePolicy):
|
||||||
:param log_std_init: Initial value for the log standard deviation
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
:param full_std: Whether to use (n_features x n_actions) parameters
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||||
for the std instead of only (n_features,) when using gSDE.
|
for the std instead of only (n_features,) when using gSDE.
|
||||||
:param sde_net_arch: Network architecture for extracting features
|
|
||||||
when using gSDE. If None, the latent features from the policy will be used.
|
|
||||||
Pass an empty list to use the states as features.
|
|
||||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||||
a positive standard deviation (cf paper). It allows to keep variance
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
|
@ -59,7 +55,6 @@ class Actor(BasePolicy):
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
log_std_init: float = -3,
|
log_std_init: float = -3,
|
||||||
full_std: bool = True,
|
full_std: bool = True,
|
||||||
sde_net_arch: Optional[List[int]] = None,
|
|
||||||
use_expln: bool = False,
|
use_expln: bool = False,
|
||||||
clip_mean: float = 2.0,
|
clip_mean: float = 2.0,
|
||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
|
|
@ -79,7 +74,6 @@ class Actor(BasePolicy):
|
||||||
self.features_dim = features_dim
|
self.features_dim = features_dim
|
||||||
self.activation_fn = activation_fn
|
self.activation_fn = activation_fn
|
||||||
self.log_std_init = log_std_init
|
self.log_std_init = log_std_init
|
||||||
self.sde_net_arch = sde_net_arch
|
|
||||||
self.use_expln = use_expln
|
self.use_expln = use_expln
|
||||||
self.full_std = full_std
|
self.full_std = full_std
|
||||||
self.clip_mean = clip_mean
|
self.clip_mean = clip_mean
|
||||||
|
|
@ -89,9 +83,6 @@ class Actor(BasePolicy):
|
||||||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||||
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
|
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
|
||||||
|
|
||||||
if sde_net_arch is not None:
|
|
||||||
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
|
||||||
|
|
||||||
if self.use_sde:
|
if self.use_sde:
|
||||||
self.action_dist = StateDependentNoiseDistribution(
|
self.action_dist = StateDependentNoiseDistribution(
|
||||||
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
|
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
|
||||||
|
|
@ -256,9 +247,6 @@ class TQCPolicy(BasePolicy):
|
||||||
:param activation_fn: Activation function
|
:param activation_fn: Activation function
|
||||||
:param use_sde: Whether to use State Dependent Exploration or not
|
:param use_sde: Whether to use State Dependent Exploration or not
|
||||||
:param log_std_init: Initial value for the log standard deviation
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
:param sde_net_arch: Network architecture for extracting features
|
|
||||||
when using gSDE. If None, the latent features from the policy will be used.
|
|
||||||
Pass an empty list to use the states as features.
|
|
||||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||||
a positive standard deviation (cf paper). It allows to keep variance
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
|
@ -287,7 +275,6 @@ class TQCPolicy(BasePolicy):
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
log_std_init: float = -3,
|
log_std_init: float = -3,
|
||||||
sde_net_arch: Optional[List[int]] = None,
|
|
||||||
use_expln: bool = False,
|
use_expln: bool = False,
|
||||||
clip_mean: float = 2.0,
|
clip_mean: float = 2.0,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||||
|
|
@ -325,9 +312,6 @@ class TQCPolicy(BasePolicy):
|
||||||
}
|
}
|
||||||
self.actor_kwargs = self.net_args.copy()
|
self.actor_kwargs = self.net_args.copy()
|
||||||
|
|
||||||
if sde_net_arch is not None:
|
|
||||||
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
|
||||||
|
|
||||||
sde_kwargs = {
|
sde_kwargs = {
|
||||||
"use_sde": use_sde,
|
"use_sde": use_sde,
|
||||||
"log_std_init": log_std_init,
|
"log_std_init": log_std_init,
|
||||||
|
|
@ -442,9 +426,6 @@ class CnnPolicy(TQCPolicy):
|
||||||
:param activation_fn: Activation function
|
:param activation_fn: Activation function
|
||||||
:param use_sde: Whether to use State Dependent Exploration or not
|
:param use_sde: Whether to use State Dependent Exploration or not
|
||||||
:param log_std_init: Initial value for the log standard deviation
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
:param sde_net_arch: Network architecture for extracting features
|
|
||||||
when using gSDE. If None, the latent features from the policy will be used.
|
|
||||||
Pass an empty list to use the states as features.
|
|
||||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||||
a positive standard deviation (cf paper). It allows to keep variance
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
|
@ -471,7 +452,6 @@ class CnnPolicy(TQCPolicy):
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
log_std_init: float = -3,
|
log_std_init: float = -3,
|
||||||
sde_net_arch: Optional[List[int]] = None,
|
|
||||||
use_expln: bool = False,
|
use_expln: bool = False,
|
||||||
clip_mean: float = 2.0,
|
clip_mean: float = 2.0,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||||
|
|
@ -491,7 +471,6 @@ class CnnPolicy(TQCPolicy):
|
||||||
activation_fn,
|
activation_fn,
|
||||||
use_sde,
|
use_sde,
|
||||||
log_std_init,
|
log_std_init,
|
||||||
sde_net_arch,
|
|
||||||
use_expln,
|
use_expln,
|
||||||
clip_mean,
|
clip_mean,
|
||||||
features_extractor_class,
|
features_extractor_class,
|
||||||
|
|
@ -516,9 +495,6 @@ class MultiInputPolicy(TQCPolicy):
|
||||||
:param activation_fn: Activation function
|
:param activation_fn: Activation function
|
||||||
:param use_sde: Whether to use State Dependent Exploration or not
|
:param use_sde: Whether to use State Dependent Exploration or not
|
||||||
:param log_std_init: Initial value for the log standard deviation
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
:param sde_net_arch: Network architecture for extracting features
|
|
||||||
when using gSDE. If None, the latent features from the policy will be used.
|
|
||||||
Pass an empty list to use the states as features.
|
|
||||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||||
a positive standard deviation (cf paper). It allows to keep variance
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
|
@ -545,7 +521,6 @@ class MultiInputPolicy(TQCPolicy):
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
log_std_init: float = -3,
|
log_std_init: float = -3,
|
||||||
sde_net_arch: Optional[List[int]] = None,
|
|
||||||
use_expln: bool = False,
|
use_expln: bool = False,
|
||||||
clip_mean: float = 2.0,
|
clip_mean: float = 2.0,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||||
|
|
@ -565,7 +540,6 @@ class MultiInputPolicy(TQCPolicy):
|
||||||
activation_fn,
|
activation_fn,
|
||||||
use_sde,
|
use_sde,
|
||||||
log_std_init,
|
log_std_init,
|
||||||
sde_net_arch,
|
|
||||||
use_expln,
|
use_expln,
|
||||||
clip_mean,
|
clip_mean,
|
||||||
features_extractor_class,
|
features_extractor_class,
|
||||||
|
|
|
||||||
|
|
@ -57,9 +57,6 @@ class TQC(OffPolicyAlgorithm):
|
||||||
Default: -1 (only sample at the beginning of the rollout)
|
Default: -1 (only sample at the beginning of the rollout)
|
||||||
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
|
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
|
||||||
during the warm up phase (before learning starts)
|
during the warm up phase (before learning starts)
|
||||||
:param create_eval_env: Whether to create a second environment that will be
|
|
||||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
:param seed: Seed for the pseudo random generators
|
:param seed: Seed for the pseudo random generators
|
||||||
|
|
@ -98,7 +95,6 @@ class TQC(OffPolicyAlgorithm):
|
||||||
sde_sample_freq: int = -1,
|
sde_sample_freq: int = -1,
|
||||||
use_sde_at_warmup: bool = False,
|
use_sde_at_warmup: bool = False,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
create_eval_env: bool = False,
|
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
|
@ -124,7 +120,6 @@ class TQC(OffPolicyAlgorithm):
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
device=device,
|
device=device,
|
||||||
create_eval_env=create_eval_env,
|
|
||||||
seed=seed,
|
seed=seed,
|
||||||
use_sde=use_sde,
|
use_sde=use_sde,
|
||||||
sde_sample_freq=sde_sample_freq,
|
sde_sample_freq=sde_sample_freq,
|
||||||
|
|
@ -293,11 +288,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 4,
|
log_interval: int = 4,
|
||||||
eval_env: Optional[GymEnv] = None,
|
|
||||||
eval_freq: int = -1,
|
|
||||||
n_eval_episodes: int = 5,
|
|
||||||
tb_log_name: str = "TQC",
|
tb_log_name: str = "TQC",
|
||||||
eval_log_path: Optional[str] = None,
|
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> TQCSelf:
|
) -> TQCSelf:
|
||||||
|
|
@ -306,11 +297,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
log_interval=log_interval,
|
log_interval=log_interval,
|
||||||
eval_env=eval_env,
|
|
||||||
eval_freq=eval_freq,
|
|
||||||
n_eval_episodes=n_eval_episodes,
|
|
||||||
tb_log_name=tb_log_name,
|
tb_log_name=tb_log_name,
|
||||||
eval_log_path=eval_log_path,
|
|
||||||
reset_num_timesteps=reset_num_timesteps,
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
progress_bar=progress_bar,
|
progress_bar=progress_bar,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -59,9 +59,6 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
:param sub_sampling_factor: Sub-sample the batch to make computation faster
|
:param sub_sampling_factor: Sub-sample the batch to make computation faster
|
||||||
see p40-42 of John Schulman thesis http://joschu.net/docs/thesis.pdf
|
see p40-42 of John Schulman thesis http://joschu.net/docs/thesis.pdf
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
:param create_eval_env: Whether to create a second environment that will be
|
|
||||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
:param seed: Seed for the pseudo random generators
|
:param seed: Seed for the pseudo random generators
|
||||||
|
|
@ -96,7 +93,6 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
target_kl: float = 0.01,
|
target_kl: float = 0.01,
|
||||||
sub_sampling_factor: int = 1,
|
sub_sampling_factor: int = 1,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
create_eval_env: bool = False,
|
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
|
@ -120,7 +116,6 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
device=device,
|
device=device,
|
||||||
create_eval_env=create_eval_env,
|
|
||||||
seed=seed,
|
seed=seed,
|
||||||
_init_setup_model=False,
|
_init_setup_model=False,
|
||||||
supported_action_spaces=(
|
supported_action_spaces=(
|
||||||
|
|
@ -409,11 +404,7 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 1,
|
log_interval: int = 1,
|
||||||
eval_env: Optional[GymEnv] = None,
|
|
||||||
eval_freq: int = -1,
|
|
||||||
n_eval_episodes: int = 5,
|
|
||||||
tb_log_name: str = "TRPO",
|
tb_log_name: str = "TRPO",
|
||||||
eval_log_path: Optional[str] = None,
|
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> TRPOSelf:
|
) -> TRPOSelf:
|
||||||
|
|
@ -422,11 +413,7 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
log_interval=log_interval,
|
log_interval=log_interval,
|
||||||
eval_env=eval_env,
|
|
||||||
eval_freq=eval_freq,
|
|
||||||
n_eval_episodes=n_eval_episodes,
|
|
||||||
tb_log_name=tb_log_name,
|
tb_log_name=tb_log_name,
|
||||||
eval_log_path=eval_log_path,
|
|
||||||
reset_num_timesteps=reset_num_timesteps,
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
progress_bar=progress_bar,
|
progress_bar=progress_bar,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.6.2
|
1.7.0a0
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.6.2",
|
"stable_baselines3>=1.7.0a0",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -64,15 +64,6 @@ def test_bootstraping():
|
||||||
model.learn(128)
|
model.learn(128)
|
||||||
|
|
||||||
|
|
||||||
def test_eval_env():
|
|
||||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
|
||||||
eval_env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
|
||||||
model = MaskablePPO("MlpPolicy", env, clip_range_vf=0.2, n_steps=32, seed=8)
|
|
||||||
with pytest.warns(DeprecationWarning): # `eval_env` is deprecated
|
|
||||||
model.learn(32, eval_env=eval_env, eval_freq=16)
|
|
||||||
model.learn(32, reset_num_timesteps=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_supports_discrete_action_space():
|
def test_supports_discrete_action_space():
|
||||||
"""
|
"""
|
||||||
No errors using algorithm with an env that has a discrete action space
|
No errors using algorithm with an env that has a discrete action space
|
||||||
|
|
|
||||||
|
|
@ -129,32 +129,28 @@ def test_check():
|
||||||
|
|
||||||
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
||||||
def test_run(env):
|
def test_run(env):
|
||||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
model = RecurrentPPO(
|
||||||
model = RecurrentPPO(
|
"MlpLstmPolicy",
|
||||||
"MlpLstmPolicy",
|
env,
|
||||||
env,
|
n_steps=16,
|
||||||
n_steps=16,
|
seed=0,
|
||||||
seed=0,
|
)
|
||||||
create_eval_env=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
model.learn(total_timesteps=32, eval_freq=16)
|
model.learn(total_timesteps=32)
|
||||||
|
|
||||||
|
|
||||||
def test_run_sde():
|
def test_run_sde():
|
||||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
model = RecurrentPPO(
|
||||||
model = RecurrentPPO(
|
"MlpLstmPolicy",
|
||||||
"MlpLstmPolicy",
|
"Pendulum-v1",
|
||||||
"Pendulum-v1",
|
n_steps=16,
|
||||||
n_steps=16,
|
seed=0,
|
||||||
seed=0,
|
sde_sample_freq=4,
|
||||||
create_eval_env=True,
|
use_sde=True,
|
||||||
sde_sample_freq=4,
|
clip_range_vf=0.1,
|
||||||
use_sde=True,
|
)
|
||||||
clip_range_vf=0.1,
|
|
||||||
)
|
|
||||||
|
|
||||||
model.learn(total_timesteps=200, eval_freq=150)
|
model.learn(total_timesteps=200)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
|
|
@ -10,17 +10,15 @@ from sb3_contrib.common.vec_env import AsyncEval
|
||||||
|
|
||||||
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
||||||
def test_tqc(ent_coef):
|
def test_tqc(ent_coef):
|
||||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
model = TQC(
|
||||||
model = TQC(
|
"MlpPolicy",
|
||||||
"MlpPolicy",
|
"Pendulum-v1",
|
||||||
"Pendulum-v1",
|
policy_kwargs=dict(net_arch=[64, 64]),
|
||||||
policy_kwargs=dict(net_arch=[64, 64]),
|
learning_starts=100,
|
||||||
learning_starts=100,
|
verbose=1,
|
||||||
verbose=1,
|
ent_coef=ent_coef,
|
||||||
create_eval_env=True,
|
)
|
||||||
ent_coef=ent_coef,
|
model.learn(total_timesteps=300, progress_bar=True)
|
||||||
)
|
|
||||||
model.learn(total_timesteps=300, eval_freq=250, progress_bar=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_critics", [1, 3])
|
@pytest.mark.parametrize("n_critics", [1, 3])
|
||||||
|
|
@ -51,18 +49,16 @@ def test_sde():
|
||||||
|
|
||||||
|
|
||||||
def test_qrdqn():
|
def test_qrdqn():
|
||||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
model = QRDQN(
|
||||||
model = QRDQN(
|
"MlpPolicy",
|
||||||
"MlpPolicy",
|
"CartPole-v1",
|
||||||
"CartPole-v1",
|
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
|
||||||
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
|
learning_starts=100,
|
||||||
learning_starts=100,
|
buffer_size=500,
|
||||||
buffer_size=500,
|
learning_rate=3e-4,
|
||||||
learning_rate=3e-4,
|
verbose=1,
|
||||||
verbose=1,
|
)
|
||||||
create_eval_env=True,
|
model.learn(total_timesteps=500)
|
||||||
)
|
|
||||||
model.learn(total_timesteps=500, eval_freq=250)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
||||||
|
|
@ -91,8 +87,7 @@ def test_trpo_params():
|
||||||
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
|
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
|
||||||
def test_ars(policy_str, env_id):
|
def test_ars(policy_str, env_id):
|
||||||
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
|
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
|
||||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
model.learn(total_timesteps=500, log_interval=1)
|
||||||
model.learn(total_timesteps=500, log_interval=1, eval_freq=250)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ars_multi_env():
|
def test_ars_multi_env():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue