Remove deprecated features (#108)
* Remove deprecated features * Upgrade SB3 * Fix tests
This commit is contained in:
parent
52795a307e
commit
c75ad7dd58
|
|
@ -3,6 +3,27 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.7.0a0 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
||||
please use an ``EvalCallback`` instead
|
||||
- Removed deprecated ``sde_net_arch`` parameter
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
|
||||
Release 1.6.2 (2022-10-10)
|
||||
--------------------------
|
||||
|
|
|
|||
|
|
@ -309,10 +309,6 @@ class ARS(BaseAlgorithm):
|
|||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
tb_log_name: str = "ARS",
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
async_eval: Optional[AsyncEval] = None,
|
||||
progress_bar: bool = False,
|
||||
|
|
@ -324,14 +320,6 @@ class ARS(BaseAlgorithm):
|
|||
:param callback: callback(s) called at every step with state of the algorithm.
|
||||
:param log_interval: The number of timesteps before logging.
|
||||
:param tb_log_name: the name of the run for TensorBoard logging
|
||||
:param eval_env: Environment to use for evaluation.
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `EvalCallback` or a custom Callback instead.
|
||||
:param n_eval_episodes: Number of episode to evaluate the agent
|
||||
:param eval_log_path: Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
||||
:param async_eval: The object for asynchronous evaluation of candidates.
|
||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||
|
|
@ -340,11 +328,7 @@ class ARS(BaseAlgorithm):
|
|||
|
||||
total_steps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
eval_env,
|
||||
callback,
|
||||
eval_freq,
|
||||
n_eval_episodes,
|
||||
eval_log_path,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
progress_bar,
|
||||
|
|
|
|||
|
|
@ -88,18 +88,21 @@ def evaluate_policy( # noqa: C901
|
|||
current_lengths = np.zeros(n_envs, dtype="int")
|
||||
observations = env.reset()
|
||||
states = None
|
||||
|
||||
episode_starts = np.ones((env.num_envs,), dtype=bool)
|
||||
while (episode_counts < episode_count_targets).any():
|
||||
if use_masking:
|
||||
action_masks = get_action_masks(env)
|
||||
actions, state = model.predict(
|
||||
observations,
|
||||
state=states,
|
||||
episode_start=episode_starts,
|
||||
deterministic=deterministic,
|
||||
action_masks=action_masks,
|
||||
)
|
||||
else:
|
||||
actions, states = model.predict(observations, state=states, deterministic=deterministic)
|
||||
actions, states = model.predict(
|
||||
observations, state=states, episode_start=episode_starts, deterministic=deterministic
|
||||
)
|
||||
observations, rewards, dones, infos = env.step(actions)
|
||||
current_rewards += rewards
|
||||
current_lengths += 1
|
||||
|
|
@ -110,6 +113,7 @@ def evaluate_policy( # noqa: C901
|
|||
reward = rewards[i]
|
||||
done = dones[i]
|
||||
info = infos[i]
|
||||
episode_starts[i] = done
|
||||
|
||||
if callback is not None:
|
||||
callback(locals(), globals())
|
||||
|
|
@ -133,8 +137,6 @@ def evaluate_policy( # noqa: C901
|
|||
episode_counts[i] += 1
|
||||
current_rewards[i] = 0
|
||||
current_lengths[i] = 0
|
||||
if states is not None:
|
||||
states[i] *= 0
|
||||
|
||||
if render:
|
||||
env.render()
|
||||
|
|
|
|||
|
|
@ -36,9 +36,6 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
|
|
@ -74,7 +71,6 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
|
|
@ -99,7 +95,6 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
use_sde,
|
||||
log_std_init,
|
||||
full_std,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
squash_output,
|
||||
features_extractor_class,
|
||||
|
|
@ -436,9 +431,6 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
|||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
|
|
@ -473,7 +465,6 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
|||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
|
|
@ -497,7 +488,6 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
|||
use_sde,
|
||||
log_std_init,
|
||||
full_std,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
squash_output,
|
||||
features_extractor_class,
|
||||
|
|
@ -528,9 +518,6 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
|||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
|
|
@ -565,7 +552,6 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
|||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
|
|
@ -589,7 +575,6 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
|||
use_sde,
|
||||
log_std_init,
|
||||
full_std,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
squash_output,
|
||||
features_extractor_class,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from collections import deque
|
||||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
|
|
@ -61,9 +60,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||
By default, there is no limit on the kl div.
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param seed: Seed for the pseudo random generators
|
||||
|
|
@ -96,7 +92,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
max_grad_norm: float = 0.5,
|
||||
target_kl: Optional[float] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
|
|
@ -116,7 +111,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
use_sde=False,
|
||||
sde_sample_freq=-1,
|
||||
tensorboard_log=tensorboard_log,
|
||||
create_eval_env=create_eval_env,
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=verbose,
|
||||
seed=seed,
|
||||
|
|
@ -179,23 +173,11 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
def _init_callback(
|
||||
self,
|
||||
callback: MaybeCallback,
|
||||
eval_env: Optional[VecEnv] = None,
|
||||
eval_freq: int = 10000,
|
||||
n_eval_episodes: int = 5,
|
||||
log_path: Optional[str] = None,
|
||||
use_masking: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> BaseCallback:
|
||||
"""
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
:param eval_env: Environment to use for evaluation.
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `MaskableEvalCallback` or a custom Callback instead.
|
||||
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `MaskableEvalCallback` or a custom Callback instead.
|
||||
:param n_eval_episodes: How many episodes to play per evaluation
|
||||
:param log_path: Path to a folder where the evaluations will be saved
|
||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||
:return: A hybrid callback calling `callback` and performing evaluation.
|
||||
|
|
@ -212,33 +194,13 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
if progress_bar:
|
||||
callback = CallbackList([callback, ProgressBarCallback()])
|
||||
|
||||
# Create eval callback in charge of the evaluation
|
||||
if eval_env is not None:
|
||||
# Avoid circular import error
|
||||
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
||||
|
||||
eval_callback = MaskableEvalCallback(
|
||||
eval_env,
|
||||
best_model_save_path=log_path,
|
||||
log_path=log_path,
|
||||
eval_freq=eval_freq,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
use_masking=use_masking,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
callback = CallbackList([callback, eval_callback])
|
||||
|
||||
callback.init_callback(self)
|
||||
return callback
|
||||
|
||||
def _setup_learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
eval_env: Optional[GymEnv],
|
||||
callback: MaybeCallback = None,
|
||||
eval_freq: int = 10000,
|
||||
n_eval_episodes: int = 5,
|
||||
log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
tb_log_name: str = "run",
|
||||
use_masking: bool = True,
|
||||
|
|
@ -248,15 +210,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
Initialize different variables needed for training.
|
||||
|
||||
:param total_timesteps: The total number of samples (env steps) to train on
|
||||
:param eval_env: Environment to use for evaluation.
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `MaskableEvalCallback` or a custom Callback instead.
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
Please use `MaskableEvalCallback` or a custom Callback instead.
|
||||
:param n_eval_episodes: How many episodes to play per evaluation
|
||||
:param log_path: Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
||||
:param tb_log_name: the name of the run for tensorboard log
|
||||
:param use_masking: Whether or not to use invalid action masks during training
|
||||
|
|
@ -264,17 +218,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
:return:
|
||||
"""
|
||||
|
||||
if eval_env is not None or eval_freq != -1:
|
||||
warnings.warn(
|
||||
"Parameters `eval_env` and `eval_freq` are deprecated and will be removed in the future. "
|
||||
"Please use `MaskableEvalCallback` or a custom Callback instead.",
|
||||
DeprecationWarning,
|
||||
# By setting the `stacklevel` we refer to the initial caller of the deprecated feature.
|
||||
# This causes the the `DepricationWarning` to not be ignored and to be shown to the user. See
|
||||
# https://github.com/DLR-RM/stable-baselines3/pull/1082#discussion_r989842855 for more details.
|
||||
stacklevel=4,
|
||||
)
|
||||
|
||||
self.start_time = time.time_ns()
|
||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
||||
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
||||
|
|
@ -297,17 +240,12 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
if self._vec_normalize_env is not None:
|
||||
self._last_original_obs = self._vec_normalize_env.get_original_obs()
|
||||
|
||||
if eval_env is not None and self.seed is not None:
|
||||
eval_env.seed(self.seed)
|
||||
|
||||
eval_env = self._get_eval_env(eval_env)
|
||||
|
||||
# Configure logger's outputs if no logger was passed
|
||||
if not self._custom_logger:
|
||||
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
||||
|
||||
# Create eval callback if needed
|
||||
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking, progress_bar)
|
||||
callback = self._init_callback(callback, use_masking, progress_bar)
|
||||
|
||||
return total_timesteps, callback
|
||||
|
||||
|
|
@ -564,11 +502,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "PPO",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
use_masking: bool = True,
|
||||
progress_bar: bool = False,
|
||||
|
|
@ -577,11 +511,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
eval_env,
|
||||
callback,
|
||||
eval_freq,
|
||||
n_eval_episodes,
|
||||
eval_log_path,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
use_masking,
|
||||
|
|
|
|||
|
|
@ -58,9 +58,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||
By default, there is no limit on the kl div.
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param seed: Seed for the pseudo random generators
|
||||
|
|
@ -95,7 +92,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
sde_sample_freq: int = -1,
|
||||
target_kl: Optional[float] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
|
|
@ -115,7 +111,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
tensorboard_log=tensorboard_log,
|
||||
create_eval_env=create_eval_env,
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=verbose,
|
||||
seed=seed,
|
||||
|
|
@ -453,11 +448,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "RecurrentPPO",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> RecurrentPPOSelf:
|
||||
|
|
@ -465,11 +456,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
eval_env,
|
||||
callback,
|
||||
eval_freq,
|
||||
n_eval_episodes,
|
||||
eval_log_path,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
progress_bar,
|
||||
|
|
|
|||
|
|
@ -50,9 +50,6 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
:param exploration_final_eps: final value of random action probability
|
||||
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping)
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param seed: Seed for the pseudo random generators
|
||||
|
|
@ -88,7 +85,6 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
exploration_final_eps: float = 0.01,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
|
|
@ -114,7 +110,6 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
tensorboard_log=tensorboard_log,
|
||||
verbose=verbose,
|
||||
device=device,
|
||||
create_eval_env=create_eval_env,
|
||||
seed=seed,
|
||||
sde_support=False,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
|
|
@ -256,11 +251,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "QRDQN",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> QRDQNSelf:
|
||||
|
|
@ -269,11 +260,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
eval_env=eval_env,
|
||||
eval_freq=eval_freq,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
progress_bar=progress_bar,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
|
|
@ -37,9 +36,6 @@ class Actor(BasePolicy):
|
|||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE.
|
||||
:param sde_net_arch: Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
|
|
@ -59,7 +55,6 @@ class Actor(BasePolicy):
|
|||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
normalize_images: bool = True,
|
||||
|
|
@ -79,7 +74,6 @@ class Actor(BasePolicy):
|
|||
self.features_dim = features_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.log_std_init = log_std_init
|
||||
self.sde_net_arch = sde_net_arch
|
||||
self.use_expln = use_expln
|
||||
self.full_std = full_std
|
||||
self.clip_mean = clip_mean
|
||||
|
|
@ -89,9 +83,6 @@ class Actor(BasePolicy):
|
|||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
|
||||
|
||||
if sde_net_arch is not None:
|
||||
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||
|
||||
if self.use_sde:
|
||||
self.action_dist = StateDependentNoiseDistribution(
|
||||
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
|
||||
|
|
@ -256,9 +247,6 @@ class TQCPolicy(BasePolicy):
|
|||
:param activation_fn: Activation function
|
||||
:param use_sde: Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param sde_net_arch: Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
|
|
@ -287,7 +275,6 @@ class TQCPolicy(BasePolicy):
|
|||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
|
|
@ -325,9 +312,6 @@ class TQCPolicy(BasePolicy):
|
|||
}
|
||||
self.actor_kwargs = self.net_args.copy()
|
||||
|
||||
if sde_net_arch is not None:
|
||||
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||
|
||||
sde_kwargs = {
|
||||
"use_sde": use_sde,
|
||||
"log_std_init": log_std_init,
|
||||
|
|
@ -442,9 +426,6 @@ class CnnPolicy(TQCPolicy):
|
|||
:param activation_fn: Activation function
|
||||
:param use_sde: Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param sde_net_arch: Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
|
|
@ -471,7 +452,6 @@ class CnnPolicy(TQCPolicy):
|
|||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
|
|
@ -491,7 +471,6 @@ class CnnPolicy(TQCPolicy):
|
|||
activation_fn,
|
||||
use_sde,
|
||||
log_std_init,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
clip_mean,
|
||||
features_extractor_class,
|
||||
|
|
@ -516,9 +495,6 @@ class MultiInputPolicy(TQCPolicy):
|
|||
:param activation_fn: Activation function
|
||||
:param use_sde: Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param sde_net_arch: Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
|
|
@ -545,7 +521,6 @@ class MultiInputPolicy(TQCPolicy):
|
|||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
|
|
@ -565,7 +540,6 @@ class MultiInputPolicy(TQCPolicy):
|
|||
activation_fn,
|
||||
use_sde,
|
||||
log_std_init,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
clip_mean,
|
||||
features_extractor_class,
|
||||
|
|
|
|||
|
|
@ -57,9 +57,6 @@ class TQC(OffPolicyAlgorithm):
|
|||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param seed: Seed for the pseudo random generators
|
||||
|
|
@ -98,7 +95,6 @@ class TQC(OffPolicyAlgorithm):
|
|||
sde_sample_freq: int = -1,
|
||||
use_sde_at_warmup: bool = False,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
|
|
@ -124,7 +120,6 @@ class TQC(OffPolicyAlgorithm):
|
|||
tensorboard_log=tensorboard_log,
|
||||
verbose=verbose,
|
||||
device=device,
|
||||
create_eval_env=create_eval_env,
|
||||
seed=seed,
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
|
|
@ -293,11 +288,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "TQC",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> TQCSelf:
|
||||
|
|
@ -306,11 +297,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
eval_env=eval_env,
|
||||
eval_freq=eval_freq,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
progress_bar=progress_bar,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -59,9 +59,6 @@ class TRPO(OnPolicyAlgorithm):
|
|||
:param sub_sampling_factor: Sub-sample the batch to make computation faster
|
||||
see p40-42 of John Schulman thesis http://joschu.net/docs/thesis.pdf
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||
Caution, this parameter is deprecated and will be removed in the future.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param seed: Seed for the pseudo random generators
|
||||
|
|
@ -96,7 +93,6 @@ class TRPO(OnPolicyAlgorithm):
|
|||
target_kl: float = 0.01,
|
||||
sub_sampling_factor: int = 1,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
|
|
@ -120,7 +116,6 @@ class TRPO(OnPolicyAlgorithm):
|
|||
policy_kwargs=policy_kwargs,
|
||||
verbose=verbose,
|
||||
device=device,
|
||||
create_eval_env=create_eval_env,
|
||||
seed=seed,
|
||||
_init_setup_model=False,
|
||||
supported_action_spaces=(
|
||||
|
|
@ -409,11 +404,7 @@ class TRPO(OnPolicyAlgorithm):
|
|||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "TRPO",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> TRPOSelf:
|
||||
|
|
@ -422,11 +413,7 @@ class TRPO(OnPolicyAlgorithm):
|
|||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
eval_env=eval_env,
|
||||
eval_freq=eval_freq,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
progress_bar=progress_bar,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.6.2
|
||||
1.7.0a0
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.6.2",
|
||||
"stable_baselines3>=1.7.0a0",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -64,15 +64,6 @@ def test_bootstraping():
|
|||
model.learn(128)
|
||||
|
||||
|
||||
def test_eval_env():
|
||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
eval_env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
model = MaskablePPO("MlpPolicy", env, clip_range_vf=0.2, n_steps=32, seed=8)
|
||||
with pytest.warns(DeprecationWarning): # `eval_env` is deprecated
|
||||
model.learn(32, eval_env=eval_env, eval_freq=16)
|
||||
model.learn(32, reset_num_timesteps=False)
|
||||
|
||||
|
||||
def test_supports_discrete_action_space():
|
||||
"""
|
||||
No errors using algorithm with an env that has a discrete action space
|
||||
|
|
|
|||
|
|
@ -129,32 +129,28 @@ def test_check():
|
|||
|
||||
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
||||
def test_run(env):
|
||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||
model = RecurrentPPO(
|
||||
"MlpLstmPolicy",
|
||||
env,
|
||||
n_steps=16,
|
||||
seed=0,
|
||||
create_eval_env=True,
|
||||
)
|
||||
model = RecurrentPPO(
|
||||
"MlpLstmPolicy",
|
||||
env,
|
||||
n_steps=16,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=32, eval_freq=16)
|
||||
model.learn(total_timesteps=32)
|
||||
|
||||
|
||||
def test_run_sde():
|
||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||
model = RecurrentPPO(
|
||||
"MlpLstmPolicy",
|
||||
"Pendulum-v1",
|
||||
n_steps=16,
|
||||
seed=0,
|
||||
create_eval_env=True,
|
||||
sde_sample_freq=4,
|
||||
use_sde=True,
|
||||
clip_range_vf=0.1,
|
||||
)
|
||||
model = RecurrentPPO(
|
||||
"MlpLstmPolicy",
|
||||
"Pendulum-v1",
|
||||
n_steps=16,
|
||||
seed=0,
|
||||
sde_sample_freq=4,
|
||||
use_sde=True,
|
||||
clip_range_vf=0.1,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=200, eval_freq=150)
|
||||
model.learn(total_timesteps=200)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
|
|
@ -10,17 +10,15 @@ from sb3_contrib.common.vec_env import AsyncEval
|
|||
|
||||
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
||||
def test_tqc(ent_coef):
|
||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||
model = TQC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v1",
|
||||
policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
ent_coef=ent_coef,
|
||||
)
|
||||
model.learn(total_timesteps=300, eval_freq=250, progress_bar=True)
|
||||
model = TQC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v1",
|
||||
policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
verbose=1,
|
||||
ent_coef=ent_coef,
|
||||
)
|
||||
model.learn(total_timesteps=300, progress_bar=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_critics", [1, 3])
|
||||
|
|
@ -51,18 +49,16 @@ def test_sde():
|
|||
|
||||
|
||||
def test_qrdqn():
|
||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||
model = QRDQN(
|
||||
"MlpPolicy",
|
||||
"CartPole-v1",
|
||||
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
buffer_size=500,
|
||||
learning_rate=3e-4,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
model = QRDQN(
|
||||
"MlpPolicy",
|
||||
"CartPole-v1",
|
||||
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
buffer_size=500,
|
||||
learning_rate=3e-4,
|
||||
verbose=1,
|
||||
)
|
||||
model.learn(total_timesteps=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
||||
|
|
@ -91,8 +87,7 @@ def test_trpo_params():
|
|||
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
|
||||
def test_ars(policy_str, env_id):
|
||||
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
|
||||
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||
model.learn(total_timesteps=500, log_interval=1, eval_freq=250)
|
||||
model.learn(total_timesteps=500, log_interval=1)
|
||||
|
||||
|
||||
def test_ars_multi_env():
|
||||
|
|
|
|||
Loading…
Reference in New Issue