Cleanup TQC

This commit is contained in:
Antonin RAFFIN 2020-10-12 19:50:08 +02:00
parent 99fe824f76
commit 7609c87e84
6 changed files with 113 additions and 377 deletions

View File

@ -18,3 +18,19 @@ Implemented:
``` ```
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
``` ```
## Citing the Project
To cite this repository in publications (please cite SB3 directly):
```
@misc{stable-baselines3,
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
title = {Stable Baselines3},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
}
```

View File

@ -17,25 +17,25 @@ class Actor(BasePolicy):
""" """
Actor network (policy) for TQC. Actor network (policy) for TQC.
:param observation_space: (gym.spaces.Space) Obervation space :param observation_space: Obervation space
:param action_space: (gym.spaces.Space) Action space :param action_space: Action space
:param net_arch: ([int]) Network architecture :param net_arch: Network architecture
:param features_extractor: (nn.Module) Network to extract features :param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise) (a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: (int) Number of features :param features_dim: Number of features
:param activation_fn: (Type[nn.Module]) Activation function :param activation_fn: Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not :param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param full_std: (bool) Whether to use (n_features x n_actions) parameters :param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE. for the std instead of only (n_features,) when using gSDE.
:param sde_net_arch: ([int]) Network architecture for extracting features :param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used. when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features. Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param normalize_images: (bool) Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
""" """
@ -131,7 +131,7 @@ class Actor(BasePolicy):
but is slightly different when using ``expln`` function but is slightly different when using ``expln`` function
(cf StateDependentNoiseDistribution doc). (cf StateDependentNoiseDistribution doc).
:return: (th.Tensor) :return:
""" """
msg = "get_std() is only available when using gSDE" msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
@ -141,7 +141,7 @@ class Actor(BasePolicy):
""" """
Sample new weights for the exploration matrix, when using gSDE. Sample new weights for the exploration matrix, when using gSDE.
:param batch_size: (int) :param batch_size:
""" """
msg = "reset_noise() is only available when using gSDE" msg = "reset_noise() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
@ -151,8 +151,8 @@ class Actor(BasePolicy):
""" """
Get the parameters for the action distribution. Get the parameters for the action distribution.
:param obs: (th.Tensor) :param obs:
:return: (Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]) :return:
Mean, standard deviation and optional keyword arguments. Mean, standard deviation and optional keyword arguments.
""" """
features = self.extract_features(obs) features = self.extract_features(obs)
@ -187,9 +187,9 @@ class Actor(BasePolicy):
""" """
Evaluate actions according to the current policy, Evaluate actions according to the current policy,
given the observations. Only useful when using SDE. given the observations. Only useful when using SDE.
:param obs: (th.Tensor) :param obs:
:param actions: (th.Tensor) :param actions:
:return: (th.Tensor) log likelihood of taking those actions :return: log likelihood of taking those actions
""" """
mean_actions, log_std, kwargs = self.get_action_dist_params(obs) mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
self.action_dist.proba_distribution(mean_actions, log_std, **kwargs) self.action_dist.proba_distribution(mean_actions, log_std, **kwargs)
@ -200,14 +200,14 @@ class Critic(BaseModel):
""" """
Critic network (q-value function) for TQC. Critic network (q-value function) for TQC.
:param observation_space: (gym.spaces.Space) Obervation space :param observation_space: Obervation space
:param action_space: (gym.spaces.Space) Action space :param action_space: Action space
:param net_arch: ([int]) Network architecture :param net_arch: Network architecture
:param features_extractor: (nn.Module) Network to extract features :param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise) (a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: (int) Number of features :param features_dim: Number of features
:param activation_fn: (Type[nn.Module]) Activation function :param activation_fn: Activation function
:param normalize_images: (bool) Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
""" """
@ -257,28 +257,28 @@ class TQCPolicy(BasePolicy):
""" """
Policy class (with both actor and critic) for TQC. Policy class (with both actor and critic) for TQC.
:param observation_space: (gym.spaces.Space) Observation space :param observation_space: Observation space
:param action_space: (gym.spaces.Space) Action space :param action_space: Action space
:param lr_schedule: (callable) Learning rate schedule (could be constant) :param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks. :param net_arch: The specification of the policy and value networks.
:param activation_fn: (Type[nn.Module]) Activation function :param activation_fn: Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not :param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: ([int]) Network architecture for extracting features :param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used. when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features. Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. :param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments :param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor. to pass to the feature extractor.
:param normalize_images: (bool) Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, :param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default ``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, :param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer excluding the learning rate, to pass to the optimizer
""" """
@ -388,7 +388,7 @@ class TQCPolicy(BasePolicy):
""" """
Sample new weights for the exploration matrix, when using gSDE. Sample new weights for the exploration matrix, when using gSDE.
:param batch_size: (int) :param batch_size:
""" """
self.actor.reset_noise(batch_size=batch_size) self.actor.reset_noise(batch_size=batch_size)
@ -412,26 +412,26 @@ class CnnPolicy(TQCPolicy):
""" """
Policy class (with both actor and critic) for TQC. Policy class (with both actor and critic) for TQC.
:param observation_space: (gym.spaces.Space) Observation space :param observation_space: Observation space
:param action_space: (gym.spaces.Space) Action space :param action_space: Action space
:param lr_schedule: (callable) Learning rate schedule (could be constant) :param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks. :param net_arch: The specification of the policy and value networks.
:param activation_fn: (Type[nn.Module]) Activation function :param activation_fn: Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not :param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: ([int]) Network architecture for extracting features :param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used. when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features. Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. :param features_extractor_class: Features extractor to use.
:param normalize_images: (bool) Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, :param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default ``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, :param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer excluding the learning rate, to pass to the optimizer
""" """

View File

@ -7,7 +7,6 @@ from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import polyak_update from stable_baselines3.common.utils import polyak_update
from tqdm import tqdm
from sb3_contrib.tqc.policies import TQCPolicy from sb3_contrib.tqc.policies import TQCPolicy
@ -15,48 +14,48 @@ from sb3_contrib.tqc.policies import TQCPolicy
class TQC(OffPolicyAlgorithm): class TQC(OffPolicyAlgorithm):
""" """
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics.
Paper: https://arxiv.org/abs/2005.04269 Paper: https://arxiv.org/abs/2005.04269
:param policy: (TQCPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str) :param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: (float or callable) learning rate for adam optimizer, :param learning_rate: learning rate for adam optimizer,
the same learning rate will be used for all networks (Q-Values, Actor and Value function) the same learning rate will be used for all networks (Q-Values, Actor and Value function)
it can be a function of the current progress remaining (from 1 to 0) it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: (int) size of the replay buffer :param buffer_size: size of the replay buffer
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts :param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: (int) Minibatch size for each gradient update :param batch_size: Minibatch size for each gradient update
:param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: (float) the discount factor :param gamma: the discount factor
:param train_freq: (int) Update the model every ``train_freq`` steps. :param train_freq: Update the model every ``train_freq`` steps.
:param gradient_steps: (int) How many gradient update after each step :param gradient_steps: How many gradient update after each step
:param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
Note that this cannot be used at the same time as ``train_freq`` Note that this cannot be used at the same time as ``train_freq``
:param action_noise: (ActionNoise) the action noise type (None by default), this can help :param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type. for hard exploration problem. Cf common.noise for the different action noise type.
:param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity. at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param ent_coef: (str or float) Entropy regularization coefficient. (Equivalent to :param ent_coef: Entropy regularization coefficient. (Equivalent to
inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value) Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
:param target_update_interval: (int) update the target network every ``target_network_update_freq`` :param target_update_interval: update the target network every ``target_network_update_freq``
gradient steps. gradient steps.
:param target_entropy: (str or float) target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE) :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False) instead of action noise exploration (default: False)
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout) Default: -1 (only sample at the beginning of the rollout)
:param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts) during the warm up phase (before learning starts)
:param create_eval_env: (bool) Whether to create a second environment that will be :param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment) used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation :param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: (int) Seed for the pseudo random generators :param seed: Seed for the pseudo random generators
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. :param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible. Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance :param _init_setup_model: Whether or not to build the network at the creation of the instance
""" """
def __init__( def __init__(
@ -274,171 +273,6 @@ class TQC(OffPolicyAlgorithm):
if len(ent_coef_losses) > 0: if len(ent_coef_losses) > 0:
logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def pretrain(
self,
gradient_steps: int,
batch_size: int = 64,
n_action_samples: int = -1,
target_update_interval: int = 1,
tau: float = 0.005,
strategy: str = "exp",
reduce: str = "mean",
exp_temperature: float = 1.0,
off_policy_update_freq: int = -1,
) -> None:
"""
Pretrain with Critic Regularized Regression (CRR)
Paper: https://arxiv.org/abs/2006.15134
"""
# Update optimizers learning rate
optimizers = [self.actor.optimizer, self.critic.optimizer]
if self.ent_coef_optimizer is not None:
optimizers += [self.ent_coef_optimizer]
# Update learning rate according to lr schedule
self._update_learning_rate(optimizers)
actor_losses, critic_losses = [], []
for gradient_step in tqdm(range(gradient_steps)):
if off_policy_update_freq > 0 and gradient_step % off_policy_update_freq == 0:
self.train(gradient_steps=1, batch_size=batch_size)
continue
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
# We need to sample because `log_std` may have changed between two gradient steps
if self.use_sde:
self.actor.reset_noise()
# Action by the current actor for the sampled state
_, log_prob = self.actor.action_log_prob(replay_data.observations)
log_prob = log_prob.reshape(-1, 1)
ent_coef_loss = None
if self.ent_coef_optimizer is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef = th.exp(self.log_ent_coef.detach())
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
else:
ent_coef = self.ent_coef_tensor
self.replay_buffer.ent_coef = ent_coef.item()
# Optimize entropy coefficient, also called
# entropy temperature or alpha in the paper
if ent_coef_loss is not None:
self.ent_coef_optimizer.zero_grad()
ent_coef_loss.backward()
self.ent_coef_optimizer.step()
with th.no_grad():
top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute and cut quantiles at the next state
# batch x nets x quantiles
next_z = self.critic_target(replay_data.next_observations, next_actions)
sorted_z, _ = th.sort(next_z.reshape(batch_size, -1))
sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop]
target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1)
# td error + entropy term
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
# Get current Q estimates
# using action from the replay buffer
current_z = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = self.quantile_huber_loss(current_z, q_backup)
critic_losses.append(critic_loss.item())
# Optimize the critic
self.critic.optimizer.zero_grad()
critic_loss.backward()
self.critic.optimizer.step()
if strategy == "bc":
# Behavior cloning
weight = 1
else:
# Tensor version: TODO: check that the reshape works as expected
# cleaner but not faster on cpu for large batch size
# with th.no_grad():
# # Q-value for the action in the buffer
# qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True)
# # Create tensor to avoid loop
# # Note: For SDE, we need to sample several matrices
# obs_ = replay_data.observations.repeat(n_action_samples, 1)
# if self.use_sde:
# self.actor.reset_noise(batch_size * n_action_samples)
# actions_pi, _ = self.actor.action_log_prob(obs_)
# qf_pi = self.critic(obs_, actions_pi.detach()).mean(2).mean(1, keepdim=True)
# # Agregate: reduce mean or reduce max
# if reduce == "max":
# _, qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).max(axis=0)
# else:
# qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).mean(axis=0)
with th.no_grad():
qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True)
# Use the mean (as done in AWAC, cf rlkit)
if n_action_samples == -1:
actions_pi = self.actor.forward(replay_data.observations, deterministic=True)
qf_agg = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True)
else:
qf_agg = None
for _ in range(n_action_samples):
if self.use_sde:
self.actor.reset_noise()
actions_pi, _ = self.actor.action_log_prob(replay_data.observations)
qf_pi = self.critic(replay_data.observations, actions_pi.detach()).mean(2).mean(1, keepdim=True)
if qf_agg is None:
if reduce == "max":
qf_agg = qf_pi
else:
qf_agg = qf_pi / n_action_samples
else:
if reduce == "max":
qf_agg = th.max(qf_pi, qf_agg)
else:
qf_agg += qf_pi / n_action_samples
advantage = qf_buffer - qf_agg
if strategy == "binary":
# binary advantage
weight = advantage > 0
else:
# exp advantage
exp_clip = 20.0
weight = th.clamp(th.exp(advantage / exp_temperature), 0.0, exp_clip)
# Log prob by the current actor for the sampled state and action
log_prob = self.actor.evaluate_actions(replay_data.observations, replay_data.actions)
log_prob = log_prob.reshape(-1, 1)
# weigthed regression loss (close to policy gradient loss)
actor_loss = (-log_prob * weight).mean()
# actor_loss = ((actions_pi - replay_data.actions * weight) ** 2).mean()
actor_losses.append(actor_loss.item())
# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
# Update target networks
if gradient_step % target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), tau)
if self.use_sde:
print(f"std={(self.actor.get_std()).mean().item()}")
def learn( def learn(
self, self,
total_timesteps: int, total_timesteps: int,
@ -469,7 +303,7 @@ class TQC(OffPolicyAlgorithm):
Returns the names of the parameters that should be excluded by default Returns the names of the parameters that should be excluded by default
when saving the model. when saving the model.
:return: (List[str]) List of parameters that should be excluded from save :return: List of parameters that should be excluded from save
""" """
# Exclude aliases # Exclude aliases
return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]

View File

@ -23,11 +23,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3[tests,docs]>=0.9.0a0", "stable_baselines3[tests,docs]>=0.9.0",
# For progress bar when using CRR
"tqdm"
# Enable CMA
# "cma",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

View File

@ -14,41 +14,19 @@ def test_tqc(ent_coef):
create_eval_env=True, create_eval_env=True,
ent_coef=ent_coef, ent_coef=ent_coef,
) )
model.learn(total_timesteps=1000, eval_freq=500) model.learn(total_timesteps=500, eval_freq=250)
@pytest.mark.parametrize("n_critics", [1, 3]) @pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics): def test_n_critics(n_critics):
# Test TQC with different number of critics # Test TQC with different number of critics
model = TQC( model = TQC(
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1 "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1
) )
model.learn(total_timesteps=1000) model.learn(total_timesteps=500)
def test_sde():
# "CartPole-v1"
# @pytest.mark.parametrize("env_id", ["MountainCarContinuous-v0"])
# def test_cmaes(env_id):
# if CMAES is None:
# return
# model = CMAES("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[64]), verbose=1, create_eval_env=True)
# model.learn(total_timesteps=50000, eval_freq=10000)
@pytest.mark.parametrize("strategy", ["exp", "bc", "binary"])
@pytest.mark.parametrize("reduce", ["mean", "max"])
def test_crr(tmp_path, strategy, reduce):
model = TQC( model = TQC(
"MlpPolicy", "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64]), use_sde=True, learning_starts=100, verbose=1
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64]),
learning_starts=1000,
verbose=1,
create_eval_env=True,
action_noise=None,
use_sde=False,
) )
model.learn(total_timesteps=500)
model.learn(total_timesteps=1000, eval_freq=0)
for n_action_samples in [1, 2, -1]:
model.pretrain(gradient_steps=32, batch_size=32, n_action_samples=n_action_samples, strategy=strategy, reduce=reduce)

