Add TQC (#4)
* Add TQC doc * Polish code * Update doc * Update results * Update doc * Update doc * Add note about PyBullet envs
This commit is contained in:
parent
3fe0a0c2f2
commit
0700c3eeb0
|
|
@ -93,13 +93,19 @@ Results
|
||||||
|
|
||||||
A description and comparison of results (e.g. how the change improved results over the non-changed algorithm), if
|
A description and comparison of results (e.g. how the change improved results over the non-changed algorithm), if
|
||||||
applicable.
|
applicable.
|
||||||
|
Please link the associated pull request, e.g., `Pull Request #4 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/4>`_.
|
||||||
|
|
||||||
Include the expected results from the work that originally proposed the method (e.g. original paper).
|
Include the expected results from the work that originally proposed the method (e.g. original paper).
|
||||||
|
|
||||||
|
How to replicate the results?
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Include the code to replicate these results or a link to repository/branch where the code can be found.
|
Include the code to replicate these results or a link to repository/branch where the code can be found.
|
||||||
Use `rl-baselines3-zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ if possible, fork it, create a new branch
|
Use `rl-baselines3-zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ if possible, fork it, create a new branch
|
||||||
and share the code to replicate results there.
|
and share the code to replicate results there.
|
||||||
|
|
||||||
|
If applicable, please also provide the command to replicate the plots.
|
||||||
|
|
||||||
Comments
|
Comments
|
||||||
--------
|
--------
|
||||||
|
|
||||||
|
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 156 KiB |
|
|
@ -4,27 +4,22 @@ RL Algorithms
|
||||||
This table displays the rl algorithms that are implemented in the Stable Baselines3 contrib project,
|
This table displays the rl algorithms that are implemented in the Stable Baselines3 contrib project,
|
||||||
along with some useful characteristics: support for discrete/continuous actions, multiprocessing.
|
along with some useful characteristics: support for discrete/continuous actions, multiprocessing.
|
||||||
|
|
||||||
..
|
|
||||||
.. ============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
.. Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||||
.. ============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
.. A2C ✔️ ✔️ ✔️ ✔️ ✔️
|
TQC ✔️ ❌ ❌ ❌ ❌
|
||||||
.. DDPG ✔️ ❌ ❌ ❌ ❌
|
============ =========== ============ ================= =============== ================
|
||||||
.. DQN ❌ ✔️ ❌ ❌ ❌
|
|
||||||
.. PPO ✔️ ✔️ ✔️ ✔️ ✔️
|
|
||||||
.. SAC ✔️ ❌ ❌ ❌ ❌
|
|
||||||
.. TD3 ✔️ ❌ ❌ ❌ ❌
|
|
||||||
.. ============ =========== ============ ================= =============== ================
|
|
||||||
|
|
||||||
|
|
||||||
.. .. note::
|
.. note::
|
||||||
.. Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
|
Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
|
||||||
..
|
|
||||||
.. Actions ``gym.spaces``:
|
Actions ``gym.spaces``:
|
||||||
..
|
|
||||||
.. - ``Box``: A N-dimensional box that contains every point in the action
|
- ``Box``: A N-dimensional box that contains every point in the action
|
||||||
.. space.
|
space.
|
||||||
.. - ``Discrete``: A list of possible actions, where each timestep only
|
- ``Discrete``: A list of possible actions, where each timestep only
|
||||||
.. one of the actions can be used.
|
one of the actions can be used.
|
||||||
.. - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used.
|
- ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used.
|
||||||
.. - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.
|
- ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,18 @@
|
||||||
Examples
|
Examples
|
||||||
========
|
========
|
||||||
|
|
||||||
WIP
|
TQC
|
||||||
|
---
|
||||||
|
|
||||||
|
Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from sb3_contrib import TQC
|
||||||
|
|
||||||
|
model = TQC("MlpPolicy", "Pendulum-v0", top_quantiles_to_drop_per_net=2, verbose=1)
|
||||||
|
model.learn(total_timesteps=10000, log_interval=4)
|
||||||
|
model.save("tqc_pendulum")
|
||||||
|
|
||||||
.. PyBullet: Normalizing input features
|
.. PyBullet: Normalizing input features
|
||||||
.. ------------------------------------
|
.. ------------------------------------
|
||||||
|
|
@ -13,7 +24,7 @@ WIP
|
||||||
.. for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and
|
.. for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and
|
||||||
.. will compute a running average and standard deviation of input features (it can do the same for rewards).
|
.. will compute a running average and standard deviation of input features (it can do the same for rewards).
|
||||||
..
|
..
|
||||||
|
..
|
||||||
.. .. note::
|
.. .. note::
|
||||||
..
|
..
|
||||||
.. you need to install pybullet with ``pip install pybullet``
|
.. you need to install pybullet with ``pip install pybullet``
|
||||||
|
|
|
||||||
|
|
@ -27,11 +27,11 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
|
||||||
guide/examples
|
guide/examples
|
||||||
|
|
||||||
|
|
||||||
.. .. toctree::
|
.. toctree::
|
||||||
.. :maxdepth: 1
|
:maxdepth: 1
|
||||||
.. :caption: RL Algorithms
|
:caption: RL Algorithms
|
||||||
..
|
|
||||||
.. modules/a2c
|
modules/tqc
|
||||||
|
|
||||||
.. .. toctree::
|
.. .. toctree::
|
||||||
.. :maxdepth: 1
|
.. :maxdepth: 1
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ Breaking Changes:
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
- Added ``TQC`` algorithm (@araffin)
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,168 @@
|
||||||
|
.. _tqc:
|
||||||
|
|
||||||
|
.. automodule:: sb3_contrib.tqc
|
||||||
|
|
||||||
|
|
||||||
|
TQC
|
||||||
|
===
|
||||||
|
|
||||||
|
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics (TQC).
|
||||||
|
Truncated Quantile Critics (TQC) builds on SAC, TD3 and QR-DQN, making use of quantile regression to predict a distribution for the value function (instead of a mean value).
|
||||||
|
It truncates the quantiles predicted by different networks (a bit as it is done in TD3).
|
||||||
|
|
||||||
|
|
||||||
|
.. rubric:: Available Policies
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
MlpPolicy
|
||||||
|
CnnPolicy
|
||||||
|
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
|
||||||
|
- Original paper: https://arxiv.org/abs/2005.04269
|
||||||
|
- Original Implementation: https://github.com/bayesgroup/tqc_pytorch
|
||||||
|
|
||||||
|
|
||||||
|
Can I use?
|
||||||
|
----------
|
||||||
|
|
||||||
|
- Recurrent policies: ❌
|
||||||
|
- Multi processing: ❌
|
||||||
|
- Gym spaces:
|
||||||
|
|
||||||
|
|
||||||
|
============= ====== ===========
|
||||||
|
Space Action Observation
|
||||||
|
============= ====== ===========
|
||||||
|
Discrete ❌ ✔️
|
||||||
|
Box ✔️ ✔️
|
||||||
|
MultiDiscrete ❌ ✔️
|
||||||
|
MultiBinary ❌ ✔️
|
||||||
|
============= ====== ===========
|
||||||
|
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from sb3_contrib import TQC
|
||||||
|
|
||||||
|
env = gym.make("Pendulum-v0")
|
||||||
|
|
||||||
|
policy_kwargs = dict(n_critics=2, n_quantiles=25)
|
||||||
|
model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs)
|
||||||
|
model.learn(total_timesteps=10000, log_interval=4)
|
||||||
|
model.save("tqc_pendulum")
|
||||||
|
|
||||||
|
del model # remove to demonstrate saving and loading
|
||||||
|
|
||||||
|
model = TQC.load("tqc_pendulum")
|
||||||
|
|
||||||
|
obs = env.reset()
|
||||||
|
while True:
|
||||||
|
action, _states = model.predict(obs, deterministic=True)
|
||||||
|
obs, reward, done, info = env.step(action)
|
||||||
|
env.render()
|
||||||
|
if done:
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
|
||||||
|
Results
|
||||||
|
-------
|
||||||
|
|
||||||
|
Result on the PyBullet benchmark (1M steps) and on BipedalWalkerHardcore-v3 (2M steps)
|
||||||
|
using 3 seeds.
|
||||||
|
The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/4>`_.
|
||||||
|
|
||||||
|
The main difference with SAC is on harder environments (BipedalWalkerHardcore, Walker2D).
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Hyperparameters from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used (as they are tuned for SAC on PyBullet envs),
|
||||||
|
including using gSDE for the exploration and not the unstructured Gaussian noise
|
||||||
|
but this should not affect results in simulation.
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
We are using the open source PyBullet environments and not the MuJoCo simulator (as done in the original paper).
|
||||||
|
You can find a complete benchmark on PyBullet envs in the `gSDE paper <https://arxiv.org/abs/2005.05719>`_
|
||||||
|
if you want to compare TQC results to those of A2C/PPO/SAC/TD3.
|
||||||
|
|
||||||
|
|
||||||
|
===================== ============ ============
|
||||||
|
Environments SAC TQC
|
||||||
|
===================== ============ ============
|
||||||
|
\ gSDE gSDE
|
||||||
|
HalfCheetah 2984 +/- 202 3041 +/- 157
|
||||||
|
Ant 3102 +/- 37 3700 +/- 37
|
||||||
|
Hopper 2262 +/- 1 2401 +/- 62*
|
||||||
|
Walker2D 2136 +/- 67 2535 +/- 94
|
||||||
|
BipedalWalkerHardcore 13 +/- 18 228 +/- 18
|
||||||
|
===================== ============ ============
|
||||||
|
|
||||||
|
\* with tuned hyperparameter ``top_quantiles_to_drop_per_net`` taken from the original paper
|
||||||
|
|
||||||
|
|
||||||
|
How to replicate the results?
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Clone RL-Zoo and checkout the branch ``feat/tqc``:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||||
|
cd rl-baselines3-zoo/
|
||||||
|
git checkout feat/tqc
|
||||||
|
|
||||||
|
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python train.py --algo tqc --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||||
|
|
||||||
|
|
||||||
|
Plot the results:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python scripts/all_plots.py -a tqc -e HalfCheetah Ant Hopper Walker2D BipedalWalkerHardcore -f logs/ -o logs/tqc_results
|
||||||
|
python scripts/plot_from_file.py -i logs/tqc_results.pkl -latex -l TQC
|
||||||
|
|
||||||
|
Comments
|
||||||
|
--------
|
||||||
|
|
||||||
|
This implementation is based on SB3 SAC implementation and uses the code from the original TQC implementation for the quantile huber loss.
|
||||||
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
|
||||||
|
.. autoclass:: TQC
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
|
||||||
|
.. _tqc_policies:
|
||||||
|
|
||||||
|
TQC Policies
|
||||||
|
-------------
|
||||||
|
|
||||||
|
.. autoclass:: MlpPolicy
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
|
||||||
|
.. autoclass:: sb3_contrib.tqc.policies.TQCPolicy
|
||||||
|
:members:
|
||||||
|
:noindex:
|
||||||
|
|
||||||
|
.. autoclass:: CnnPolicy
|
||||||
|
:members:
|
||||||
|
|
@ -5,7 +5,13 @@ import torch as th
|
||||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy
|
from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
|
from stable_baselines3.common.torch_layers import (
|
||||||
|
BaseFeaturesExtractor,
|
||||||
|
FlattenExtractor,
|
||||||
|
NatureCNN,
|
||||||
|
create_mlp,
|
||||||
|
get_actor_critic_arch,
|
||||||
|
)
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
# CAP the standard deviation of the actor
|
# CAP the standard deviation of the actor
|
||||||
|
|
@ -306,6 +312,8 @@ class TQCPolicy(BasePolicy):
|
||||||
else:
|
else:
|
||||||
net_arch = []
|
net_arch = []
|
||||||
|
|
||||||
|
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||||
|
|
||||||
# Create shared features extractor
|
# Create shared features extractor
|
||||||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||||
self.features_dim = self.features_extractor.features_dim
|
self.features_dim = self.features_extractor.features_dim
|
||||||
|
|
@ -317,12 +325,11 @@ class TQCPolicy(BasePolicy):
|
||||||
"action_space": self.action_space,
|
"action_space": self.action_space,
|
||||||
"features_extractor": self.features_extractor,
|
"features_extractor": self.features_extractor,
|
||||||
"features_dim": self.features_dim,
|
"features_dim": self.features_dim,
|
||||||
"net_arch": self.net_arch,
|
"net_arch": actor_arch,
|
||||||
"activation_fn": self.activation_fn,
|
"activation_fn": self.activation_fn,
|
||||||
"normalize_images": normalize_images,
|
"normalize_images": normalize_images,
|
||||||
}
|
}
|
||||||
self.actor_kwargs = self.net_args.copy()
|
self.actor_kwargs = self.net_args.copy()
|
||||||
self.critic_kwargs = self.net_args.copy()
|
|
||||||
sde_kwargs = {
|
sde_kwargs = {
|
||||||
"use_sde": use_sde,
|
"use_sde": use_sde,
|
||||||
"log_std_init": log_std_init,
|
"log_std_init": log_std_init,
|
||||||
|
|
@ -331,7 +338,8 @@ class TQCPolicy(BasePolicy):
|
||||||
"clip_mean": clip_mean,
|
"clip_mean": clip_mean,
|
||||||
}
|
}
|
||||||
self.actor_kwargs.update(sde_kwargs)
|
self.actor_kwargs.update(sde_kwargs)
|
||||||
tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics}
|
self.critic_kwargs = self.net_args.copy()
|
||||||
|
tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics, "net_arch": critic_arch}
|
||||||
self.critic_kwargs.update(tqc_kwargs)
|
self.critic_kwargs.update(tqc_kwargs)
|
||||||
self.actor, self.actor_target = None, None
|
self.actor, self.actor_target = None, None
|
||||||
self.critic, self.critic_target = None, None
|
self.critic, self.critic_target = None, None
|
||||||
|
|
@ -356,7 +364,7 @@ class TQCPolicy(BasePolicy):
|
||||||
|
|
||||||
data.update(
|
data.update(
|
||||||
dict(
|
dict(
|
||||||
net_arch=self.net_args["net_arch"],
|
net_arch=self.net_arch,
|
||||||
activation_fn=self.net_args["activation_fn"],
|
activation_fn=self.net_args["activation_fn"],
|
||||||
use_sde=self.actor_kwargs["use_sde"],
|
use_sde=self.actor_kwargs["use_sde"],
|
||||||
log_std_init=self.actor_kwargs["log_std_init"],
|
log_std_init=self.actor_kwargs["log_std_init"],
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
|
|
||||||
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics.
|
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics.
|
||||||
Paper: https://arxiv.org/abs/2005.04269
|
Paper: https://arxiv.org/abs/2005.04269
|
||||||
|
This implementation uses SB3 SAC implementation as base.
|
||||||
|
|
||||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||||
|
|
@ -169,6 +170,14 @@ class TQC(OffPolicyAlgorithm):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor:
|
def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
The quantile-regression loss, as described in the QR-DQN and TQC papers.
|
||||||
|
Taken from https://github.com/bayesgroup/tqc_pytorch
|
||||||
|
|
||||||
|
:param quantiles:
|
||||||
|
:param samples:
|
||||||
|
:return: the loss
|
||||||
|
"""
|
||||||
# batch x nets x quantiles x samples
|
# batch x nets x quantiles x samples
|
||||||
pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None]
|
pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None]
|
||||||
abs_pairwise_delta = th.abs(pairwise_delta)
|
abs_pairwise_delta = th.abs(pairwise_delta)
|
||||||
|
|
@ -251,7 +260,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
self.critic.optimizer.step()
|
self.critic.optimizer.step()
|
||||||
|
|
||||||
# Compute actor loss
|
# Compute actor loss
|
||||||
qf_pi = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True)
|
qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True)
|
||||||
actor_loss = (ent_coef * log_prob - qf_pi).mean()
|
actor_loss = (ent_coef * log_prob - qf_pi).mean()
|
||||||
actor_losses.append(actor_loss.item())
|
actor_losses.append(actor_loss.item())
|
||||||
|
|
||||||
|
|
@ -299,19 +308,10 @@ class TQC(OffPolicyAlgorithm):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _excluded_save_params(self) -> List[str]:
|
def _excluded_save_params(self) -> List[str]:
|
||||||
"""
|
|
||||||
Returns the names of the parameters that should be excluded by default
|
|
||||||
when saving the model.
|
|
||||||
|
|
||||||
:return: List of parameters that should be excluded from save
|
|
||||||
"""
|
|
||||||
# Exclude aliases
|
# Exclude aliases
|
||||||
return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
|
return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
|
||||||
|
|
||||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||||
"""
|
|
||||||
cf base class
|
|
||||||
"""
|
|
||||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||||
saved_pytorch_variables = ["log_ent_coef"]
|
saved_pytorch_variables = ["log_ent_coef"]
|
||||||
if self.ent_coef_optimizer is not None:
|
if self.ent_coef_optimizer is not None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue