Add QR-DQN (#13)

* Add QR-DQN(WIP)

* Update docstring

* Add quantile_huber_loss

* Fix typo

* Remove unnecessary lines

* Update variable names and comments in quantile_huber_loss

* Fix mutable arguments

* Update variable names

* Ignore import not used warnings

* Fix default parameter of optimizer in QR-DQN

* Update quantile_huber_loss to have more reasonable interface

* update tests

* Add assertion to quantile_huber_loss

* Update variable names of quantile regression

* Update comments

* Reduce the number of quantiles during test

* Update comment

* Update quantile_huber_loss

* Fix isort

* Add document of QR-DQN without results

* Update docs

* Fix bugs

* Update doc

* Add comments about shape

* Minor edits

* Update comments

* Add benchmark

* Doc fixes

* Update doc

* Bug fix in saving/loading + update tests

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Toshiki Watanabe 2020-12-21 19:17:48 +09:00 committed by GitHub
parent 3598ca284a
commit b30397fff5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 963 additions and 67 deletions

View File

@ -25,6 +25,7 @@ See documentation for the full list of included features.
**RL Algorithms**:
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)

View File

@ -9,6 +9,7 @@ along with some useful characteristics: support for discrete/continuous actions,
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
TQC ✔️ ❌ ❌ ❌ ❌
QR-DQN ️❌ ️✔️ ❌ ❌ ❌
============ =========== ============ ================= =============== ================

View File

@ -16,6 +16,21 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
model.learn(total_timesteps=10000, log_interval=4)
model.save("tqc_pendulum")
QR-DQN
------
Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.
.. code-block:: python
from sb3_contrib import QRDQN
policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("qrdqn_cartpole")
.. PyBullet: Normalizing input features
.. ------------------------------------
..

View File

@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:caption: RL Algorithms
modules/tqc
modules/qrdqn
.. toctree::
:maxdepth: 1

View File

@ -13,9 +13,11 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``TimeFeatureWrapper`` to the wrappers
- Added ``QR-DQN`` algorithm (`@ku2482`_)
Bug Fixes:
^^^^^^^^^^
- Fixed bug in ``TQC`` when saving/loading the policy only with non-default number of quantiles
Deprecations:
^^^^^^^^^^^^^
@ -24,6 +26,7 @@ Others:
^^^^^^^
- Updated ``TQC`` to match new SB3 version
- Updated SB3 min version
- Moved ``quantile_huber_loss`` to ``common/utils.py`` (@ku2482)
Documentation:
^^^^^^^^^^^^^^
@ -62,13 +65,19 @@ Maintainers
-----------
Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a),
`Maximilian Ernestus`_ (aka @erniejunior), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
.. _Ashley Hill: https://github.com/hill-a
.. _Antonin Raffin: https://araffin.github.io/
.. _Maximilian Ernestus: https://github.com/erniejunior
.. _Maximilian Ernestus: https://github.com/ernestum
.. _Adam Gleave: https://gleave.me/
.. _@araffin: https://github.com/araffin
.. _@AdamGleave: https://github.com/adamgleave
.. _Anssi Kanervisto: https://github.com/Miffyli
.. _@Miffyli: https://github.com/Miffyli
.. _@ku2482: https://github.com/ku2482
Contributors:
-------------
@ku2482

150
docs/modules/qrdqn.rst Normal file
View File

@ -0,0 +1,150 @@
.. _qrdqn:
.. automodule:: sb3_contrib.qrdqn
QR-DQN
======
`Quantile Regression DQN (QR-DQN) <https://arxiv.org/abs/1710.10044>`_ builds on `Deep Q-Network (DQN) <https://arxiv.org/abs/1312.5602>`_
and make use of quantile regression to explicitly model the `distribution over returns <https://arxiv.org/abs/1707.06887>`_,
instead of predicting the mean return (DQN).
.. rubric:: Available Policies
.. autosummary::
:nosignatures:
MlpPolicy
CnnPolicy
Notes
-----
- Original paper: https://arxiv.org/abs/1710.100442
- Distributional RL (C51): https://arxiv.org/abs/1707.06887
Can I use?
----------
- Recurrent policies: ❌
- Multi processing: ❌
- Gym spaces:
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔ ✔
Box ❌ ✔
MultiDiscrete ❌ ✔
MultiBinary ❌ ✔
============= ====== ===========
Example
-------
.. code-block:: python
import gym
from sb3_contrib import QRDQN
env = gym.make("CartPole-v1")
policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("qrdqn_cartpole")
del model # remove to demonstrate saving and loading
model = QRDQN.load("qrdqn_cartpole")
obs = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()
Results
-------
Result on Atari environments (10M steps, Pong and Breakout) and classic control tasks using 3 and 5 seeds.
The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/13>`_.
.. note::
QR-DQN implementation was validated against `Intel Coach <https://github.com/IntelLabs/coach/tree/master/benchmarks/qr_dqn>`_ one
which roughly compare to the original paper results (we trained the agent with a smaller budget).
============ ========== ===========
Environments QR-DQN DQN
============ ========== ===========
Breakout 413 +/- 21 ~300
Pong 20 +/- 0 ~20
CartPole 386 +/- 64 500 +/- 0
MountainCar -111 +/- 4 -107 +/- 4
LunarLander 168 +/- 39 195 +/- 28
Acrobot -73 +/- 2 -74 +/- 2
============ ========== ===========
How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Clone RL-Zoo fork and checkout the branch ``feat/qrdqn``:
.. code-block:: bash
git clone https://github.com/ku2482/rl-baselines3-zoo/
cd rl-baselines3-zoo/
git checkout feat/qrdqn
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
.. code-block:: bash
python train.py --algo qrdqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results:
.. code-block:: bash
python scripts/all_plots.py -a qrdqn -e Breakout Pong -f logs/ -o logs/qrdqn_results
python scripts/plot_from_file.py -i logs/qrdqn_results.pkl -latex -l QR-DQN
Parameters
----------
.. autoclass:: QRDQN
:members:
:inherited-members:
.. _qrdqn_policies:
QR-DQN Policies
---------------
.. autoclass:: MlpPolicy
:members:
:inherited-members:
.. autoclass:: sb3_contrib.qrdqn.policies.QRDQNPolicy
:members:
:noindex:
.. autoclass:: CnnPolicy
:members:

View File

@ -1,6 +1,6 @@
import os
# from sb3_contrib.cmaes import CMAES
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC
# Read version from file

View File

@ -0,0 +1,69 @@
from typing import Optional
import torch as th
def quantile_huber_loss(
current_quantiles: th.Tensor,
target_quantiles: th.Tensor,
cum_prob: Optional[th.Tensor] = None,
sum_over_quantiles: bool = True,
) -> th.Tensor:
"""
The quantile-regression loss, as described in the QR-DQN and TQC papers.
Partially taken from https://github.com/bayesgroup/tqc_pytorch.
:param current_quantiles: current estimate of quantiles, must be either
(batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles)
:param target_quantiles: target of quantiles, must be either (batch_size, n_target_quantiles),
(batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles)
:param cum_prob: cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper),
must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles).
(if None, calculating unit quantiles)
:param sum_over_quantiles: if summing over the quantile dimension or not
:return: the loss
"""
if current_quantiles.ndim != target_quantiles.ndim:
raise ValueError(
f"Error: The dimension of curremt_quantile ({current_quantiles.ndim}) needs to match "
f"the dimension of target_quantiles ({target_quantiles.ndim})."
)
if current_quantiles.shape[0] != target_quantiles.shape[0]:
raise ValueError(
f"Error: The batch size of curremt_quantile ({current_quantiles.shape[0]}) needs to match "
f"the batch size of target_quantiles ({target_quantiles.shape[0]})."
)
if current_quantiles.ndim not in (2, 3):
raise ValueError(f"Error: The dimension of current_quantiles ({current_quantiles.ndim}) needs to be either 2 or 3.")
if cum_prob is None:
n_quantiles = current_quantiles.shape[-1]
# Cumulative probabilities to calculate quantiles.
cum_prob = (th.arange(n_quantiles, device=current_quantiles.device, dtype=th.float) + 0.5) / n_quantiles
if current_quantiles.ndim == 2:
# For QR-DQN, current_quantiles have a shape (batch_size, n_quantiles), and make cum_prob
# broadcastable to (batch_size, n_quantiles, n_target_quantiles)
cum_prob = cum_prob.view(1, -1, 1)
elif current_quantiles.ndim == 3:
# For TQC, current_quantiles have a shape (batch_size, n_critics, n_quantiles), and make cum_prob
# broadcastable to (batch_size, n_critics, n_quantiles, n_target_quantiles)
cum_prob = cum_prob.view(1, 1, -1, 1)
# QR-DQN
# target_quantiles: (batch_size, n_target_quantiles) -> (batch_size, 1, n_target_quantiles)
# current_quantiles: (batch_size, n_quantiles) -> (batch_size, n_quantiles, 1)
# pairwise_delta: (batch_size, n_target_quantiles, n_quantiles)
# TQC
# target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles)
# current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1)
# pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles)
# Note: in both cases, the loss has the same shape as pairwise_delta
pairwise_delta = target_quantiles.unsqueeze(-2) - current_quantiles.unsqueeze(-1)
abs_pairwise_delta = th.abs(pairwise_delta)
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
loss = th.abs(cum_prob - (pairwise_delta.detach() < 0).float()) * huber_loss
if sum_over_quantiles:
loss = loss.sum(dim=-2).mean()
else:
loss = loss.mean()
return loss

View File

@ -0,0 +1,2 @@
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy
from sb3_contrib.qrdqn.qrdqn import QRDQN

View File

@ -0,0 +1,249 @@
from typing import Any, Dict, List, Optional, Type
import gym
import torch as th
from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
from stable_baselines3.common.type_aliases import Schedule
from torch import nn
class QuantileNetwork(BasePolicy):
"""
Quantile network for QR-DQN
:param observation_space: Observation space
:param action_space: Action space
:param n_quantiles: Number of quantiles
:param net_arch: The specification of the network architecture.
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
features_extractor: nn.Module,
features_dim: int,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
super(QuantileNetwork, self).__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)
if net_arch is None:
net_arch = [64, 64]
self.net_arch = net_arch
self.activation_fn = activation_fn
self.features_extractor = features_extractor
self.features_dim = features_dim
self.n_quantiles = n_quantiles
self.normalize_images = normalize_images
action_dim = self.action_space.n # number of actions
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
self.quantile_net = nn.Sequential(*quantile_net)
def forward(self, obs: th.Tensor) -> th.Tensor:
"""
Predict the quantiles.
:param obs: Observation
:return: The estimated quantiles for each action.
"""
quantiles = self.quantile_net(self.extract_features(obs))
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
q_values = self.forward(observation).mean(dim=1)
# Greedy action
action = q_values.argmax(dim=1).reshape(-1)
return action
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
data.update(
dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
n_quantiles=self.n_quantiles,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor,
)
)
return data
class QRDQNPolicy(BasePolicy):
"""
Policy class with quantile and target networks for QR-DQN.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param n_quantiles: Number of quantiles
:param net_arch: The specification of the network architecture.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(QRDQNPolicy, self).__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
)
if net_arch is None:
if features_extractor_class == FlattenExtractor:
net_arch = [64, 64]
else:
net_arch = []
self.n_quantiles = n_quantiles
self.net_arch = net_arch
self.activation_fn = activation_fn
self.normalize_images = normalize_images
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"n_quantiles": self.n_quantiles,
"net_arch": self.net_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
self.quantile_net, self.quantile_net_target = None, None
self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None:
"""
Create the network and the optimizer.
:param lr_schedule: Learning rate schedule
lr_schedule(1) is the initial learning rate
"""
self.quantile_net = self.make_quantile_net()
self.quantile_net_target = self.make_quantile_net()
self.quantile_net_target.load_state_dict(self.quantile_net.state_dict())
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def make_quantile_net(self) -> QuantileNetwork:
# Make sure we always have separate networks for features extractors etc
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return QuantileNetwork(**net_args).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
return self.quantile_net._predict(obs, deterministic=deterministic)
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
data.update(
dict(
n_quantiles=self.net_args["n_quantiles"],
net_arch=self.net_args["net_arch"],
activation_fn=self.net_args["activation_fn"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
MlpPolicy = QRDQNPolicy
class CnnPolicy(QRDQNPolicy):
"""
Policy class for QR-DQN when using images as input.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param n_quantiles: Number of quantiles
:param net_arch: The specification of the network architecture.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(CnnPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
n_quantiles,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)

253
sb3_contrib/qrdqn/qrdqn.py Normal file
View File

@ -0,0 +1,253 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from stable_baselines3.common import logger
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import QRDQNPolicy
class QRDQN(OffPolicyAlgorithm):
"""
Quantile Regression Deep Q-Network (QR-DQN)
Paper: https://arxiv.org/abs/1710.10044
Default hyperparameters are taken from the paper and are tuned for Atari games.
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable.
:param gradient_steps: How many gradient steps to do after each rollout
(see ``train_freq`` and ``n_episodes_rollout``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param target_update_interval: update the target network every ``target_update_interval``
environment steps.
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
:param exploration_initial_eps: initial value of random action probability
:param exploration_final_eps: final value of random action probability
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping)
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
def __init__(
self,
policy: Union[str, Type[QRDQNPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 5e-5,
buffer_size: int = 1000000,
learning_starts: int = 50000,
batch_size: Optional[int] = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: int = 4,
gradient_steps: int = 1,
n_episodes_rollout: int = -1,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
exploration_fraction: float = 0.005,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.01,
max_grad_norm: Optional[float] = None,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super(QRDQN, self).__init__(
policy,
env,
QRDQNPolicy,
learning_rate,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
n_episodes_rollout,
action_noise=None, # No action noise
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
create_eval_env=create_eval_env,
seed=seed,
sde_support=False,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Discrete,),
)
self.exploration_initial_eps = exploration_initial_eps
self.exploration_final_eps = exploration_final_eps
self.exploration_fraction = exploration_fraction
self.target_update_interval = target_update_interval
self.max_grad_norm = max_grad_norm
# "epsilon" for the epsilon-greedy exploration
self.exploration_rate = 0.0
# Linear schedule will be defined in `_setup_model()`
self.exploration_schedule = None
self.quantile_net, self.quantile_net_target = None, None
if "optimizer_class" not in self.policy_kwargs:
self.policy_kwargs["optimizer_class"] = th.optim.Adam
# Proposed in the QR-DQN paper where `batch_size = 32`
self.policy_kwargs["optimizer_kwargs"] = dict(eps=0.01 / batch_size)
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
super(QRDQN, self)._setup_model()
self._create_aliases()
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
)
def _create_aliases(self) -> None:
self.quantile_net = self.policy.quantile_net
self.quantile_net_target = self.policy.quantile_net_target
self.n_quantiles = self.policy.n_quantiles
def _on_step(self) -> None:
"""
Update the exploration rate and target network if needed.
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
if self.num_timesteps % self.target_update_interval == 0:
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
logger.record("rollout/exploration rate", self.exploration_rate)
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Update learning rate according to schedule
self._update_learning_rate(self.policy.optimizer)
losses = []
for gradient_step in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
with th.no_grad():
# Compute the quantiles of next observation
next_quantiles = self.quantile_net_target(replay_data.next_observations)
# Follow greedy policy: use the one with the highest value
next_quantiles, _ = next_quantiles.max(dim=2)
# 1-step TD target
target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles
# Get current quantile estimates
current_quantiles = self.quantile_net(replay_data.observations)
# Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1).
actions = replay_data.actions[..., None].long().expand(batch_size, self.n_quantiles, 1)
# Retrieve the quantiles for the actions from the replay buffer
current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2)
# Compute Quantile Huber loss, summing over a quantile dimension as in the paper.
loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True)
losses.append(loss.item())
# Optimize the policy
self.policy.optimizer.zero_grad()
loss.backward()
# Clip gradient norm
if self.max_grad_norm is not None:
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
# Increase update counter
self._n_updates += gradient_steps
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
logger.record("train/loss", np.mean(losses))
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Overrides the base_class predict function to include epsilon-greedy exploration.
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
(used in recurrent policies)
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(observation, self.observation_space):
n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])
else:
action = np.array(self.action_space.sample())
else:
action, state = self.policy.predict(observation, state, mask, deterministic)
return action, state
def learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "QRDQN",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
return super(QRDQN, self).learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
eval_env=eval_env,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
)
def _excluded_save_params(self) -> List[str]:
return super(QRDQN, self)._excluded_save_params() + ["quantile_net", "quantile_net_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]
return state_dicts, []

View File

@ -391,6 +391,8 @@ class TQCPolicy(BasePolicy):
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
n_quantiles=self.critic_kwargs["n_quantiles"],
n_critics=self.critic_kwargs["n_critics"],
)
)
return data

View File

@ -9,6 +9,7 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import polyak_update
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.tqc.policies import TQCPolicy
@ -171,26 +172,6 @@ class TQC(OffPolicyAlgorithm):
self.critic = self.policy.critic
self.critic_target = self.policy.critic_target
@staticmethod
def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor:
"""
The quantile-regression loss, as described in the QR-DQN and TQC papers.
Taken from https://github.com/bayesgroup/tqc_pytorch
:param quantiles:
:param samples:
:return: the loss
"""
# batch x nets x quantiles x samples
pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None]
abs_pairwise_delta = th.abs(pairwise_delta)
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
n_quantiles = quantiles.shape[2]
tau = th.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles
loss = (th.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean()
return loss
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Update optimizers learning rate
optimizers = [self.actor.optimizer, self.critic.optimizer]
@ -237,24 +218,27 @@ class TQC(OffPolicyAlgorithm):
self.ent_coef_optimizer.step()
with th.no_grad():
top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute and cut quantiles at the next state
# batch x nets x quantiles
next_z = self.critic_target(replay_data.next_observations, next_actions)
sorted_z, _ = th.sort(next_z.reshape(batch_size, -1))
sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop]
next_quantiles = self.critic_target(replay_data.next_observations, next_actions)
# Sort and drop top k quantiles to control overestimation.
n_target_quantiles = self.critic.quantiles_total - self.top_quantiles_to_drop_per_net * self.critic.n_critics
next_quantiles, _ = th.sort(next_quantiles.reshape(batch_size, -1))
next_quantiles = next_quantiles[:, :n_target_quantiles]
target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1)
# td error + entropy term
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
target_quantiles = next_quantiles - ent_coef * next_log_prob.reshape(-1, 1)
target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_quantiles
# Make target_quantiles broadcastable to (batch_size, n_critics, n_target_quantiles).
target_quantiles.unsqueeze_(dim=1)
# Get current Q estimates
# using action from the replay buffer
current_z = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = self.quantile_huber_loss(current_z, q_backup)
# Get current Quantile estimates using action from the replay buffer
current_quantiles = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss, not summing over the quantile dimension as in the paper.
critic_loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=False)
critic_losses.append(critic_loss.item())
# Optimize the critic

View File

@ -25,6 +25,7 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators
# Ignore import not used when aliases are defined
per-file-ignores =
./sb3_contrib/__init__.py:F401
./sb3_contrib/qrdqn/__init__.py:F401
./sb3_contrib/tqc/__init__.py:F401
./sb3_contrib/common/wrappers/__init__.py:F401
exclude =

View File

@ -7,10 +7,10 @@ import torch as th
from stable_baselines3.common.identity_env import FakeImageEnv
from stable_baselines3.common.utils import zip_strict
from sb3_contrib import TQC
from sb3_contrib import QRDQN, TQC
@pytest.mark.parametrize("model_class", [TQC])
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
def test_cnn(tmp_path, model_class):
SAVE_NAME = "cnn_model.zip"
# Fake grayscale with frameskip
@ -18,10 +18,13 @@ def test_cnn(tmp_path, model_class):
# to check that the network handle it automatically
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
kwargs = {}
if model_class in {TQC}:
if model_class in {TQC, QRDQN}:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
# Reduce the size of the features and the number of quantiles
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)),
)
model = model_class("CnnPolicy", env, **kwargs).learn(250)
obs = env.reset()
@ -39,6 +42,13 @@ def test_cnn(tmp_path, model_class):
os.remove(str(tmp_path / SAVE_NAME))
def patch_qrdqn_names_(model):
# Small hack to make the test work with QRDQN
if isinstance(model, QRDQN):
model.critic = model.quantile_net
model.critic_target = model.quantile_net_target
def params_should_match(params, other_params):
for param, other_param in zip_strict(params, other_params):
assert th.allclose(param, other_param)
@ -49,28 +59,36 @@ def params_should_differ(params, other_params):
assert not th.allclose(param, other_param)
@pytest.mark.parametrize("model_class", [TQC])
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
@pytest.mark.parametrize("share_features_extractor", [True, False])
def test_feature_extractor_target_net(model_class, share_features_extractor):
if model_class == QRDQN and share_features_extractor:
pytest.skip()
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(
buffer_size=250,
learning_starts=100,
policy_kwargs=dict(
features_extractor_kwargs=dict(features_dim=32),
share_features_extractor=share_features_extractor,
),
)
if model_class in {TQC, QRDQN}:
# Avoid memory error when using replay buffer
# Reduce the size of the features and the number of quantiles
kwargs = dict(
buffer_size=250,
learning_starts=100,
policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)),
)
if model_class != QRDQN:
kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor
model = model_class("CnnPolicy", env, seed=0, **kwargs)
patch_qrdqn_names_(model)
if share_features_extractor:
# Check that the objects are the same and not just copied
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
else:
# Check that the objects differ
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
if model_class != QRDQN:
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
# Critic and target should be equal at the begginning of training
params_should_match(model.critic.parameters(), model.critic_target.parameters())
@ -83,6 +101,8 @@ def test_feature_extractor_target_net(model_class, share_features_extractor):
# Re-initialize and collect some random data (without doing gradient steps)
model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10)
patch_qrdqn_names_(model)
original_param = deepcopy(list(model.critic.parameters()))
original_target_param = deepcopy(list(model.critic_target.parameters()))
@ -103,6 +123,11 @@ def test_feature_extractor_target_net(model_class, share_features_extractor):
model.lr_schedule = lambda _: 0.0
# Re-activate polyak update
model.tau = 0.01
# Special case for QRDQN: target net is updated in the `collect_rollouts()`
# not the `train()` method
if model_class == QRDQN:
model.target_update_interval = 1
model._on_step()
model.train(gradient_steps=1)

View File

@ -1,6 +1,6 @@
import pytest
from sb3_contrib import TQC
from sb3_contrib import QRDQN, TQC
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
@ -21,7 +21,11 @@ def test_tqc(ent_coef):
def test_n_critics(n_critics):
# Test TQC with different number of critics
model = TQC(
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64], n_critics=n_critics),
learning_starts=100,
verbose=1,
)
model.learn(total_timesteps=300)
@ -38,3 +42,17 @@ def test_sde():
model.learn(total_timesteps=300)
model.policy.reset_noise()
model.policy.actor.get_std()
def test_qrdqn():
model = QRDQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
learning_starts=100,
buffer_size=500,
learning_rate=3e-4,
verbose=1,
create_eval_env=True,
)
model.learn(total_timesteps=500, eval_freq=250)

View File

@ -7,22 +7,21 @@ import gym
import numpy as np
import pytest
import torch as th
from stable_baselines3 import DQN
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib import TQC
from sb3_contrib import QRDQN, TQC
MODEL_LIST = [TQC]
MODEL_LIST = [TQC, QRDQN]
def select_env(model_class: BaseAlgorithm) -> gym.Env:
"""
Selects an environment with the correct action space as DQN only supports discrete action space
Selects an environment with the correct action space as QRDQN only supports discrete action space
"""
if model_class == DQN:
if model_class == QRDQN:
return IdentityEnv(10)
else:
return IdentityEnvBox(10)
@ -41,8 +40,13 @@ def test_save_load(tmp_path, model_class):
env = DummyVecEnv([lambda: select_env(model_class)])
policy_kwargs = dict(net_arch=[16])
if model_class in {QRDQN, TQC}:
policy_kwargs.update(dict(n_quantiles=20))
# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
model = model_class("MlpPolicy", env, verbose=1, policy_kwargs=policy_kwargs)
model.learn(total_timesteps=300)
env.reset()
@ -167,13 +171,18 @@ def test_set_env(model_class):
:param model_class: (BaseAlgorithm) A RL model
"""
# use discrete for DQN
# use discrete for QRDQN
env = DummyVecEnv([lambda: select_env(model_class)])
env2 = DummyVecEnv([lambda: select_env(model_class)])
env3 = select_env(model_class)
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
if model_class in {TQC, QRDQN}:
kwargs.update(dict(learning_starts=100))
kwargs["policy_kwargs"].update(dict(n_quantiles=20))
# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
model = model_class("MlpPolicy", env, **kwargs)
# learn
model.learn(total_timesteps=300)
@ -219,7 +228,7 @@ def test_exclude_include_saved_params(tmp_path, model_class):
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("model_class", [TQC])
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
def test_save_load_replay_buffer(tmp_path, model_class):
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
@ -254,20 +263,28 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
:param model_class: (BaseAlgorithm) A RL model
:param policy_str: (str) Name of the policy.
"""
kwargs = {}
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
if model_class in [TQC]:
if model_class in [TQC, QRDQN]:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
kwargs = dict(
buffer_size=250,
learning_starts=100,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN)
# Reduce number of quantiles for faster tests
if model_class in [TQC, QRDQN]:
kwargs["policy_kwargs"].update(dict(n_quantiles=20))
env = DummyVecEnv([lambda: env])
# create model
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
model = model_class(policy_str, env, verbose=1, **kwargs)
model.learn(total_timesteps=300)
env.reset()
@ -334,3 +351,83 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
os.remove(tmp_path / "policy.pkl")
if actor_class is not None:
os.remove(tmp_path / "actor.pkl")
@pytest.mark.parametrize("model_class", [QRDQN])
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
def test_save_load_q_net(tmp_path, model_class, policy_str):
"""
Test saving and loading q-network/quantile net only.
:param model_class: (BaseAlgorithm) A RL model
:param policy_str: (str) Name of the policy.
"""
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
if model_class in [QRDQN]:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(
buffer_size=250,
learning_starts=100,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN)
# Reduce number of quantiles for faster tests
if model_class in [QRDQN]:
kwargs["policy_kwargs"].update(dict(n_quantiles=20))
env = DummyVecEnv([lambda: env])
# create model
model = model_class(policy_str, env, verbose=1, **kwargs)
model.learn(total_timesteps=300)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
q_net = model.quantile_net
q_net_class = q_net.__class__
# Get dictionary of current parameters
params = deepcopy(q_net.state_dict())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
# Update model parameters with the new random values
q_net.load_state_dict(random_params)
new_params = q_net.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
params = new_params
# get selected actions
selected_actions, _ = q_net.predict(observations, deterministic=True)
# Save and load q_net
q_net.save(tmp_path / "q_net.pkl")
del q_net
q_net = q_net_class.load(tmp_path / "q_net.pkl")
# check if params are still the same after load
new_params = q_net.state_dict()
# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = q_net.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# clear file from os
os.remove(tmp_path / "q_net.pkl")

19
tests/test_utils.py Normal file
View File

@ -0,0 +1,19 @@
import numpy as np
import pytest
import torch as th
from sb3_contrib.common.utils import quantile_huber_loss
def test_quantile_huber_loss():
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10)), 2.5)
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10), sum_over_quantiles=False), 0.25)
with pytest.raises(ValueError):
quantile_huber_loss(th.zeros(1, 4, 4), th.zeros(1, 4))
with pytest.raises(ValueError):
quantile_huber_loss(th.zeros(1, 4), th.zeros(1, 1, 4))
with pytest.raises(ValueError):
quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4))
with pytest.raises(ValueError):
quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4))