View File

@ -45,7 +45,7 @@ def test_save_load(tmp_path, model_class):
# create model # create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1) model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500, eval_freq=250) model.learn(total_timesteps=500)
env.reset() env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
@ -154,7 +154,7 @@ def test_save_load(tmp_path, model_class):
assert np.allclose(selected_actions, new_selected_actions, 1e-4) assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# check if learn still works # check if learn still works
model.learn(total_timesteps=1000, eval_freq=500) model.learn(total_timesteps=500)
del model del model
@ -177,17 +177,17 @@ def test_set_env(model_class):
# create model # create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16])) model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
# learn # learn
model.learn(total_timesteps=1000, eval_freq=500) model.learn(total_timesteps=500)
# change env # change env
model.set_env(env2) model.set_env(env2)
# learn again # learn again
model.learn(total_timesteps=1000, eval_freq=500) model.learn(total_timesteps=500)
# change env test wrapping # change env test wrapping
model.set_env(env3) model.set_env(env3)
# learn again # learn again
model.learn(total_timesteps=1000, eval_freq=500) model.learn(total_timesteps=500)
@pytest.mark.parametrize("model_class", MODEL_LIST) @pytest.mark.parametrize("model_class", MODEL_LIST)
@ -247,45 +247,6 @@ def test_save_load_replay_buffer(tmp_path, model_class):
) )
@pytest.mark.parametrize("model_class", [TQC])
@pytest.mark.parametrize("optimize_memory_usage", [False, True])
def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
"""
When using memory efficient replay buffer,
a warning must be emitted when calling `.learn()`
multiple times.
See https://github.com/DLR-RM/stable-baselines3/issues/46
"""
# remove gym warnings
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
model = model_class(
"MlpPolicy",
select_env(model_class),
buffer_size=100,
optimize_memory_usage=optimize_memory_usage,
policy_kwargs=dict(net_arch=[64]),
learning_starts=10,
)
model.learn(150)
model.learn(150, reset_num_timesteps=False)
# Check that there is no warning
assert len(recwarn) == 0
model.learn(150)
if optimize_memory_usage:
assert len(recwarn) == 1
warning = recwarn.pop(UserWarning)
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
else:
assert len(recwarn) == 0
@pytest.mark.parametrize("model_class", MODEL_LIST) @pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) @pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
def test_save_load_policy(tmp_path, model_class, policy_str): def test_save_load_policy(tmp_path, model_class, policy_str):
@ -309,7 +270,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
# create model # create model
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs) model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
model.learn(total_timesteps=500, eval_freq=250) model.learn(total_timesteps=500)
env.reset() env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
@ -375,52 +336,3 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
os.remove(tmp_path / "policy.pkl") os.remove(tmp_path / "policy.pkl")
if actor_class is not None: if actor_class is not None:
os.remove(tmp_path / "actor.pkl") os.remove(tmp_path / "actor.pkl")
@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
def test_open_file_str_pathlib(tmp_path, pathtype):
# check that suffix isn't added because we used open_path first
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
assert not record
# test custom suffix
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
assert not record
# test without suffix
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
assert not record
# test that a warning is raised when the path doesn't exist
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
assert len(record) == 0
with pytest.warns(None) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
assert len(record) == 1
fp = pathlib.Path(f"{tmp_path}/t2").open("w")
fp.write("rubbish")
fp.close()
# test that a warning is only raised when verbose = 0
with pytest.warns(None) as record:
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
assert len(record) == 1