* Add TQC doc

* Polish code

* Update doc

* Update results

* Update doc

* Update doc

* Add note about PyBullet envs
This commit is contained in:
Antonin RAFFIN 2020-10-22 13:43:46 +02:00 committed by GitHub
parent 3fe0a0c2f2
commit 0700c3eeb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 233 additions and 44 deletions

View File

@ -93,13 +93,19 @@ Results
A description and comparison of results (e.g. how the change improved results over the non-changed algorithm), if A description and comparison of results (e.g. how the change improved results over the non-changed algorithm), if
applicable. applicable.
Please link the associated pull request, e.g., `Pull Request #4 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/4>`_.
Include the expected results from the work that originally proposed the method (e.g. original paper). Include the expected results from the work that originally proposed the method (e.g. original paper).
How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Include the code to replicate these results or a link to repository/branch where the code can be found. Include the code to replicate these results or a link to repository/branch where the code can be found.
Use `rl-baselines3-zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ if possible, fork it, create a new branch Use `rl-baselines3-zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ if possible, fork it, create a new branch
and share the code to replicate results there. and share the code to replicate results there.
If applicable, please also provide the command to replicate the plots.
Comments Comments
-------- --------

Binary file not shown.

Before

Width:  |  Height:  |  Size: 156 KiB

View File

@ -4,27 +4,22 @@ RL Algorithms
This table displays the rl algorithms that are implemented in the Stable Baselines3 contrib project, This table displays the rl algorithms that are implemented in the Stable Baselines3 contrib project,
along with some useful characteristics: support for discrete/continuous actions, multiprocessing. along with some useful characteristics: support for discrete/continuous actions, multiprocessing.
..
.. ============ =========== ============ ================= =============== ================ ============ =========== ============ ================= =============== ================
.. Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
.. ============ =========== ============ ================= =============== ================ ============ =========== ============ ================= =============== ================
.. A2C ✔️ ✔️ ✔️ ✔️ ✔️ TQC ✔️ ❌ ❌ ❌ ❌
.. DDPG ✔️ ❌ ❌ ❌ ❌ ============ =========== ============ ================= =============== ================
.. DQN ❌ ✔️ ❌ ❌ ❌
.. PPO ✔️ ✔️ ✔️ ✔️ ✔️
.. SAC ✔️ ❌ ❌ ❌ ❌
.. TD3 ✔️ ❌ ❌ ❌ ❌
.. ============ =========== ============ ================= =============== ================
.. .. note:: .. note::
.. Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm. Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
..
.. Actions ``gym.spaces``: Actions ``gym.spaces``:
..
.. - ``Box``: A N-dimensional box that contains every point in the action - ``Box``: A N-dimensional box that contains every point in the action
.. space. space.
.. - ``Discrete``: A list of possible actions, where each timestep only - ``Discrete``: A list of possible actions, where each timestep only
.. one of the actions can be used. one of the actions can be used.
.. - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used. - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used.
.. - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination. - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.

View File

@ -3,7 +3,18 @@
Examples Examples
======== ========
WIP TQC
---
Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
.. code-block:: python
from sb3_contrib import TQC
model = TQC("MlpPolicy", "Pendulum-v0", top_quantiles_to_drop_per_net=2, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("tqc_pendulum")
.. PyBullet: Normalizing input features .. PyBullet: Normalizing input features
.. ------------------------------------ .. ------------------------------------
@ -13,7 +24,7 @@ WIP
.. for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and .. for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and
.. will compute a running average and standard deviation of input features (it can do the same for rewards). .. will compute a running average and standard deviation of input features (it can do the same for rewards).
.. ..
..
.. .. note:: .. .. note::
.. ..
.. you need to install pybullet with ``pip install pybullet`` .. you need to install pybullet with ``pip install pybullet``

View File

@ -27,11 +27,11 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
guide/examples guide/examples
.. .. toctree:: .. toctree::
.. :maxdepth: 1 :maxdepth: 1
.. :caption: RL Algorithms :caption: RL Algorithms
..
.. modules/a2c modules/tqc
.. .. toctree:: .. .. toctree::
.. :maxdepth: 1 .. :maxdepth: 1

View File

@ -12,6 +12,7 @@ Breaking Changes:
New Features: New Features:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
- Added ``TQC`` algorithm (@araffin)
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^

168
docs/modules/tqc.rst Normal file
View File

@ -0,0 +1,168 @@
.. _tqc:
.. automodule:: sb3_contrib.tqc
TQC
===
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics (TQC).
Truncated Quantile Critics (TQC) builds on SAC, TD3 and QR-DQN, making use of quantile regression to predict a distribution for the value function (instead of a mean value).
It truncates the quantiles predicted by different networks (a bit as it is done in TD3).
.. rubric:: Available Policies
.. autosummary::
:nosignatures:
MlpPolicy
CnnPolicy
Notes
-----
- Original paper: https://arxiv.org/abs/2005.04269
- Original Implementation: https://github.com/bayesgroup/tqc_pytorch
Can I use?
----------
- Recurrent policies: ❌
- Multi processing: ❌
- Gym spaces:
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========
Example
-------
.. code-block:: python
import gym
import numpy as np
from sb3_contrib import TQC
env = gym.make("Pendulum-v0")
policy_kwargs = dict(n_critics=2, n_quantiles=25)
model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs)
model.learn(total_timesteps=10000, log_interval=4)
model.save("tqc_pendulum")
del model # remove to demonstrate saving and loading
model = TQC.load("tqc_pendulum")
obs = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()
Results
-------
Result on the PyBullet benchmark (1M steps) and on BipedalWalkerHardcore-v3 (2M steps)
using 3 seeds.
The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/4>`_.
The main difference with SAC is on harder environments (BipedalWalkerHardcore, Walker2D).
.. note::
Hyperparameters from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used (as they are tuned for SAC on PyBullet envs),
including using gSDE for the exploration and not the unstructured Gaussian noise
but this should not affect results in simulation.
.. note::
We are using the open source PyBullet environments and not the MuJoCo simulator (as done in the original paper).
You can find a complete benchmark on PyBullet envs in the `gSDE paper <https://arxiv.org/abs/2005.05719>`_
if you want to compare TQC results to those of A2C/PPO/SAC/TD3.
===================== ============ ============
Environments SAC TQC
===================== ============ ============
\ gSDE gSDE
HalfCheetah 2984 +/- 202 3041 +/- 157
Ant 3102 +/- 37 3700 +/- 37
Hopper 2262 +/- 1 2401 +/- 62*
Walker2D 2136 +/- 67 2535 +/- 94
BipedalWalkerHardcore 13 +/- 18 228 +/- 18
===================== ============ ============
\* with tuned hyperparameter ``top_quantiles_to_drop_per_net`` taken from the original paper
How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Clone RL-Zoo and checkout the branch ``feat/tqc``:
.. code-block:: bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
git checkout feat/tqc
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
.. code-block:: bash
python train.py --algo tqc --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results:
.. code-block:: bash
python scripts/all_plots.py -a tqc -e HalfCheetah Ant Hopper Walker2D BipedalWalkerHardcore -f logs/ -o logs/tqc_results
python scripts/plot_from_file.py -i logs/tqc_results.pkl -latex -l TQC
Comments
--------
This implementation is based on SB3 SAC implementation and uses the code from the original TQC implementation for the quantile huber loss.
Parameters
----------
.. autoclass:: TQC
:members:
:inherited-members:
.. _tqc_policies:
TQC Policies
-------------
.. autoclass:: MlpPolicy
:members:
:inherited-members:
.. autoclass:: sb3_contrib.tqc.policies.TQCPolicy
:members:
:noindex:
.. autoclass:: CnnPolicy
:members:

View File

@ -5,7 +5,13 @@ import torch as th
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy
from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
FlattenExtractor,
NatureCNN,
create_mlp,
get_actor_critic_arch,
)
from torch import nn as nn from torch import nn as nn
# CAP the standard deviation of the actor # CAP the standard deviation of the actor
@ -306,6 +312,8 @@ class TQCPolicy(BasePolicy):
else: else:
net_arch = [] net_arch = []
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
# Create shared features extractor # Create shared features extractor
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim self.features_dim = self.features_extractor.features_dim
@ -317,12 +325,11 @@ class TQCPolicy(BasePolicy):
"action_space": self.action_space, "action_space": self.action_space,
"features_extractor": self.features_extractor, "features_extractor": self.features_extractor,
"features_dim": self.features_dim, "features_dim": self.features_dim,
"net_arch": self.net_arch, "net_arch": actor_arch,
"activation_fn": self.activation_fn, "activation_fn": self.activation_fn,
"normalize_images": normalize_images, "normalize_images": normalize_images,
} }
self.actor_kwargs = self.net_args.copy() self.actor_kwargs = self.net_args.copy()
self.critic_kwargs = self.net_args.copy()
sde_kwargs = { sde_kwargs = {
"use_sde": use_sde, "use_sde": use_sde,
"log_std_init": log_std_init, "log_std_init": log_std_init,
@ -331,7 +338,8 @@ class TQCPolicy(BasePolicy):
"clip_mean": clip_mean, "clip_mean": clip_mean,
} }
self.actor_kwargs.update(sde_kwargs) self.actor_kwargs.update(sde_kwargs)
tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics} self.critic_kwargs = self.net_args.copy()
tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics, "net_arch": critic_arch}
self.critic_kwargs.update(tqc_kwargs) self.critic_kwargs.update(tqc_kwargs)
self.actor, self.actor_target = None, None self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None self.critic, self.critic_target = None, None
@ -356,7 +364,7 @@ class TQCPolicy(BasePolicy):
data.update( data.update(
dict( dict(
net_arch=self.net_args["net_arch"], net_arch=self.net_arch,
activation_fn=self.net_args["activation_fn"], activation_fn=self.net_args["activation_fn"],
use_sde=self.actor_kwargs["use_sde"], use_sde=self.actor_kwargs["use_sde"],
log_std_init=self.actor_kwargs["log_std_init"], log_std_init=self.actor_kwargs["log_std_init"],

View File

@ -16,6 +16,7 @@ class TQC(OffPolicyAlgorithm):
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics. Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics.
Paper: https://arxiv.org/abs/2005.04269 Paper: https://arxiv.org/abs/2005.04269
This implementation uses SB3 SAC implementation as base.
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str) :param env: The environment to learn from (if registered in Gym, can be str)
@ -169,6 +170,14 @@ class TQC(OffPolicyAlgorithm):
@staticmethod @staticmethod
def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor: def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor:
"""
The quantile-regression loss, as described in the QR-DQN and TQC papers.
Taken from https://github.com/bayesgroup/tqc_pytorch
:param quantiles:
:param samples:
:return: the loss
"""
# batch x nets x quantiles x samples # batch x nets x quantiles x samples
pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None] pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None]
abs_pairwise_delta = th.abs(pairwise_delta) abs_pairwise_delta = th.abs(pairwise_delta)
@ -251,7 +260,7 @@ class TQC(OffPolicyAlgorithm):
self.critic.optimizer.step() self.critic.optimizer.step()
# Compute actor loss # Compute actor loss
qf_pi = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True) qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - qf_pi).mean() actor_loss = (ent_coef * log_prob - qf_pi).mean()
actor_losses.append(actor_loss.item()) actor_losses.append(actor_loss.item())
@ -299,19 +308,10 @@ class TQC(OffPolicyAlgorithm):
) )
def _excluded_save_params(self) -> List[str]: def _excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.
:return: List of parameters that should be excluded from save
"""
# Exclude aliases # Exclude aliases
return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
saved_pytorch_variables = ["log_ent_coef"] saved_pytorch_variables = ["log_ent_coef"]
if self.ent_coef_optimizer is not None: if self.ent_coef_optimizer is not None: