Remove deprecated features (#108)

* Remove deprecated features

* Upgrade SB3

* Fix tests
This commit is contained in:
Antonin RAFFIN 2022-10-11 13:04:18 +02:00 committed by GitHub
parent 52795a307e
commit c75ad7dd58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 67 additions and 241 deletions

View File

@ -3,6 +3,27 @@
Changelog Changelog
========== ==========
Release 1.7.0a0 (WIP)
--------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
please use an ``EvalCallback`` instead
- Removed deprecated ``sde_net_arch`` parameter
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Release 1.6.2 (2022-10-10) Release 1.6.2 (2022-10-10)
-------------------------- --------------------------

View File

@ -309,10 +309,6 @@ class ARS(BaseAlgorithm):
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
tb_log_name: str = "ARS", tb_log_name: str = "ARS",
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
async_eval: Optional[AsyncEval] = None, async_eval: Optional[AsyncEval] = None,
progress_bar: bool = False, progress_bar: bool = False,
@ -324,14 +320,6 @@ class ARS(BaseAlgorithm):
:param callback: callback(s) called at every step with state of the algorithm. :param callback: callback(s) called at every step with state of the algorithm.
:param log_interval: The number of timesteps before logging. :param log_interval: The number of timesteps before logging.
:param tb_log_name: the name of the run for TensorBoard logging :param tb_log_name: the name of the run for TensorBoard logging
:param eval_env: Environment to use for evaluation.
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param n_eval_episodes: Number of episode to evaluate the agent
:param eval_log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
:param async_eval: The object for asynchronous evaluation of candidates. :param async_eval: The object for asynchronous evaluation of candidates.
:param progress_bar: Display a progress bar using tqdm and rich. :param progress_bar: Display a progress bar using tqdm and rich.
@ -340,11 +328,7 @@ class ARS(BaseAlgorithm):
total_steps, callback = self._setup_learn( total_steps, callback = self._setup_learn(
total_timesteps, total_timesteps,
eval_env,
callback, callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps, reset_num_timesteps,
tb_log_name, tb_log_name,
progress_bar, progress_bar,

View File

@ -88,18 +88,21 @@ def evaluate_policy( # noqa: C901
current_lengths = np.zeros(n_envs, dtype="int") current_lengths = np.zeros(n_envs, dtype="int")
observations = env.reset() observations = env.reset()
states = None states = None
episode_starts = np.ones((env.num_envs,), dtype=bool)
while (episode_counts < episode_count_targets).any(): while (episode_counts < episode_count_targets).any():
if use_masking: if use_masking:
action_masks = get_action_masks(env) action_masks = get_action_masks(env)
actions, state = model.predict( actions, state = model.predict(
observations, observations,
state=states, state=states,
episode_start=episode_starts,
deterministic=deterministic, deterministic=deterministic,
action_masks=action_masks, action_masks=action_masks,
) )
else: else:
actions, states = model.predict(observations, state=states, deterministic=deterministic) actions, states = model.predict(
observations, state=states, episode_start=episode_starts, deterministic=deterministic
)
observations, rewards, dones, infos = env.step(actions) observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards current_rewards += rewards
current_lengths += 1 current_lengths += 1
@ -110,6 +113,7 @@ def evaluate_policy( # noqa: C901
reward = rewards[i] reward = rewards[i]
done = dones[i] done = dones[i]
info = infos[i] info = infos[i]
episode_starts[i] = done
if callback is not None: if callback is not None:
callback(locals(), globals()) callback(locals(), globals())
@ -133,8 +137,6 @@ def evaluate_policy( # noqa: C901
episode_counts[i] += 1 episode_counts[i] += 1
current_rewards[i] = 0 current_rewards[i] = 0
current_lengths[i] = 0 current_lengths[i] = 0
if states is not None:
states[i] *= 0
if render: if render:
env.render() env.render()

View File

@ -36,9 +36,6 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
:param log_std_init: Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters :param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@ -74,7 +71,6 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
use_sde: bool = False, use_sde: bool = False,
log_std_init: float = 0.0, log_std_init: float = 0.0,
full_std: bool = True, full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False, use_expln: bool = False,
squash_output: bool = False, squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
@ -99,7 +95,6 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
use_sde, use_sde,
log_std_init, log_std_init,
full_std, full_std,
sde_net_arch,
use_expln, use_expln,
squash_output, squash_output,
features_extractor_class, features_extractor_class,
@ -436,9 +431,6 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
:param log_std_init: Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters :param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@ -473,7 +465,6 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
use_sde: bool = False, use_sde: bool = False,
log_std_init: float = 0.0, log_std_init: float = 0.0,
full_std: bool = True, full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False, use_expln: bool = False,
squash_output: bool = False, squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
@ -497,7 +488,6 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
use_sde, use_sde,
log_std_init, log_std_init,
full_std, full_std,
sde_net_arch,
use_expln, use_expln,
squash_output, squash_output,
features_extractor_class, features_extractor_class,
@ -528,9 +518,6 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
:param log_std_init: Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters :param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@ -565,7 +552,6 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
use_sde: bool = False, use_sde: bool = False,
log_std_init: float = 0.0, log_std_init: float = 0.0,
full_std: bool = True, full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False, use_expln: bool = False,
squash_output: bool = False, squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
@ -589,7 +575,6 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
use_sde, use_sde,
log_std_init, log_std_init,
full_std, full_std,
sde_net_arch,
use_expln, use_expln,
squash_output, squash_output,
features_extractor_class, features_extractor_class,

View File

@ -1,6 +1,5 @@
import sys import sys
import time import time
import warnings
from collections import deque from collections import deque
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
@ -61,9 +60,6 @@ class MaskablePPO(OnPolicyAlgorithm):
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
By default, there is no limit on the kl div. By default, there is no limit on the kl div.
:param tensorboard_log: the log location for tensorboard (if None, no logging) :param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
:param policy_kwargs: additional arguments to be passed to the policy on creation :param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators :param seed: Seed for the pseudo random generators
@ -96,7 +92,6 @@ class MaskablePPO(OnPolicyAlgorithm):
max_grad_norm: float = 0.5, max_grad_norm: float = 0.5,
target_kl: Optional[float] = None, target_kl: Optional[float] = None,
tensorboard_log: Optional[str] = None, tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None, policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0, verbose: int = 0,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -116,7 +111,6 @@ class MaskablePPO(OnPolicyAlgorithm):
use_sde=False, use_sde=False,
sde_sample_freq=-1, sde_sample_freq=-1,
tensorboard_log=tensorboard_log, tensorboard_log=tensorboard_log,
create_eval_env=create_eval_env,
policy_kwargs=policy_kwargs, policy_kwargs=policy_kwargs,
verbose=verbose, verbose=verbose,
seed=seed, seed=seed,
@ -179,23 +173,11 @@ class MaskablePPO(OnPolicyAlgorithm):
def _init_callback( def _init_callback(
self, self,
callback: MaybeCallback, callback: MaybeCallback,
eval_env: Optional[VecEnv] = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
use_masking: bool = True, use_masking: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> BaseCallback: ) -> BaseCallback:
""" """
:param callback: Callback(s) called at every step with state of the algorithm. :param callback: Callback(s) called at every step with state of the algorithm.
:param eval_env: Environment to use for evaluation.
Caution, this parameter is deprecated and will be removed in the future.
Please use `MaskableEvalCallback` or a custom Callback instead.
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
Caution, this parameter is deprecated and will be removed in the future.
Please use `MaskableEvalCallback` or a custom Callback instead.
:param n_eval_episodes: How many episodes to play per evaluation
:param log_path: Path to a folder where the evaluations will be saved
:param use_masking: Whether or not to use invalid action masks during evaluation :param use_masking: Whether or not to use invalid action masks during evaluation
:param progress_bar: Display a progress bar using tqdm and rich. :param progress_bar: Display a progress bar using tqdm and rich.
:return: A hybrid callback calling `callback` and performing evaluation. :return: A hybrid callback calling `callback` and performing evaluation.
@ -212,33 +194,13 @@ class MaskablePPO(OnPolicyAlgorithm):
if progress_bar: if progress_bar:
callback = CallbackList([callback, ProgressBarCallback()]) callback = CallbackList([callback, ProgressBarCallback()])
# Create eval callback in charge of the evaluation
if eval_env is not None:
# Avoid circular import error
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
eval_callback = MaskableEvalCallback(
eval_env,
best_model_save_path=log_path,
log_path=log_path,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
use_masking=use_masking,
verbose=self.verbose,
)
callback = CallbackList([callback, eval_callback])
callback.init_callback(self) callback.init_callback(self)
return callback return callback
def _setup_learn( def _setup_learn(
self, self,
total_timesteps: int, total_timesteps: int,
eval_env: Optional[GymEnv],
callback: MaybeCallback = None, callback: MaybeCallback = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
tb_log_name: str = "run", tb_log_name: str = "run",
use_masking: bool = True, use_masking: bool = True,
@ -248,15 +210,7 @@ class MaskablePPO(OnPolicyAlgorithm):
Initialize different variables needed for training. Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on :param total_timesteps: The total number of samples (env steps) to train on
:param eval_env: Environment to use for evaluation.
Caution, this parameter is deprecated and will be removed in the future.
Please use `MaskableEvalCallback` or a custom Callback instead.
:param callback: Callback(s) called at every step with state of the algorithm. :param callback: Callback(s) called at every step with state of the algorithm.
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
Caution, this parameter is deprecated and will be removed in the future.
Please use `MaskableEvalCallback` or a custom Callback instead.
:param n_eval_episodes: How many episodes to play per evaluation
:param log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log :param tb_log_name: the name of the run for tensorboard log
:param use_masking: Whether or not to use invalid action masks during training :param use_masking: Whether or not to use invalid action masks during training
@ -264,17 +218,6 @@ class MaskablePPO(OnPolicyAlgorithm):
:return: :return:
""" """
if eval_env is not None or eval_freq != -1:
warnings.warn(
"Parameters `eval_env` and `eval_freq` are deprecated and will be removed in the future. "
"Please use `MaskableEvalCallback` or a custom Callback instead.",
DeprecationWarning,
# By setting the `stacklevel` we refer to the initial caller of the deprecated feature.
# This causes the the `DepricationWarning` to not be ignored and to be shown to the user. See
# https://github.com/DLR-RM/stable-baselines3/pull/1082#discussion_r989842855 for more details.
stacklevel=4,
)
self.start_time = time.time_ns() self.start_time = time.time_ns()
if self.ep_info_buffer is None or reset_num_timesteps: if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters # Initialize buffers if they don't exist, or reinitialize if resetting counters
@ -297,17 +240,12 @@ class MaskablePPO(OnPolicyAlgorithm):
if self._vec_normalize_env is not None: if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs() self._last_original_obs = self._vec_normalize_env.get_original_obs()
if eval_env is not None and self.seed is not None:
eval_env.seed(self.seed)
eval_env = self._get_eval_env(eval_env)
# Configure logger's outputs if no logger was passed # Configure logger's outputs if no logger was passed
if not self._custom_logger: if not self._custom_logger:
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps) self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
# Create eval callback if needed # Create eval callback if needed
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking, progress_bar) callback = self._init_callback(callback, use_masking, progress_bar)
return total_timesteps, callback return total_timesteps, callback
@ -564,11 +502,7 @@ class MaskablePPO(OnPolicyAlgorithm):
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "PPO", tb_log_name: str = "PPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
use_masking: bool = True, use_masking: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
@ -577,11 +511,7 @@ class MaskablePPO(OnPolicyAlgorithm):
total_timesteps, callback = self._setup_learn( total_timesteps, callback = self._setup_learn(
total_timesteps, total_timesteps,
eval_env,
callback, callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps, reset_num_timesteps,
tb_log_name, tb_log_name,
use_masking, use_masking,

View File

@ -58,9 +58,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
By default, there is no limit on the kl div. By default, there is no limit on the kl div.
:param tensorboard_log: the log location for tensorboard (if None, no logging) :param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
:param policy_kwargs: additional arguments to be passed to the policy on creation :param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators :param seed: Seed for the pseudo random generators
@ -95,7 +92,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
sde_sample_freq: int = -1, sde_sample_freq: int = -1,
target_kl: Optional[float] = None, target_kl: Optional[float] = None,
tensorboard_log: Optional[str] = None, tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None, policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0, verbose: int = 0,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -115,7 +111,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
use_sde=use_sde, use_sde=use_sde,
sde_sample_freq=sde_sample_freq, sde_sample_freq=sde_sample_freq,
tensorboard_log=tensorboard_log, tensorboard_log=tensorboard_log,
create_eval_env=create_eval_env,
policy_kwargs=policy_kwargs, policy_kwargs=policy_kwargs,
verbose=verbose, verbose=verbose,
seed=seed, seed=seed,
@ -453,11 +448,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "RecurrentPPO", tb_log_name: str = "RecurrentPPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> RecurrentPPOSelf: ) -> RecurrentPPOSelf:
@ -465,11 +456,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
total_timesteps, callback = self._setup_learn( total_timesteps, callback = self._setup_learn(
total_timesteps, total_timesteps,
eval_env,
callback, callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps, reset_num_timesteps,
tb_log_name, tb_log_name,
progress_bar, progress_bar,

View File

@ -50,9 +50,6 @@ class QRDQN(OffPolicyAlgorithm):
:param exploration_final_eps: final value of random action probability :param exploration_final_eps: final value of random action probability
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping) :param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping)
:param tensorboard_log: the log location for tensorboard (if None, no logging) :param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
:param policy_kwargs: additional arguments to be passed to the policy on creation :param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators :param seed: Seed for the pseudo random generators
@ -88,7 +85,6 @@ class QRDQN(OffPolicyAlgorithm):
exploration_final_eps: float = 0.01, exploration_final_eps: float = 0.01,
max_grad_norm: Optional[float] = None, max_grad_norm: Optional[float] = None,
tensorboard_log: Optional[str] = None, tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None, policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0, verbose: int = 0,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -114,7 +110,6 @@ class QRDQN(OffPolicyAlgorithm):
tensorboard_log=tensorboard_log, tensorboard_log=tensorboard_log,
verbose=verbose, verbose=verbose,
device=device, device=device,
create_eval_env=create_eval_env,
seed=seed, seed=seed,
sde_support=False, sde_support=False,
optimize_memory_usage=optimize_memory_usage, optimize_memory_usage=optimize_memory_usage,
@ -256,11 +251,7 @@ class QRDQN(OffPolicyAlgorithm):
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 4, log_interval: int = 4,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "QRDQN", tb_log_name: str = "QRDQN",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> QRDQNSelf: ) -> QRDQNSelf:
@ -269,11 +260,7 @@ class QRDQN(OffPolicyAlgorithm):
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
callback=callback, callback=callback,
log_interval=log_interval, log_interval=log_interval,
eval_env=eval_env,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name, tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps, reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar, progress_bar=progress_bar,
) )

View File

@ -1,4 +1,3 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym import gym
@ -37,9 +36,6 @@ class Actor(BasePolicy):
:param log_std_init: Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters :param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE. for the std instead of only (n_features,) when using gSDE.
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@ -59,7 +55,6 @@ class Actor(BasePolicy):
use_sde: bool = False, use_sde: bool = False,
log_std_init: float = -3, log_std_init: float = -3,
full_std: bool = True, full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False, use_expln: bool = False,
clip_mean: float = 2.0, clip_mean: float = 2.0,
normalize_images: bool = True, normalize_images: bool = True,
@ -79,7 +74,6 @@ class Actor(BasePolicy):
self.features_dim = features_dim self.features_dim = features_dim
self.activation_fn = activation_fn self.activation_fn = activation_fn
self.log_std_init = log_std_init self.log_std_init = log_std_init
self.sde_net_arch = sde_net_arch
self.use_expln = use_expln self.use_expln = use_expln
self.full_std = full_std self.full_std = full_std
self.clip_mean = clip_mean self.clip_mean = clip_mean
@ -89,9 +83,6 @@ class Actor(BasePolicy):
self.latent_pi = nn.Sequential(*latent_pi_net) self.latent_pi = nn.Sequential(*latent_pi_net)
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
if self.use_sde: if self.use_sde:
self.action_dist = StateDependentNoiseDistribution( self.action_dist = StateDependentNoiseDistribution(
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
@ -256,9 +247,6 @@ class TQCPolicy(BasePolicy):
:param activation_fn: Activation function :param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not :param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@ -287,7 +275,6 @@ class TQCPolicy(BasePolicy):
activation_fn: Type[nn.Module] = nn.ReLU, activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False, use_sde: bool = False,
log_std_init: float = -3, log_std_init: float = -3,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False, use_expln: bool = False,
clip_mean: float = 2.0, clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
@ -325,9 +312,6 @@ class TQCPolicy(BasePolicy):
} }
self.actor_kwargs = self.net_args.copy() self.actor_kwargs = self.net_args.copy()
if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
sde_kwargs = { sde_kwargs = {
"use_sde": use_sde, "use_sde": use_sde,
"log_std_init": log_std_init, "log_std_init": log_std_init,
@ -442,9 +426,6 @@ class CnnPolicy(TQCPolicy):
:param activation_fn: Activation function :param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not :param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@ -471,7 +452,6 @@ class CnnPolicy(TQCPolicy):
activation_fn: Type[nn.Module] = nn.ReLU, activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False, use_sde: bool = False,
log_std_init: float = -3, log_std_init: float = -3,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False, use_expln: bool = False,
clip_mean: float = 2.0, clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
@ -491,7 +471,6 @@ class CnnPolicy(TQCPolicy):
activation_fn, activation_fn,
use_sde, use_sde,
log_std_init, log_std_init,
sde_net_arch,
use_expln, use_expln,
clip_mean, clip_mean,
features_extractor_class, features_extractor_class,
@ -516,9 +495,6 @@ class MultiInputPolicy(TQCPolicy):
:param activation_fn: Activation function :param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not :param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@ -545,7 +521,6 @@ class MultiInputPolicy(TQCPolicy):
activation_fn: Type[nn.Module] = nn.ReLU, activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False, use_sde: bool = False,
log_std_init: float = -3, log_std_init: float = -3,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False, use_expln: bool = False,
clip_mean: float = 2.0, clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
@ -565,7 +540,6 @@ class MultiInputPolicy(TQCPolicy):
activation_fn, activation_fn,
use_sde, use_sde,
log_std_init, log_std_init,
sde_net_arch,
use_expln, use_expln,
clip_mean, clip_mean,
features_extractor_class, features_extractor_class,

View File

@ -57,9 +57,6 @@ class TQC(OffPolicyAlgorithm):
Default: -1 (only sample at the beginning of the rollout) Default: -1 (only sample at the beginning of the rollout)
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts) during the warm up phase (before learning starts)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
:param policy_kwargs: additional arguments to be passed to the policy on creation :param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators :param seed: Seed for the pseudo random generators
@ -98,7 +95,6 @@ class TQC(OffPolicyAlgorithm):
sde_sample_freq: int = -1, sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False, use_sde_at_warmup: bool = False,
tensorboard_log: Optional[str] = None, tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None, policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0, verbose: int = 0,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -124,7 +120,6 @@ class TQC(OffPolicyAlgorithm):
tensorboard_log=tensorboard_log, tensorboard_log=tensorboard_log,
verbose=verbose, verbose=verbose,
device=device, device=device,
create_eval_env=create_eval_env,
seed=seed, seed=seed,
use_sde=use_sde, use_sde=use_sde,
sde_sample_freq=sde_sample_freq, sde_sample_freq=sde_sample_freq,
@ -293,11 +288,7 @@ class TQC(OffPolicyAlgorithm):
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 4, log_interval: int = 4,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "TQC", tb_log_name: str = "TQC",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> TQCSelf: ) -> TQCSelf:
@ -306,11 +297,7 @@ class TQC(OffPolicyAlgorithm):
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
callback=callback, callback=callback,
log_interval=log_interval, log_interval=log_interval,
eval_env=eval_env,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name, tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps, reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar, progress_bar=progress_bar,
) )

View File

@ -59,9 +59,6 @@ class TRPO(OnPolicyAlgorithm):
:param sub_sampling_factor: Sub-sample the batch to make computation faster :param sub_sampling_factor: Sub-sample the batch to make computation faster
see p40-42 of John Schulman thesis http://joschu.net/docs/thesis.pdf see p40-42 of John Schulman thesis http://joschu.net/docs/thesis.pdf
:param tensorboard_log: the log location for tensorboard (if None, no logging) :param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically (Only available when passing string for the environment).
Caution, this parameter is deprecated and will be removed in the future.
:param policy_kwargs: additional arguments to be passed to the policy on creation :param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators :param seed: Seed for the pseudo random generators
@ -96,7 +93,6 @@ class TRPO(OnPolicyAlgorithm):
target_kl: float = 0.01, target_kl: float = 0.01,
sub_sampling_factor: int = 1, sub_sampling_factor: int = 1,
tensorboard_log: Optional[str] = None, tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None, policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0, verbose: int = 0,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -120,7 +116,6 @@ class TRPO(OnPolicyAlgorithm):
policy_kwargs=policy_kwargs, policy_kwargs=policy_kwargs,
verbose=verbose, verbose=verbose,
device=device, device=device,
create_eval_env=create_eval_env,
seed=seed, seed=seed,
_init_setup_model=False, _init_setup_model=False,
supported_action_spaces=( supported_action_spaces=(
@ -409,11 +404,7 @@ class TRPO(OnPolicyAlgorithm):
total_timesteps: int, total_timesteps: int,
callback: MaybeCallback = None, callback: MaybeCallback = None,
log_interval: int = 1, log_interval: int = 1,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "TRPO", tb_log_name: str = "TRPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> TRPOSelf: ) -> TRPOSelf:
@ -422,11 +413,7 @@ class TRPO(OnPolicyAlgorithm):
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
callback=callback, callback=callback,
log_interval=log_interval, log_interval=log_interval,
eval_env=eval_env,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name, tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps, reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar, progress_bar=progress_bar,
) )

View File

@ -1 +1 @@
1.6.2 1.7.0a0

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=1.6.2", "stable_baselines3>=1.7.0a0",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

View File

@ -64,15 +64,6 @@ def test_bootstraping():
model.learn(128) model.learn(128)
def test_eval_env():
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
eval_env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
model = MaskablePPO("MlpPolicy", env, clip_range_vf=0.2, n_steps=32, seed=8)
with pytest.warns(DeprecationWarning): # `eval_env` is deprecated
model.learn(32, eval_env=eval_env, eval_freq=16)
model.learn(32, reset_num_timesteps=False)
def test_supports_discrete_action_space(): def test_supports_discrete_action_space():
""" """
No errors using algorithm with an env that has a discrete action space No errors using algorithm with an env that has a discrete action space

View File

@ -129,32 +129,28 @@ def test_check():
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"]) @pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
def test_run(env): def test_run(env):
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated model = RecurrentPPO(
model = RecurrentPPO( "MlpLstmPolicy",
"MlpLstmPolicy", env,
env, n_steps=16,
n_steps=16, seed=0,
seed=0, )
create_eval_env=True,
)
model.learn(total_timesteps=32, eval_freq=16) model.learn(total_timesteps=32)
def test_run_sde(): def test_run_sde():
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated model = RecurrentPPO(
model = RecurrentPPO( "MlpLstmPolicy",
"MlpLstmPolicy", "Pendulum-v1",
"Pendulum-v1", n_steps=16,
n_steps=16, seed=0,
seed=0, sde_sample_freq=4,
create_eval_env=True, use_sde=True,
sde_sample_freq=4, clip_range_vf=0.1,
use_sde=True, )
clip_range_vf=0.1,
)
model.learn(total_timesteps=200, eval_freq=150) model.learn(total_timesteps=200)
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -10,17 +10,15 @@ from sb3_contrib.common.vec_env import AsyncEval
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
def test_tqc(ent_coef): def test_tqc(ent_coef):
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated model = TQC(
model = TQC( "MlpPolicy",
"MlpPolicy", "Pendulum-v1",
"Pendulum-v1", policy_kwargs=dict(net_arch=[64, 64]),
policy_kwargs=dict(net_arch=[64, 64]), learning_starts=100,
learning_starts=100, verbose=1,
verbose=1, ent_coef=ent_coef,
create_eval_env=True, )
ent_coef=ent_coef, model.learn(total_timesteps=300, progress_bar=True)
)
model.learn(total_timesteps=300, eval_freq=250, progress_bar=True)
@pytest.mark.parametrize("n_critics", [1, 3]) @pytest.mark.parametrize("n_critics", [1, 3])
@ -51,18 +49,16 @@ def test_sde():
def test_qrdqn(): def test_qrdqn():
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated model = QRDQN(
model = QRDQN( "MlpPolicy",
"MlpPolicy", "CartPole-v1",
"CartPole-v1", policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]), learning_starts=100,
learning_starts=100, buffer_size=500,
buffer_size=500, learning_rate=3e-4,
learning_rate=3e-4, verbose=1,
verbose=1, )
create_eval_env=True, model.learn(total_timesteps=500)
)
model.learn(total_timesteps=500, eval_freq=250)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
@ -91,8 +87,7 @@ def test_trpo_params():
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"]) @pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
def test_ars(policy_str, env_id): def test_ars(policy_str, env_id):
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0) model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated model.learn(total_timesteps=500, log_interval=1)
model.learn(total_timesteps=500, log_interval=1, eval_freq=250)
def test_ars_multi_env(): def test_ars_multi_env():