diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8ebb1ec..d51f9f1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -93,13 +93,19 @@ Results A description and comparison of results (e.g. how the change improved results over the non-changed algorithm), if applicable. +Please link the associated pull request, e.g., `Pull Request #4 `_. Include the expected results from the work that originally proposed the method (e.g. original paper). +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + Include the code to replicate these results or a link to repository/branch where the code can be found. Use `rl-baselines3-zoo `_ if possible, fork it, create a new branch and share the code to replicate results there. +If applicable, please also provide the command to replicate the plots. + Comments -------- diff --git a/docs/_static/img/logo.png~ b/docs/_static/img/logo.png~ deleted file mode 100644 index ef592f5..0000000 Binary files a/docs/_static/img/logo.png~ and /dev/null differ diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 1d863ea..1d5ce80 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -4,27 +4,22 @@ RL Algorithms This table displays the rl algorithms that are implemented in the Stable Baselines3 contrib project, along with some useful characteristics: support for discrete/continuous actions, multiprocessing. -.. -.. ============ =========== ============ ================= =============== ================ -.. Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing -.. ============ =========== ============ ================= =============== ================ -.. A2C ✔️ ✔️ ✔️ ✔️ ✔️ -.. DDPG ✔️ ❌ ❌ ❌ ❌ -.. DQN ❌ ✔️ ❌ ❌ ❌ -.. PPO ✔️ ✔️ ✔️ ✔️ ✔️ -.. SAC ✔️ ❌ ❌ ❌ ❌ -.. TD3 ✔️ ❌ ❌ ❌ ❌ -.. ============ =========== ============ ================= =============== ================ + +============ =========== ============ ================= =============== ================ +Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing +============ =========== ============ ================= =============== ================ +TQC ✔️ ❌ ❌ ❌ ❌ +============ =========== ============ ================= =============== ================ -.. .. note:: -.. Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm. -.. -.. Actions ``gym.spaces``: -.. -.. - ``Box``: A N-dimensional box that contains every point in the action -.. space. -.. - ``Discrete``: A list of possible actions, where each timestep only -.. one of the actions can be used. -.. - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used. -.. - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination. +.. note:: + Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm. + +Actions ``gym.spaces``: + +- ``Box``: A N-dimensional box that contains every point in the action + space. +- ``Discrete``: A list of possible actions, where each timestep only + one of the actions can be used. +- ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used. +- ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination. diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 65b2d08..c535bfd 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -3,7 +3,18 @@ Examples ======== -WIP +TQC +--- + +Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment. + +.. code-block:: python + + from sb3_contrib import TQC + + model = TQC("MlpPolicy", "Pendulum-v0", top_quantiles_to_drop_per_net=2, verbose=1) + model.learn(total_timesteps=10000, log_interval=4) + model.save("tqc_pendulum") .. PyBullet: Normalizing input features .. ------------------------------------ @@ -13,7 +24,7 @@ WIP .. for instance when training on `PyBullet `__ environments. For that, a wrapper exists and .. will compute a running average and standard deviation of input features (it can do the same for rewards). .. - +.. .. .. note:: .. .. you need to install pybullet with ``pip install pybullet`` diff --git a/docs/index.rst b/docs/index.rst index d590f25..f701110 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -27,11 +27,11 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d guide/examples -.. .. toctree:: -.. :maxdepth: 1 -.. :caption: RL Algorithms -.. -.. modules/a2c +.. toctree:: + :maxdepth: 1 + :caption: RL Algorithms + + modules/tqc .. .. toctree:: .. :maxdepth: 1 diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index db44b6d..f2ea8b0 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -12,6 +12,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added ``TQC`` algorithm (@araffin) Bug Fixes: ^^^^^^^^^^ diff --git a/docs/modules/tqc.rst b/docs/modules/tqc.rst new file mode 100644 index 0000000..28279c7 --- /dev/null +++ b/docs/modules/tqc.rst @@ -0,0 +1,168 @@ +.. _tqc: + +.. automodule:: sb3_contrib.tqc + + +TQC +=== + +Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics (TQC). +Truncated Quantile Critics (TQC) builds on SAC, TD3 and QR-DQN, making use of quantile regression to predict a distribution for the value function (instead of a mean value). +It truncates the quantiles predicted by different networks (a bit as it is done in TD3). + + +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + MlpPolicy + CnnPolicy + + +Notes +----- + +- Original paper: https://arxiv.org/abs/2005.04269 +- Original Implementation: https://github.com/bayesgroup/tqc_pytorch + + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ❌ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ❌ ✔️ +Box ✔️ ✔️ +MultiDiscrete ❌ ✔️ +MultiBinary ❌ ✔️ +============= ====== =========== + + +Example +------- + +.. code-block:: python + + import gym + import numpy as np + + from sb3_contrib import TQC + + env = gym.make("Pendulum-v0") + + policy_kwargs = dict(n_critics=2, n_quantiles=25) + model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs) + model.learn(total_timesteps=10000, log_interval=4) + model.save("tqc_pendulum") + + del model # remove to demonstrate saving and loading + + model = TQC.load("tqc_pendulum") + + obs = env.reset() + while True: + action, _states = model.predict(obs, deterministic=True) + obs, reward, done, info = env.step(action) + env.render() + if done: + obs = env.reset() + + +Results +------- + +Result on the PyBullet benchmark (1M steps) and on BipedalWalkerHardcore-v3 (2M steps) +using 3 seeds. +The complete learning curves are available in the `associated PR `_. + +The main difference with SAC is on harder environments (BipedalWalkerHardcore, Walker2D). + + +.. note:: + + Hyperparameters from the `gSDE paper `_ were used (as they are tuned for SAC on PyBullet envs), + including using gSDE for the exploration and not the unstructured Gaussian noise + but this should not affect results in simulation. + + +.. note:: + + We are using the open source PyBullet environments and not the MuJoCo simulator (as done in the original paper). + You can find a complete benchmark on PyBullet envs in the `gSDE paper `_ + if you want to compare TQC results to those of A2C/PPO/SAC/TD3. + + +===================== ============ ============ +Environments SAC TQC +===================== ============ ============ +\ gSDE gSDE +HalfCheetah 2984 +/- 202 3041 +/- 157 +Ant 3102 +/- 37 3700 +/- 37 +Hopper 2262 +/- 1 2401 +/- 62* +Walker2D 2136 +/- 67 2535 +/- 94 +BipedalWalkerHardcore 13 +/- 18 228 +/- 18 +===================== ============ ============ + +\* with tuned hyperparameter ``top_quantiles_to_drop_per_net`` taken from the original paper + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone RL-Zoo and checkout the branch ``feat/tqc``: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + git checkout feat/tqc + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo tqc --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a tqc -e HalfCheetah Ant Hopper Walker2D BipedalWalkerHardcore -f logs/ -o logs/tqc_results + python scripts/plot_from_file.py -i logs/tqc_results.pkl -latex -l TQC + +Comments +-------- + +This implementation is based on SB3 SAC implementation and uses the code from the original TQC implementation for the quantile huber loss. + + +Parameters +---------- + +.. autoclass:: TQC + :members: + :inherited-members: + +.. _tqc_policies: + +TQC Policies +------------- + +.. autoclass:: MlpPolicy + :members: + :inherited-members: + +.. autoclass:: sb3_contrib.tqc.policies.TQCPolicy + :members: + :noindex: + +.. autoclass:: CnnPolicy + :members: diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index f728c3a..788a7ba 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -5,7 +5,13 @@ import torch as th from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + FlattenExtractor, + NatureCNN, + create_mlp, + get_actor_critic_arch, +) from torch import nn as nn # CAP the standard deviation of the actor @@ -306,6 +312,8 @@ class TQCPolicy(BasePolicy): else: net_arch = [] + actor_arch, critic_arch = get_actor_critic_arch(net_arch) + # Create shared features extractor self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) self.features_dim = self.features_extractor.features_dim @@ -317,12 +325,11 @@ class TQCPolicy(BasePolicy): "action_space": self.action_space, "features_extractor": self.features_extractor, "features_dim": self.features_dim, - "net_arch": self.net_arch, + "net_arch": actor_arch, "activation_fn": self.activation_fn, "normalize_images": normalize_images, } self.actor_kwargs = self.net_args.copy() - self.critic_kwargs = self.net_args.copy() sde_kwargs = { "use_sde": use_sde, "log_std_init": log_std_init, @@ -331,7 +338,8 @@ class TQCPolicy(BasePolicy): "clip_mean": clip_mean, } self.actor_kwargs.update(sde_kwargs) - tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics} + self.critic_kwargs = self.net_args.copy() + tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics, "net_arch": critic_arch} self.critic_kwargs.update(tqc_kwargs) self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None @@ -356,7 +364,7 @@ class TQCPolicy(BasePolicy): data.update( dict( - net_arch=self.net_args["net_arch"], + net_arch=self.net_arch, activation_fn=self.net_args["activation_fn"], use_sde=self.actor_kwargs["use_sde"], log_std_init=self.actor_kwargs["log_std_init"], diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index ea6818e..d8fd647 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -16,6 +16,7 @@ class TQC(OffPolicyAlgorithm): Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics. Paper: https://arxiv.org/abs/2005.04269 + This implementation uses SB3 SAC implementation as base. :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) @@ -169,6 +170,14 @@ class TQC(OffPolicyAlgorithm): @staticmethod def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor: + """ + The quantile-regression loss, as described in the QR-DQN and TQC papers. + Taken from https://github.com/bayesgroup/tqc_pytorch + + :param quantiles: + :param samples: + :return: the loss + """ # batch x nets x quantiles x samples pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None] abs_pairwise_delta = th.abs(pairwise_delta) @@ -251,7 +260,7 @@ class TQC(OffPolicyAlgorithm): self.critic.optimizer.step() # Compute actor loss - qf_pi = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True) + qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True) actor_loss = (ent_coef * log_prob - qf_pi).mean() actor_losses.append(actor_loss.item()) @@ -299,19 +308,10 @@ class TQC(OffPolicyAlgorithm): ) def _excluded_save_params(self) -> List[str]: - """ - Returns the names of the parameters that should be excluded by default - when saving the model. - - :return: List of parameters that should be excluded from save - """ # Exclude aliases return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: - """ - cf base class - """ state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] saved_pytorch_variables = ["log_ent_coef"] if self.ent_coef_optimizer is not None: