Add TQC (#4)
* Add TQC doc * Polish code * Update doc * Update results * Update doc * Update doc * Add note about PyBullet envs
This commit is contained in:
parent
3fe0a0c2f2
commit
0700c3eeb0
|
|
@ -93,13 +93,19 @@ Results
|
|||
|
||||
A description and comparison of results (e.g. how the change improved results over the non-changed algorithm), if
|
||||
applicable.
|
||||
Please link the associated pull request, e.g., `Pull Request #4 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/4>`_.
|
||||
|
||||
Include the expected results from the work that originally proposed the method (e.g. original paper).
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Include the code to replicate these results or a link to repository/branch where the code can be found.
|
||||
Use `rl-baselines3-zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ if possible, fork it, create a new branch
|
||||
and share the code to replicate results there.
|
||||
|
||||
If applicable, please also provide the command to replicate the plots.
|
||||
|
||||
Comments
|
||||
--------
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
Before Width: | Height: | Size: 156 KiB |
|
|
@ -4,27 +4,22 @@ RL Algorithms
|
|||
This table displays the rl algorithms that are implemented in the Stable Baselines3 contrib project,
|
||||
along with some useful characteristics: support for discrete/continuous actions, multiprocessing.
|
||||
|
||||
..
|
||||
.. ============ =========== ============ ================= =============== ================
|
||||
.. Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||
.. ============ =========== ============ ================= =============== ================
|
||||
.. A2C ✔️ ✔️ ✔️ ✔️ ✔️
|
||||
.. DDPG ✔️ ❌ ❌ ❌ ❌
|
||||
.. DQN ❌ ✔️ ❌ ❌ ❌
|
||||
.. PPO ✔️ ✔️ ✔️ ✔️ ✔️
|
||||
.. SAC ✔️ ❌ ❌ ❌ ❌
|
||||
.. TD3 ✔️ ❌ ❌ ❌ ❌
|
||||
.. ============ =========== ============ ================= =============== ================
|
||||
|
||||
============ =========== ============ ================= =============== ================
|
||||
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||
============ =========== ============ ================= =============== ================
|
||||
TQC ✔️ ❌ ❌ ❌ ❌
|
||||
============ =========== ============ ================= =============== ================
|
||||
|
||||
|
||||
.. .. note::
|
||||
.. Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
|
||||
..
|
||||
.. Actions ``gym.spaces``:
|
||||
..
|
||||
.. - ``Box``: A N-dimensional box that contains every point in the action
|
||||
.. space.
|
||||
.. - ``Discrete``: A list of possible actions, where each timestep only
|
||||
.. one of the actions can be used.
|
||||
.. - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used.
|
||||
.. - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.
|
||||
.. note::
|
||||
Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
|
||||
|
||||
Actions ``gym.spaces``:
|
||||
|
||||
- ``Box``: A N-dimensional box that contains every point in the action
|
||||
space.
|
||||
- ``Discrete``: A list of possible actions, where each timestep only
|
||||
one of the actions can be used.
|
||||
- ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used.
|
||||
- ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,18 @@
|
|||
Examples
|
||||
========
|
||||
|
||||
WIP
|
||||
TQC
|
||||
---
|
||||
|
||||
Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sb3_contrib import TQC
|
||||
|
||||
model = TQC("MlpPolicy", "Pendulum-v0", top_quantiles_to_drop_per_net=2, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.save("tqc_pendulum")
|
||||
|
||||
.. PyBullet: Normalizing input features
|
||||
.. ------------------------------------
|
||||
|
|
@ -13,7 +24,7 @@ WIP
|
|||
.. for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and
|
||||
.. will compute a running average and standard deviation of input features (it can do the same for rewards).
|
||||
..
|
||||
|
||||
..
|
||||
.. .. note::
|
||||
..
|
||||
.. you need to install pybullet with ``pip install pybullet``
|
||||
|
|
|
|||
|
|
@ -27,11 +27,11 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
|
|||
guide/examples
|
||||
|
||||
|
||||
.. .. toctree::
|
||||
.. :maxdepth: 1
|
||||
.. :caption: RL Algorithms
|
||||
..
|
||||
.. modules/a2c
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: RL Algorithms
|
||||
|
||||
modules/tqc
|
||||
|
||||
.. .. toctree::
|
||||
.. :maxdepth: 1
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``TQC`` algorithm (@araffin)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -0,0 +1,168 @@
|
|||
.. _tqc:
|
||||
|
||||
.. automodule:: sb3_contrib.tqc
|
||||
|
||||
|
||||
TQC
|
||||
===
|
||||
|
||||
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics (TQC).
|
||||
Truncated Quantile Critics (TQC) builds on SAC, TD3 and QR-DQN, making use of quantile regression to predict a distribution for the value function (instead of a mean value).
|
||||
It truncates the quantiles predicted by different networks (a bit as it is done in TD3).
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
MlpPolicy
|
||||
CnnPolicy
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
- Original paper: https://arxiv.org/abs/2005.04269
|
||||
- Original Implementation: https://github.com/bayesgroup/tqc_pytorch
|
||||
|
||||
|
||||
Can I use?
|
||||
----------
|
||||
|
||||
- Recurrent policies: ❌
|
||||
- Multi processing: ❌
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ❌ ✔️
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ✔️
|
||||
MultiBinary ❌ ✔️
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from sb3_contrib import TQC
|
||||
|
||||
env = gym.make("Pendulum-v0")
|
||||
|
||||
policy_kwargs = dict(n_critics=2, n_quantiles=25)
|
||||
model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.save("tqc_pendulum")
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = TQC.load("tqc_pendulum")
|
||||
|
||||
obs = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
obs = env.reset()
|
||||
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
Result on the PyBullet benchmark (1M steps) and on BipedalWalkerHardcore-v3 (2M steps)
|
||||
using 3 seeds.
|
||||
The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/4>`_.
|
||||
|
||||
The main difference with SAC is on harder environments (BipedalWalkerHardcore, Walker2D).
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Hyperparameters from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used (as they are tuned for SAC on PyBullet envs),
|
||||
including using gSDE for the exploration and not the unstructured Gaussian noise
|
||||
but this should not affect results in simulation.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
We are using the open source PyBullet environments and not the MuJoCo simulator (as done in the original paper).
|
||||
You can find a complete benchmark on PyBullet envs in the `gSDE paper <https://arxiv.org/abs/2005.05719>`_
|
||||
if you want to compare TQC results to those of A2C/PPO/SAC/TD3.
|
||||
|
||||
|
||||
===================== ============ ============
|
||||
Environments SAC TQC
|
||||
===================== ============ ============
|
||||
\ gSDE gSDE
|
||||
HalfCheetah 2984 +/- 202 3041 +/- 157
|
||||
Ant 3102 +/- 37 3700 +/- 37
|
||||
Hopper 2262 +/- 1 2401 +/- 62*
|
||||
Walker2D 2136 +/- 67 2535 +/- 94
|
||||
BipedalWalkerHardcore 13 +/- 18 228 +/- 18
|
||||
===================== ============ ============
|
||||
|
||||
\* with tuned hyperparameter ``top_quantiles_to_drop_per_net`` taken from the original paper
|
||||
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Clone RL-Zoo and checkout the branch ``feat/tqc``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
cd rl-baselines3-zoo/
|
||||
git checkout feat/tqc
|
||||
|
||||
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo tqc --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||
|
||||
|
||||
Plot the results:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/all_plots.py -a tqc -e HalfCheetah Ant Hopper Walker2D BipedalWalkerHardcore -f logs/ -o logs/tqc_results
|
||||
python scripts/plot_from_file.py -i logs/tqc_results.pkl -latex -l TQC
|
||||
|
||||
Comments
|
||||
--------
|
||||
|
||||
This implementation is based on SB3 SAC implementation and uses the code from the original TQC implementation for the quantile huber loss.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
.. autoclass:: TQC
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. _tqc_policies:
|
||||
|
||||
TQC Policies
|
||||
-------------
|
||||
|
||||
.. autoclass:: MlpPolicy
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: sb3_contrib.tqc.policies.TQCPolicy
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
.. autoclass:: CnnPolicy
|
||||
:members:
|
||||
|
|
@ -5,7 +5,13 @@ import torch as th
|
|||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
FlattenExtractor,
|
||||
NatureCNN,
|
||||
create_mlp,
|
||||
get_actor_critic_arch,
|
||||
)
|
||||
from torch import nn as nn
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
|
|
@ -306,6 +312,8 @@ class TQCPolicy(BasePolicy):
|
|||
else:
|
||||
net_arch = []
|
||||
|
||||
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||
|
||||
# Create shared features extractor
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
|
@ -317,12 +325,11 @@ class TQCPolicy(BasePolicy):
|
|||
"action_space": self.action_space,
|
||||
"features_extractor": self.features_extractor,
|
||||
"features_dim": self.features_dim,
|
||||
"net_arch": self.net_arch,
|
||||
"net_arch": actor_arch,
|
||||
"activation_fn": self.activation_fn,
|
||||
"normalize_images": normalize_images,
|
||||
}
|
||||
self.actor_kwargs = self.net_args.copy()
|
||||
self.critic_kwargs = self.net_args.copy()
|
||||
sde_kwargs = {
|
||||
"use_sde": use_sde,
|
||||
"log_std_init": log_std_init,
|
||||
|
|
@ -331,7 +338,8 @@ class TQCPolicy(BasePolicy):
|
|||
"clip_mean": clip_mean,
|
||||
}
|
||||
self.actor_kwargs.update(sde_kwargs)
|
||||
tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics}
|
||||
self.critic_kwargs = self.net_args.copy()
|
||||
tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics, "net_arch": critic_arch}
|
||||
self.critic_kwargs.update(tqc_kwargs)
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
|
|
@ -356,7 +364,7 @@ class TQCPolicy(BasePolicy):
|
|||
|
||||
data.update(
|
||||
dict(
|
||||
net_arch=self.net_args["net_arch"],
|
||||
net_arch=self.net_arch,
|
||||
activation_fn=self.net_args["activation_fn"],
|
||||
use_sde=self.actor_kwargs["use_sde"],
|
||||
log_std_init=self.actor_kwargs["log_std_init"],
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
|
||||
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics.
|
||||
Paper: https://arxiv.org/abs/2005.04269
|
||||
This implementation uses SB3 SAC implementation as base.
|
||||
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||
|
|
@ -169,6 +170,14 @@ class TQC(OffPolicyAlgorithm):
|
|||
|
||||
@staticmethod
|
||||
def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
The quantile-regression loss, as described in the QR-DQN and TQC papers.
|
||||
Taken from https://github.com/bayesgroup/tqc_pytorch
|
||||
|
||||
:param quantiles:
|
||||
:param samples:
|
||||
:return: the loss
|
||||
"""
|
||||
# batch x nets x quantiles x samples
|
||||
pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None]
|
||||
abs_pairwise_delta = th.abs(pairwise_delta)
|
||||
|
|
@ -251,7 +260,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
self.critic.optimizer.step()
|
||||
|
||||
# Compute actor loss
|
||||
qf_pi = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True)
|
||||
qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True)
|
||||
actor_loss = (ent_coef * log_prob - qf_pi).mean()
|
||||
actor_losses.append(actor_loss.item())
|
||||
|
||||
|
|
@ -299,19 +308,10 @@ class TQC(OffPolicyAlgorithm):
|
|||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
"""
|
||||
Returns the names of the parameters that should be excluded by default
|
||||
when saving the model.
|
||||
|
||||
:return: List of parameters that should be excluded from save
|
||||
"""
|
||||
# Exclude aliases
|
||||
return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
cf base class
|
||||
"""
|
||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||
saved_pytorch_variables = ["log_ent_coef"]
|
||||
if self.ent_coef_optimizer is not None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue