Add QR-DQN (#13)
* Add QR-DQN(WIP) * Update docstring * Add quantile_huber_loss * Fix typo * Remove unnecessary lines * Update variable names and comments in quantile_huber_loss * Fix mutable arguments * Update variable names * Ignore import not used warnings * Fix default parameter of optimizer in QR-DQN * Update quantile_huber_loss to have more reasonable interface * update tests * Add assertion to quantile_huber_loss * Update variable names of quantile regression * Update comments * Reduce the number of quantiles during test * Update comment * Update quantile_huber_loss * Fix isort * Add document of QR-DQN without results * Update docs * Fix bugs * Update doc * Add comments about shape * Minor edits * Update comments * Add benchmark * Doc fixes * Update doc * Bug fix in saving/loading + update tests Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
3598ca284a
commit
b30397fff5
|
|
@ -25,6 +25,7 @@ See documentation for the full list of included features.
|
|||
|
||||
**RL Algorithms**:
|
||||
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
||||
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
|
||||
|
||||
**Gym Wrappers**:
|
||||
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ along with some useful characteristics: support for discrete/continuous actions,
|
|||
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||
============ =========== ============ ================= =============== ================
|
||||
TQC ✔️ ❌ ❌ ❌ ❌
|
||||
QR-DQN ️❌ ️✔️ ❌ ❌ ❌
|
||||
============ =========== ============ ================= =============== ================
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,21 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
|
|||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.save("tqc_pendulum")
|
||||
|
||||
QR-DQN
|
||||
------
|
||||
|
||||
Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sb3_contrib import QRDQN
|
||||
|
||||
policy_kwargs = dict(n_quantiles=50)
|
||||
model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.save("qrdqn_cartpole")
|
||||
|
||||
|
||||
.. PyBullet: Normalizing input features
|
||||
.. ------------------------------------
|
||||
..
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
|
|||
:caption: RL Algorithms
|
||||
|
||||
modules/tqc
|
||||
modules/qrdqn
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
|
|
|||
|
|
@ -13,9 +13,11 @@ Breaking Changes:
|
|||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``TimeFeatureWrapper`` to the wrappers
|
||||
- Added ``QR-DQN`` algorithm (`@ku2482`_)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed bug in ``TQC`` when saving/loading the policy only with non-default number of quantiles
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -24,6 +26,7 @@ Others:
|
|||
^^^^^^^
|
||||
- Updated ``TQC`` to match new SB3 version
|
||||
- Updated SB3 min version
|
||||
- Moved ``quantile_huber_loss`` to ``common/utils.py`` (@ku2482)
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
@ -62,13 +65,19 @@ Maintainers
|
|||
-----------
|
||||
|
||||
Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a),
|
||||
`Maximilian Ernestus`_ (aka @erniejunior), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
|
||||
`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
|
||||
|
||||
.. _Ashley Hill: https://github.com/hill-a
|
||||
.. _Antonin Raffin: https://araffin.github.io/
|
||||
.. _Maximilian Ernestus: https://github.com/erniejunior
|
||||
.. _Maximilian Ernestus: https://github.com/ernestum
|
||||
.. _Adam Gleave: https://gleave.me/
|
||||
.. _@araffin: https://github.com/araffin
|
||||
.. _@AdamGleave: https://github.com/adamgleave
|
||||
.. _Anssi Kanervisto: https://github.com/Miffyli
|
||||
.. _@Miffyli: https://github.com/Miffyli
|
||||
.. _@ku2482: https://github.com/ku2482
|
||||
|
||||
Contributors:
|
||||
-------------
|
||||
|
||||
@ku2482
|
||||
|
|
|
|||
|
|
@ -0,0 +1,150 @@
|
|||
.. _qrdqn:
|
||||
|
||||
.. automodule:: sb3_contrib.qrdqn
|
||||
|
||||
|
||||
QR-DQN
|
||||
======
|
||||
|
||||
`Quantile Regression DQN (QR-DQN) <https://arxiv.org/abs/1710.10044>`_ builds on `Deep Q-Network (DQN) <https://arxiv.org/abs/1312.5602>`_
|
||||
and make use of quantile regression to explicitly model the `distribution over returns <https://arxiv.org/abs/1707.06887>`_,
|
||||
instead of predicting the mean return (DQN).
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
MlpPolicy
|
||||
CnnPolicy
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
- Original paper: https://arxiv.org/abs/1710.100442
|
||||
- Distributional RL (C51): https://arxiv.org/abs/1707.06887
|
||||
|
||||
|
||||
Can I use?
|
||||
----------
|
||||
|
||||
- Recurrent policies: ❌
|
||||
- Multi processing: ❌
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ✔ ✔
|
||||
Box ❌ ✔
|
||||
MultiDiscrete ❌ ✔
|
||||
MultiBinary ❌ ✔
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
|
||||
from sb3_contrib import QRDQN
|
||||
|
||||
env = gym.make("CartPole-v1")
|
||||
|
||||
policy_kwargs = dict(n_quantiles=50)
|
||||
model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.save("qrdqn_cartpole")
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = QRDQN.load("qrdqn_cartpole")
|
||||
|
||||
obs = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
obs = env.reset()
|
||||
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
Result on Atari environments (10M steps, Pong and Breakout) and classic control tasks using 3 and 5 seeds.
|
||||
|
||||
The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/13>`_.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
QR-DQN implementation was validated against `Intel Coach <https://github.com/IntelLabs/coach/tree/master/benchmarks/qr_dqn>`_ one
|
||||
which roughly compare to the original paper results (we trained the agent with a smaller budget).
|
||||
|
||||
|
||||
============ ========== ===========
|
||||
Environments QR-DQN DQN
|
||||
============ ========== ===========
|
||||
Breakout 413 +/- 21 ~300
|
||||
Pong 20 +/- 0 ~20
|
||||
CartPole 386 +/- 64 500 +/- 0
|
||||
MountainCar -111 +/- 4 -107 +/- 4
|
||||
LunarLander 168 +/- 39 195 +/- 28
|
||||
Acrobot -73 +/- 2 -74 +/- 2
|
||||
============ ========== ===========
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Clone RL-Zoo fork and checkout the branch ``feat/qrdqn``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/ku2482/rl-baselines3-zoo/
|
||||
cd rl-baselines3-zoo/
|
||||
git checkout feat/qrdqn
|
||||
|
||||
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo qrdqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||
|
||||
|
||||
Plot the results:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/all_plots.py -a qrdqn -e Breakout Pong -f logs/ -o logs/qrdqn_results
|
||||
python scripts/plot_from_file.py -i logs/qrdqn_results.pkl -latex -l QR-DQN
|
||||
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
.. autoclass:: QRDQN
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. _qrdqn_policies:
|
||||
|
||||
QR-DQN Policies
|
||||
---------------
|
||||
|
||||
.. autoclass:: MlpPolicy
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: sb3_contrib.qrdqn.policies.QRDQNPolicy
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
.. autoclass:: CnnPolicy
|
||||
:members:
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
# from sb3_contrib.cmaes import CMAES
|
||||
from sb3_contrib.qrdqn import QRDQN
|
||||
from sb3_contrib.tqc import TQC
|
||||
|
||||
# Read version from file
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch as th
|
||||
|
||||
|
||||
def quantile_huber_loss(
|
||||
current_quantiles: th.Tensor,
|
||||
target_quantiles: th.Tensor,
|
||||
cum_prob: Optional[th.Tensor] = None,
|
||||
sum_over_quantiles: bool = True,
|
||||
) -> th.Tensor:
|
||||
"""
|
||||
The quantile-regression loss, as described in the QR-DQN and TQC papers.
|
||||
Partially taken from https://github.com/bayesgroup/tqc_pytorch.
|
||||
|
||||
:param current_quantiles: current estimate of quantiles, must be either
|
||||
(batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles)
|
||||
:param target_quantiles: target of quantiles, must be either (batch_size, n_target_quantiles),
|
||||
(batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles)
|
||||
:param cum_prob: cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper),
|
||||
must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles).
|
||||
(if None, calculating unit quantiles)
|
||||
:param sum_over_quantiles: if summing over the quantile dimension or not
|
||||
:return: the loss
|
||||
"""
|
||||
if current_quantiles.ndim != target_quantiles.ndim:
|
||||
raise ValueError(
|
||||
f"Error: The dimension of curremt_quantile ({current_quantiles.ndim}) needs to match "
|
||||
f"the dimension of target_quantiles ({target_quantiles.ndim})."
|
||||
)
|
||||
if current_quantiles.shape[0] != target_quantiles.shape[0]:
|
||||
raise ValueError(
|
||||
f"Error: The batch size of curremt_quantile ({current_quantiles.shape[0]}) needs to match "
|
||||
f"the batch size of target_quantiles ({target_quantiles.shape[0]})."
|
||||
)
|
||||
if current_quantiles.ndim not in (2, 3):
|
||||
raise ValueError(f"Error: The dimension of current_quantiles ({current_quantiles.ndim}) needs to be either 2 or 3.")
|
||||
|
||||
if cum_prob is None:
|
||||
n_quantiles = current_quantiles.shape[-1]
|
||||
# Cumulative probabilities to calculate quantiles.
|
||||
cum_prob = (th.arange(n_quantiles, device=current_quantiles.device, dtype=th.float) + 0.5) / n_quantiles
|
||||
if current_quantiles.ndim == 2:
|
||||
# For QR-DQN, current_quantiles have a shape (batch_size, n_quantiles), and make cum_prob
|
||||
# broadcastable to (batch_size, n_quantiles, n_target_quantiles)
|
||||
cum_prob = cum_prob.view(1, -1, 1)
|
||||
elif current_quantiles.ndim == 3:
|
||||
# For TQC, current_quantiles have a shape (batch_size, n_critics, n_quantiles), and make cum_prob
|
||||
# broadcastable to (batch_size, n_critics, n_quantiles, n_target_quantiles)
|
||||
cum_prob = cum_prob.view(1, 1, -1, 1)
|
||||
|
||||
# QR-DQN
|
||||
# target_quantiles: (batch_size, n_target_quantiles) -> (batch_size, 1, n_target_quantiles)
|
||||
# current_quantiles: (batch_size, n_quantiles) -> (batch_size, n_quantiles, 1)
|
||||
# pairwise_delta: (batch_size, n_target_quantiles, n_quantiles)
|
||||
# TQC
|
||||
# target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles)
|
||||
# current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1)
|
||||
# pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles)
|
||||
# Note: in both cases, the loss has the same shape as pairwise_delta
|
||||
pairwise_delta = target_quantiles.unsqueeze(-2) - current_quantiles.unsqueeze(-1)
|
||||
abs_pairwise_delta = th.abs(pairwise_delta)
|
||||
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
|
||||
loss = th.abs(cum_prob - (pairwise_delta.detach() < 0).float()) * huber_loss
|
||||
if sum_over_quantiles:
|
||||
loss = loss.sum(dim=-2).mean()
|
||||
else:
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy
|
||||
from sb3_contrib.qrdqn.qrdqn import QRDQN
|
||||
|
|
@ -0,0 +1,249 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from torch import nn
|
||||
|
||||
|
||||
class QuantileNetwork(BasePolicy):
|
||||
"""
|
||||
Quantile network for QR-DQN
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param n_quantiles: Number of quantiles
|
||||
:param net_arch: The specification of the network architecture.
|
||||
:param activation_fn: Activation function
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
n_quantiles: int = 200,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
):
|
||||
super(QuantileNetwork, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor=features_extractor,
|
||||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
if net_arch is None:
|
||||
net_arch = [64, 64]
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.features_extractor = features_extractor
|
||||
self.features_dim = features_dim
|
||||
self.n_quantiles = n_quantiles
|
||||
self.normalize_images = normalize_images
|
||||
action_dim = self.action_space.n # number of actions
|
||||
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
|
||||
self.quantile_net = nn.Sequential(*quantile_net)
|
||||
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Predict the quantiles.
|
||||
|
||||
:param obs: Observation
|
||||
:return: The estimated quantiles for each action.
|
||||
"""
|
||||
quantiles = self.quantile_net(self.extract_features(obs))
|
||||
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
q_values = self.forward(observation).mean(dim=1)
|
||||
# Greedy action
|
||||
action = q_values.argmax(dim=1).reshape(-1)
|
||||
return action
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
data.update(
|
||||
dict(
|
||||
net_arch=self.net_arch,
|
||||
features_dim=self.features_dim,
|
||||
n_quantiles=self.n_quantiles,
|
||||
activation_fn=self.activation_fn,
|
||||
features_extractor=self.features_extractor,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class QRDQNPolicy(BasePolicy):
|
||||
"""
|
||||
Policy class with quantile and target networks for QR-DQN.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param n_quantiles: Number of quantiles
|
||||
:param net_arch: The specification of the network architecture.
|
||||
:param activation_fn: Activation function
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
n_quantiles: int = 200,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
|
||||
super(QRDQNPolicy, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
)
|
||||
|
||||
if net_arch is None:
|
||||
if features_extractor_class == FlattenExtractor:
|
||||
net_arch = [64, 64]
|
||||
else:
|
||||
net_arch = []
|
||||
|
||||
self.n_quantiles = n_quantiles
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.normalize_images = normalize_images
|
||||
|
||||
self.net_args = {
|
||||
"observation_space": self.observation_space,
|
||||
"action_space": self.action_space,
|
||||
"n_quantiles": self.n_quantiles,
|
||||
"net_arch": self.net_arch,
|
||||
"activation_fn": self.activation_fn,
|
||||
"normalize_images": normalize_images,
|
||||
}
|
||||
|
||||
self.quantile_net, self.quantile_net_target = None, None
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule: Schedule) -> None:
|
||||
"""
|
||||
Create the network and the optimizer.
|
||||
|
||||
:param lr_schedule: Learning rate schedule
|
||||
lr_schedule(1) is the initial learning rate
|
||||
"""
|
||||
self.quantile_net = self.make_quantile_net()
|
||||
self.quantile_net_target = self.make_quantile_net()
|
||||
self.quantile_net_target.load_state_dict(self.quantile_net.state_dict())
|
||||
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def make_quantile_net(self) -> QuantileNetwork:
|
||||
# Make sure we always have separate networks for features extractors etc
|
||||
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
|
||||
return QuantileNetwork(**net_args).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
return self.quantile_net._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
data.update(
|
||||
dict(
|
||||
n_quantiles=self.net_args["n_quantiles"],
|
||||
net_arch=self.net_args["net_arch"],
|
||||
activation_fn=self.net_args["activation_fn"],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
MlpPolicy = QRDQNPolicy
|
||||
|
||||
|
||||
class CnnPolicy(QRDQNPolicy):
|
||||
"""
|
||||
Policy class for QR-DQN when using images as input.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param n_quantiles: Number of quantiles
|
||||
:param net_arch: The specification of the network architecture.
|
||||
:param activation_fn: Activation function
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
n_quantiles: int = 200,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(CnnPolicy, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
n_quantiles,
|
||||
net_arch,
|
||||
activation_fn,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
|
|
@ -0,0 +1,253 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
||||
|
||||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.qrdqn.policies import QRDQNPolicy
|
||||
|
||||
|
||||
class QRDQN(OffPolicyAlgorithm):
|
||||
"""
|
||||
Quantile Regression Deep Q-Network (QR-DQN)
|
||||
Paper: https://arxiv.org/abs/1710.10044
|
||||
Default hyperparameters are taken from the paper and are tuned for Atari games.
|
||||
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: The learning rate, it can be a function
|
||||
of the current progress remaining (from 1 to 0)
|
||||
:param buffer_size: size of the replay buffer
|
||||
:param learning_starts: how many steps of the model to collect transitions for before learning starts
|
||||
:param batch_size: Minibatch size for each gradient update
|
||||
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
|
||||
:param gamma: the discount factor
|
||||
:param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable.
|
||||
:param gradient_steps: How many gradient steps to do after each rollout
|
||||
(see ``train_freq`` and ``n_episodes_rollout``)
|
||||
Set to ``-1`` means to do as many gradient steps as steps done in the environment
|
||||
during the rollout.
|
||||
:param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
|
||||
Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable.
|
||||
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
||||
at a cost of more complexity.
|
||||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||
:param target_update_interval: update the target network every ``target_update_interval``
|
||||
environment steps.
|
||||
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
|
||||
:param exploration_initial_eps: initial value of random action probability
|
||||
:param exploration_final_eps: final value of random action probability
|
||||
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping)
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[QRDQNPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 5e-5,
|
||||
buffer_size: int = 1000000,
|
||||
learning_starts: int = 50000,
|
||||
batch_size: Optional[int] = 32,
|
||||
tau: float = 1.0,
|
||||
gamma: float = 0.99,
|
||||
train_freq: int = 4,
|
||||
gradient_steps: int = 1,
|
||||
n_episodes_rollout: int = -1,
|
||||
optimize_memory_usage: bool = False,
|
||||
target_update_interval: int = 10000,
|
||||
exploration_fraction: float = 0.005,
|
||||
exploration_initial_eps: float = 1.0,
|
||||
exploration_final_eps: float = 0.01,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super(QRDQN, self).__init__(
|
||||
policy,
|
||||
env,
|
||||
QRDQNPolicy,
|
||||
learning_rate,
|
||||
buffer_size,
|
||||
learning_starts,
|
||||
batch_size,
|
||||
tau,
|
||||
gamma,
|
||||
train_freq,
|
||||
gradient_steps,
|
||||
n_episodes_rollout,
|
||||
action_noise=None, # No action noise
|
||||
policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=tensorboard_log,
|
||||
verbose=verbose,
|
||||
device=device,
|
||||
create_eval_env=create_eval_env,
|
||||
seed=seed,
|
||||
sde_support=False,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
supported_action_spaces=(gym.spaces.Discrete,),
|
||||
)
|
||||
|
||||
self.exploration_initial_eps = exploration_initial_eps
|
||||
self.exploration_final_eps = exploration_final_eps
|
||||
self.exploration_fraction = exploration_fraction
|
||||
self.target_update_interval = target_update_interval
|
||||
self.max_grad_norm = max_grad_norm
|
||||
# "epsilon" for the epsilon-greedy exploration
|
||||
self.exploration_rate = 0.0
|
||||
# Linear schedule will be defined in `_setup_model()`
|
||||
self.exploration_schedule = None
|
||||
self.quantile_net, self.quantile_net_target = None, None
|
||||
|
||||
if "optimizer_class" not in self.policy_kwargs:
|
||||
self.policy_kwargs["optimizer_class"] = th.optim.Adam
|
||||
# Proposed in the QR-DQN paper where `batch_size = 32`
|
||||
self.policy_kwargs["optimizer_kwargs"] = dict(eps=0.01 / batch_size)
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
super(QRDQN, self)._setup_model()
|
||||
self._create_aliases()
|
||||
self.exploration_schedule = get_linear_fn(
|
||||
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
|
||||
)
|
||||
|
||||
def _create_aliases(self) -> None:
|
||||
self.quantile_net = self.policy.quantile_net
|
||||
self.quantile_net_target = self.policy.quantile_net_target
|
||||
self.n_quantiles = self.policy.n_quantiles
|
||||
|
||||
def _on_step(self) -> None:
|
||||
"""
|
||||
Update the exploration rate and target network if needed.
|
||||
This method is called in ``collect_rollouts()`` after each step in the environment.
|
||||
"""
|
||||
if self.num_timesteps % self.target_update_interval == 0:
|
||||
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
||||
|
||||
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
|
||||
logger.record("rollout/exploration rate", self.exploration_rate)
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
||||
# Update learning rate according to schedule
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
|
||||
losses = []
|
||||
for gradient_step in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
||||
with th.no_grad():
|
||||
# Compute the quantiles of next observation
|
||||
next_quantiles = self.quantile_net_target(replay_data.next_observations)
|
||||
# Follow greedy policy: use the one with the highest value
|
||||
next_quantiles, _ = next_quantiles.max(dim=2)
|
||||
# 1-step TD target
|
||||
target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles
|
||||
|
||||
# Get current quantile estimates
|
||||
current_quantiles = self.quantile_net(replay_data.observations)
|
||||
|
||||
# Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1).
|
||||
actions = replay_data.actions[..., None].long().expand(batch_size, self.n_quantiles, 1)
|
||||
# Retrieve the quantiles for the actions from the replay buffer
|
||||
current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2)
|
||||
|
||||
# Compute Quantile Huber loss, summing over a quantile dimension as in the paper.
|
||||
loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True)
|
||||
losses.append(loss.item())
|
||||
|
||||
# Optimize the policy
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip gradient norm
|
||||
if self.max_grad_norm is not None:
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
||||
# Increase update counter
|
||||
self._n_updates += gradient_steps
|
||||
|
||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
logger.record("train/loss", np.mean(losses))
|
||||
|
||||
def predict(
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
"""
|
||||
Overrides the base_class predict function to include epsilon-greedy exploration.
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||
if is_vectorized_observation(observation, self.observation_space):
|
||||
n_batch = observation.shape[0]
|
||||
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
||||
else:
|
||||
action = np.array(self.action_space.sample())
|
||||
else:
|
||||
action, state = self.policy.predict(observation, state, mask, deterministic)
|
||||
return action, state
|
||||
|
||||
def learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "QRDQN",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
|
||||
return super(QRDQN, self).learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
eval_env=eval_env,
|
||||
eval_freq=eval_freq,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
return super(QRDQN, self)._excluded_save_params() + ["quantile_net", "quantile_net_target"]
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
state_dicts = ["policy", "policy.optimizer"]
|
||||
|
||||
return state_dicts, []
|
||||
|
|
@ -391,6 +391,8 @@ class TQCPolicy(BasePolicy):
|
|||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs,
|
||||
n_quantiles=self.critic_kwargs["n_quantiles"],
|
||||
n_critics=self.critic_kwargs["n_critics"],
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
|||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
from stable_baselines3.common.utils import polyak_update
|
||||
|
||||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.tqc.policies import TQCPolicy
|
||||
|
||||
|
||||
|
|
@ -171,26 +172,6 @@ class TQC(OffPolicyAlgorithm):
|
|||
self.critic = self.policy.critic
|
||||
self.critic_target = self.policy.critic_target
|
||||
|
||||
@staticmethod
|
||||
def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
The quantile-regression loss, as described in the QR-DQN and TQC papers.
|
||||
Taken from https://github.com/bayesgroup/tqc_pytorch
|
||||
|
||||
:param quantiles:
|
||||
:param samples:
|
||||
:return: the loss
|
||||
"""
|
||||
# batch x nets x quantiles x samples
|
||||
pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None]
|
||||
abs_pairwise_delta = th.abs(pairwise_delta)
|
||||
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
|
||||
|
||||
n_quantiles = quantiles.shape[2]
|
||||
tau = th.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles
|
||||
loss = (th.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean()
|
||||
return loss
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
|
||||
# Update optimizers learning rate
|
||||
optimizers = [self.actor.optimizer, self.critic.optimizer]
|
||||
|
|
@ -237,24 +218,27 @@ class TQC(OffPolicyAlgorithm):
|
|||
self.ent_coef_optimizer.step()
|
||||
|
||||
with th.no_grad():
|
||||
top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics
|
||||
# Select action according to policy
|
||||
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||
# Compute and cut quantiles at the next state
|
||||
# batch x nets x quantiles
|
||||
next_z = self.critic_target(replay_data.next_observations, next_actions)
|
||||
sorted_z, _ = th.sort(next_z.reshape(batch_size, -1))
|
||||
sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop]
|
||||
next_quantiles = self.critic_target(replay_data.next_observations, next_actions)
|
||||
|
||||
# Sort and drop top k quantiles to control overestimation.
|
||||
n_target_quantiles = self.critic.quantiles_total - self.top_quantiles_to_drop_per_net * self.critic.n_critics
|
||||
next_quantiles, _ = th.sort(next_quantiles.reshape(batch_size, -1))
|
||||
next_quantiles = next_quantiles[:, :n_target_quantiles]
|
||||
|
||||
target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1)
|
||||
# td error + entropy term
|
||||
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
|
||||
target_quantiles = next_quantiles - ent_coef * next_log_prob.reshape(-1, 1)
|
||||
target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_quantiles
|
||||
# Make target_quantiles broadcastable to (batch_size, n_critics, n_target_quantiles).
|
||||
target_quantiles.unsqueeze_(dim=1)
|
||||
|
||||
# Get current Q estimates
|
||||
# using action from the replay buffer
|
||||
current_z = self.critic(replay_data.observations, replay_data.actions)
|
||||
# Compute critic loss
|
||||
critic_loss = self.quantile_huber_loss(current_z, q_backup)
|
||||
# Get current Quantile estimates using action from the replay buffer
|
||||
current_quantiles = self.critic(replay_data.observations, replay_data.actions)
|
||||
# Compute critic loss, not summing over the quantile dimension as in the paper.
|
||||
critic_loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=False)
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# Optimize the critic
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators
|
|||
# Ignore import not used when aliases are defined
|
||||
per-file-ignores =
|
||||
./sb3_contrib/__init__.py:F401
|
||||
./sb3_contrib/qrdqn/__init__.py:F401
|
||||
./sb3_contrib/tqc/__init__.py:F401
|
||||
./sb3_contrib/common/wrappers/__init__.py:F401
|
||||
exclude =
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ import torch as th
|
|||
from stable_baselines3.common.identity_env import FakeImageEnv
|
||||
from stable_baselines3.common.utils import zip_strict
|
||||
|
||||
from sb3_contrib import TQC
|
||||
from sb3_contrib import QRDQN, TQC
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TQC])
|
||||
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||
def test_cnn(tmp_path, model_class):
|
||||
SAVE_NAME = "cnn_model.zip"
|
||||
# Fake grayscale with frameskip
|
||||
|
|
@ -18,10 +18,13 @@ def test_cnn(tmp_path, model_class):
|
|||
# to check that the network handle it automatically
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
||||
kwargs = {}
|
||||
if model_class in {TQC}:
|
||||
if model_class in {TQC, QRDQN}:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
|
||||
# Reduce the size of the features and the number of quantiles
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)),
|
||||
)
|
||||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||
|
||||
obs = env.reset()
|
||||
|
|
@ -39,6 +42,13 @@ def test_cnn(tmp_path, model_class):
|
|||
os.remove(str(tmp_path / SAVE_NAME))
|
||||
|
||||
|
||||
def patch_qrdqn_names_(model):
|
||||
# Small hack to make the test work with QRDQN
|
||||
if isinstance(model, QRDQN):
|
||||
model.critic = model.quantile_net
|
||||
model.critic_target = model.quantile_net_target
|
||||
|
||||
|
||||
def params_should_match(params, other_params):
|
||||
for param, other_param in zip_strict(params, other_params):
|
||||
assert th.allclose(param, other_param)
|
||||
|
|
@ -49,28 +59,36 @@ def params_should_differ(params, other_params):
|
|||
assert not th.allclose(param, other_param)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TQC])
|
||||
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||
@pytest.mark.parametrize("share_features_extractor", [True, False])
|
||||
def test_feature_extractor_target_net(model_class, share_features_extractor):
|
||||
if model_class == QRDQN and share_features_extractor:
|
||||
pytest.skip()
|
||||
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
learning_starts=100,
|
||||
policy_kwargs=dict(
|
||||
features_extractor_kwargs=dict(features_dim=32),
|
||||
share_features_extractor=share_features_extractor,
|
||||
),
|
||||
)
|
||||
|
||||
if model_class in {TQC, QRDQN}:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features and the number of quantiles
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
learning_starts=100,
|
||||
policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)),
|
||||
)
|
||||
if model_class != QRDQN:
|
||||
kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor
|
||||
|
||||
model = model_class("CnnPolicy", env, seed=0, **kwargs)
|
||||
|
||||
patch_qrdqn_names_(model)
|
||||
|
||||
if share_features_extractor:
|
||||
# Check that the objects are the same and not just copied
|
||||
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
|
||||
else:
|
||||
# Check that the objects differ
|
||||
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
|
||||
if model_class != QRDQN:
|
||||
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
|
||||
|
||||
# Critic and target should be equal at the begginning of training
|
||||
params_should_match(model.critic.parameters(), model.critic_target.parameters())
|
||||
|
|
@ -83,6 +101,8 @@ def test_feature_extractor_target_net(model_class, share_features_extractor):
|
|||
# Re-initialize and collect some random data (without doing gradient steps)
|
||||
model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10)
|
||||
|
||||
patch_qrdqn_names_(model)
|
||||
|
||||
original_param = deepcopy(list(model.critic.parameters()))
|
||||
original_target_param = deepcopy(list(model.critic_target.parameters()))
|
||||
|
||||
|
|
@ -103,6 +123,11 @@ def test_feature_extractor_target_net(model_class, share_features_extractor):
|
|||
model.lr_schedule = lambda _: 0.0
|
||||
# Re-activate polyak update
|
||||
model.tau = 0.01
|
||||
# Special case for QRDQN: target net is updated in the `collect_rollouts()`
|
||||
# not the `train()` method
|
||||
if model_class == QRDQN:
|
||||
model.target_update_interval = 1
|
||||
model._on_step()
|
||||
|
||||
model.train(gradient_steps=1)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from sb3_contrib import TQC
|
||||
from sb3_contrib import QRDQN, TQC
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
||||
|
|
@ -21,7 +21,11 @@ def test_tqc(ent_coef):
|
|||
def test_n_critics(n_critics):
|
||||
# Test TQC with different number of critics
|
||||
model = TQC(
|
||||
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
policy_kwargs=dict(net_arch=[64], n_critics=n_critics),
|
||||
learning_starts=100,
|
||||
verbose=1,
|
||||
)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
|
|
@ -38,3 +42,17 @@ def test_sde():
|
|||
model.learn(total_timesteps=300)
|
||||
model.policy.reset_noise()
|
||||
model.policy.actor.get_std()
|
||||
|
||||
|
||||
def test_qrdqn():
|
||||
model = QRDQN(
|
||||
"MlpPolicy",
|
||||
"CartPole-v1",
|
||||
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
buffer_size=500,
|
||||
learning_rate=3e-4,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
|
|
|||
|
|
@ -7,22 +7,21 @@ import gym
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
from stable_baselines3 import DQN
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||
from stable_baselines3.common.utils import get_device
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from sb3_contrib import TQC
|
||||
from sb3_contrib import QRDQN, TQC
|
||||
|
||||
MODEL_LIST = [TQC]
|
||||
MODEL_LIST = [TQC, QRDQN]
|
||||
|
||||
|
||||
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
||||
"""
|
||||
Selects an environment with the correct action space as DQN only supports discrete action space
|
||||
Selects an environment with the correct action space as QRDQN only supports discrete action space
|
||||
"""
|
||||
if model_class == DQN:
|
||||
if model_class == QRDQN:
|
||||
return IdentityEnv(10)
|
||||
else:
|
||||
return IdentityEnvBox(10)
|
||||
|
|
@ -41,8 +40,13 @@ def test_save_load(tmp_path, model_class):
|
|||
|
||||
env = DummyVecEnv([lambda: select_env(model_class)])
|
||||
|
||||
policy_kwargs = dict(net_arch=[16])
|
||||
|
||||
if model_class in {QRDQN, TQC}:
|
||||
policy_kwargs.update(dict(n_quantiles=20))
|
||||
|
||||
# create model
|
||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||
model = model_class("MlpPolicy", env, verbose=1, policy_kwargs=policy_kwargs)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
env.reset()
|
||||
|
|
@ -167,13 +171,18 @@ def test_set_env(model_class):
|
|||
:param model_class: (BaseAlgorithm) A RL model
|
||||
"""
|
||||
|
||||
# use discrete for DQN
|
||||
# use discrete for QRDQN
|
||||
env = DummyVecEnv([lambda: select_env(model_class)])
|
||||
env2 = DummyVecEnv([lambda: select_env(model_class)])
|
||||
env3 = select_env(model_class)
|
||||
|
||||
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
||||
if model_class in {TQC, QRDQN}:
|
||||
kwargs.update(dict(learning_starts=100))
|
||||
kwargs["policy_kwargs"].update(dict(n_quantiles=20))
|
||||
|
||||
# create model
|
||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
|
||||
model = model_class("MlpPolicy", env, **kwargs)
|
||||
# learn
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
|
|
@ -219,7 +228,7 @@ def test_exclude_include_saved_params(tmp_path, model_class):
|
|||
os.remove(tmp_path / "test_save.zip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TQC])
|
||||
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||
def test_save_load_replay_buffer(tmp_path, model_class):
|
||||
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
|
||||
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
|
||||
|
|
@ -254,20 +263,28 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
:param model_class: (BaseAlgorithm) A RL model
|
||||
:param policy_str: (str) Name of the policy.
|
||||
"""
|
||||
kwargs = {}
|
||||
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
||||
if policy_str == "MlpPolicy":
|
||||
env = select_env(model_class)
|
||||
else:
|
||||
if model_class in [TQC]:
|
||||
if model_class in [TQC, QRDQN]:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(buffer_size=250)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
learning_starts=100,
|
||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||
)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN)
|
||||
|
||||
# Reduce number of quantiles for faster tests
|
||||
if model_class in [TQC, QRDQN]:
|
||||
kwargs["policy_kwargs"].update(dict(n_quantiles=20))
|
||||
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
# create model
|
||||
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
|
||||
model = model_class(policy_str, env, verbose=1, **kwargs)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
env.reset()
|
||||
|
|
@ -334,3 +351,83 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
os.remove(tmp_path / "policy.pkl")
|
||||
if actor_class is not None:
|
||||
os.remove(tmp_path / "actor.pkl")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [QRDQN])
|
||||
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
|
||||
def test_save_load_q_net(tmp_path, model_class, policy_str):
|
||||
"""
|
||||
Test saving and loading q-network/quantile net only.
|
||||
|
||||
:param model_class: (BaseAlgorithm) A RL model
|
||||
:param policy_str: (str) Name of the policy.
|
||||
"""
|
||||
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
||||
if policy_str == "MlpPolicy":
|
||||
env = select_env(model_class)
|
||||
else:
|
||||
if model_class in [QRDQN]:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
learning_starts=100,
|
||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||
)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN)
|
||||
|
||||
# Reduce number of quantiles for faster tests
|
||||
if model_class in [QRDQN]:
|
||||
kwargs["policy_kwargs"].update(dict(n_quantiles=20))
|
||||
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
# create model
|
||||
model = model_class(policy_str, env, verbose=1, **kwargs)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
env.reset()
|
||||
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||
|
||||
q_net = model.quantile_net
|
||||
q_net_class = q_net.__class__
|
||||
|
||||
# Get dictionary of current parameters
|
||||
params = deepcopy(q_net.state_dict())
|
||||
|
||||
# Modify all parameters to be random values
|
||||
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||
|
||||
# Update model parameters with the new random values
|
||||
q_net.load_state_dict(random_params)
|
||||
|
||||
new_params = q_net.state_dict()
|
||||
# Check that all params are different now
|
||||
for k in params:
|
||||
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
|
||||
|
||||
params = new_params
|
||||
|
||||
# get selected actions
|
||||
selected_actions, _ = q_net.predict(observations, deterministic=True)
|
||||
|
||||
# Save and load q_net
|
||||
q_net.save(tmp_path / "q_net.pkl")
|
||||
|
||||
del q_net
|
||||
|
||||
q_net = q_net_class.load(tmp_path / "q_net.pkl")
|
||||
|
||||
# check if params are still the same after load
|
||||
new_params = q_net.state_dict()
|
||||
|
||||
# Check that all params are the same as before save load procedure now
|
||||
for key in params:
|
||||
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions, _ = q_net.predict(observations, deterministic=True)
|
||||
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||
|
||||
# clear file from os
|
||||
os.remove(tmp_path / "q_net.pkl")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
|
||||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
|
||||
|
||||
def test_quantile_huber_loss():
|
||||
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10)), 2.5)
|
||||
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10), sum_over_quantiles=False), 0.25)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
quantile_huber_loss(th.zeros(1, 4, 4), th.zeros(1, 4))
|
||||
with pytest.raises(ValueError):
|
||||
quantile_huber_loss(th.zeros(1, 4), th.zeros(1, 1, 4))
|
||||
with pytest.raises(ValueError):
|
||||
quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4))
|
||||
with pytest.raises(ValueError):
|
||||
quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4))
|
||||
Loading…
Reference in New Issue