Add QR-DQN (#13)
* Add QR-DQN(WIP) * Update docstring * Add quantile_huber_loss * Fix typo * Remove unnecessary lines * Update variable names and comments in quantile_huber_loss * Fix mutable arguments * Update variable names * Ignore import not used warnings * Fix default parameter of optimizer in QR-DQN * Update quantile_huber_loss to have more reasonable interface * update tests * Add assertion to quantile_huber_loss * Update variable names of quantile regression * Update comments * Reduce the number of quantiles during test * Update comment * Update quantile_huber_loss * Fix isort * Add document of QR-DQN without results * Update docs * Fix bugs * Update doc * Add comments about shape * Minor edits * Update comments * Add benchmark * Doc fixes * Update doc * Bug fix in saving/loading + update tests Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
3598ca284a
commit
b30397fff5
|
|
@ -25,6 +25,7 @@ See documentation for the full list of included features.
|
||||||
|
|
||||||
**RL Algorithms**:
|
**RL Algorithms**:
|
||||||
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
||||||
|
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
|
||||||
|
|
||||||
**Gym Wrappers**:
|
**Gym Wrappers**:
|
||||||
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ along with some useful characteristics: support for discrete/continuous actions,
|
||||||
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||||
============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
TQC ✔️ ❌ ❌ ❌ ❌
|
TQC ✔️ ❌ ❌ ❌ ❌
|
||||||
|
QR-DQN ️❌ ️✔️ ❌ ❌ ❌
|
||||||
============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,21 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
|
||||||
model.learn(total_timesteps=10000, log_interval=4)
|
model.learn(total_timesteps=10000, log_interval=4)
|
||||||
model.save("tqc_pendulum")
|
model.save("tqc_pendulum")
|
||||||
|
|
||||||
|
QR-DQN
|
||||||
|
------
|
||||||
|
|
||||||
|
Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from sb3_contrib import QRDQN
|
||||||
|
|
||||||
|
policy_kwargs = dict(n_quantiles=50)
|
||||||
|
model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
|
||||||
|
model.learn(total_timesteps=10000, log_interval=4)
|
||||||
|
model.save("qrdqn_cartpole")
|
||||||
|
|
||||||
|
|
||||||
.. PyBullet: Normalizing input features
|
.. PyBullet: Normalizing input features
|
||||||
.. ------------------------------------
|
.. ------------------------------------
|
||||||
..
|
..
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
|
||||||
:caption: RL Algorithms
|
:caption: RL Algorithms
|
||||||
|
|
||||||
modules/tqc
|
modules/tqc
|
||||||
|
modules/qrdqn
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,11 @@ Breaking Changes:
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
- Added ``TimeFeatureWrapper`` to the wrappers
|
- Added ``TimeFeatureWrapper`` to the wrappers
|
||||||
|
- Added ``QR-DQN`` algorithm (`@ku2482`_)
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
- Fixed bug in ``TQC`` when saving/loading the policy only with non-default number of quantiles
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
@ -24,6 +26,7 @@ Others:
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
- Updated ``TQC`` to match new SB3 version
|
- Updated ``TQC`` to match new SB3 version
|
||||||
- Updated SB3 min version
|
- Updated SB3 min version
|
||||||
|
- Moved ``quantile_huber_loss`` to ``common/utils.py`` (@ku2482)
|
||||||
|
|
||||||
Documentation:
|
Documentation:
|
||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
|
@ -62,13 +65,19 @@ Maintainers
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a),
|
Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a),
|
||||||
`Maximilian Ernestus`_ (aka @erniejunior), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
|
`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
|
||||||
|
|
||||||
.. _Ashley Hill: https://github.com/hill-a
|
.. _Ashley Hill: https://github.com/hill-a
|
||||||
.. _Antonin Raffin: https://araffin.github.io/
|
.. _Antonin Raffin: https://araffin.github.io/
|
||||||
.. _Maximilian Ernestus: https://github.com/erniejunior
|
.. _Maximilian Ernestus: https://github.com/ernestum
|
||||||
.. _Adam Gleave: https://gleave.me/
|
.. _Adam Gleave: https://gleave.me/
|
||||||
.. _@araffin: https://github.com/araffin
|
.. _@araffin: https://github.com/araffin
|
||||||
.. _@AdamGleave: https://github.com/adamgleave
|
.. _@AdamGleave: https://github.com/adamgleave
|
||||||
.. _Anssi Kanervisto: https://github.com/Miffyli
|
.. _Anssi Kanervisto: https://github.com/Miffyli
|
||||||
.. _@Miffyli: https://github.com/Miffyli
|
.. _@Miffyli: https://github.com/Miffyli
|
||||||
|
.. _@ku2482: https://github.com/ku2482
|
||||||
|
|
||||||
|
Contributors:
|
||||||
|
-------------
|
||||||
|
|
||||||
|
@ku2482
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,150 @@
|
||||||
|
.. _qrdqn:
|
||||||
|
|
||||||
|
.. automodule:: sb3_contrib.qrdqn
|
||||||
|
|
||||||
|
|
||||||
|
QR-DQN
|
||||||
|
======
|
||||||
|
|
||||||
|
`Quantile Regression DQN (QR-DQN) <https://arxiv.org/abs/1710.10044>`_ builds on `Deep Q-Network (DQN) <https://arxiv.org/abs/1312.5602>`_
|
||||||
|
and make use of quantile regression to explicitly model the `distribution over returns <https://arxiv.org/abs/1707.06887>`_,
|
||||||
|
instead of predicting the mean return (DQN).
|
||||||
|
|
||||||
|
|
||||||
|
.. rubric:: Available Policies
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
MlpPolicy
|
||||||
|
CnnPolicy
|
||||||
|
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
|
||||||
|
- Original paper: https://arxiv.org/abs/1710.100442
|
||||||
|
- Distributional RL (C51): https://arxiv.org/abs/1707.06887
|
||||||
|
|
||||||
|
|
||||||
|
Can I use?
|
||||||
|
----------
|
||||||
|
|
||||||
|
- Recurrent policies: ❌
|
||||||
|
- Multi processing: ❌
|
||||||
|
- Gym spaces:
|
||||||
|
|
||||||
|
|
||||||
|
============= ====== ===========
|
||||||
|
Space Action Observation
|
||||||
|
============= ====== ===========
|
||||||
|
Discrete ✔ ✔
|
||||||
|
Box ❌ ✔
|
||||||
|
MultiDiscrete ❌ ✔
|
||||||
|
MultiBinary ❌ ✔
|
||||||
|
============= ====== ===========
|
||||||
|
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import gym
|
||||||
|
|
||||||
|
from sb3_contrib import QRDQN
|
||||||
|
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
|
||||||
|
policy_kwargs = dict(n_quantiles=50)
|
||||||
|
model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
|
||||||
|
model.learn(total_timesteps=10000, log_interval=4)
|
||||||
|
model.save("qrdqn_cartpole")
|
||||||
|
|
||||||
|
del model # remove to demonstrate saving and loading
|
||||||
|
|
||||||
|
model = QRDQN.load("qrdqn_cartpole")
|
||||||
|
|
||||||
|
obs = env.reset()
|
||||||
|
while True:
|
||||||
|
action, _states = model.predict(obs, deterministic=True)
|
||||||
|
obs, reward, done, info = env.step(action)
|
||||||
|
env.render()
|
||||||
|
if done:
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
|
||||||
|
Results
|
||||||
|
-------
|
||||||
|
|
||||||
|
Result on Atari environments (10M steps, Pong and Breakout) and classic control tasks using 3 and 5 seeds.
|
||||||
|
|
||||||
|
The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/13>`_.
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
QR-DQN implementation was validated against `Intel Coach <https://github.com/IntelLabs/coach/tree/master/benchmarks/qr_dqn>`_ one
|
||||||
|
which roughly compare to the original paper results (we trained the agent with a smaller budget).
|
||||||
|
|
||||||
|
|
||||||
|
============ ========== ===========
|
||||||
|
Environments QR-DQN DQN
|
||||||
|
============ ========== ===========
|
||||||
|
Breakout 413 +/- 21 ~300
|
||||||
|
Pong 20 +/- 0 ~20
|
||||||
|
CartPole 386 +/- 64 500 +/- 0
|
||||||
|
MountainCar -111 +/- 4 -107 +/- 4
|
||||||
|
LunarLander 168 +/- 39 195 +/- 28
|
||||||
|
Acrobot -73 +/- 2 -74 +/- 2
|
||||||
|
============ ========== ===========
|
||||||
|
|
||||||
|
How to replicate the results?
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Clone RL-Zoo fork and checkout the branch ``feat/qrdqn``:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
git clone https://github.com/ku2482/rl-baselines3-zoo/
|
||||||
|
cd rl-baselines3-zoo/
|
||||||
|
git checkout feat/qrdqn
|
||||||
|
|
||||||
|
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python train.py --algo qrdqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||||
|
|
||||||
|
|
||||||
|
Plot the results:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python scripts/all_plots.py -a qrdqn -e Breakout Pong -f logs/ -o logs/qrdqn_results
|
||||||
|
python scripts/plot_from_file.py -i logs/qrdqn_results.pkl -latex -l QR-DQN
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
|
||||||
|
.. autoclass:: QRDQN
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
|
||||||
|
.. _qrdqn_policies:
|
||||||
|
|
||||||
|
QR-DQN Policies
|
||||||
|
---------------
|
||||||
|
|
||||||
|
.. autoclass:: MlpPolicy
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
|
||||||
|
.. autoclass:: sb3_contrib.qrdqn.policies.QRDQNPolicy
|
||||||
|
:members:
|
||||||
|
:noindex:
|
||||||
|
|
||||||
|
.. autoclass:: CnnPolicy
|
||||||
|
:members:
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# from sb3_contrib.cmaes import CMAES
|
from sb3_contrib.qrdqn import QRDQN
|
||||||
from sb3_contrib.tqc import TQC
|
from sb3_contrib.tqc import TQC
|
||||||
|
|
||||||
# Read version from file
|
# Read version from file
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch as th
|
||||||
|
|
||||||
|
|
||||||
|
def quantile_huber_loss(
|
||||||
|
current_quantiles: th.Tensor,
|
||||||
|
target_quantiles: th.Tensor,
|
||||||
|
cum_prob: Optional[th.Tensor] = None,
|
||||||
|
sum_over_quantiles: bool = True,
|
||||||
|
) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
The quantile-regression loss, as described in the QR-DQN and TQC papers.
|
||||||
|
Partially taken from https://github.com/bayesgroup/tqc_pytorch.
|
||||||
|
|
||||||
|
:param current_quantiles: current estimate of quantiles, must be either
|
||||||
|
(batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles)
|
||||||
|
:param target_quantiles: target of quantiles, must be either (batch_size, n_target_quantiles),
|
||||||
|
(batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles)
|
||||||
|
:param cum_prob: cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper),
|
||||||
|
must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles).
|
||||||
|
(if None, calculating unit quantiles)
|
||||||
|
:param sum_over_quantiles: if summing over the quantile dimension or not
|
||||||
|
:return: the loss
|
||||||
|
"""
|
||||||
|
if current_quantiles.ndim != target_quantiles.ndim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error: The dimension of curremt_quantile ({current_quantiles.ndim}) needs to match "
|
||||||
|
f"the dimension of target_quantiles ({target_quantiles.ndim})."
|
||||||
|
)
|
||||||
|
if current_quantiles.shape[0] != target_quantiles.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error: The batch size of curremt_quantile ({current_quantiles.shape[0]}) needs to match "
|
||||||
|
f"the batch size of target_quantiles ({target_quantiles.shape[0]})."
|
||||||
|
)
|
||||||
|
if current_quantiles.ndim not in (2, 3):
|
||||||
|
raise ValueError(f"Error: The dimension of current_quantiles ({current_quantiles.ndim}) needs to be either 2 or 3.")
|
||||||
|
|
||||||
|
if cum_prob is None:
|
||||||
|
n_quantiles = current_quantiles.shape[-1]
|
||||||
|
# Cumulative probabilities to calculate quantiles.
|
||||||
|
cum_prob = (th.arange(n_quantiles, device=current_quantiles.device, dtype=th.float) + 0.5) / n_quantiles
|
||||||
|
if current_quantiles.ndim == 2:
|
||||||
|
# For QR-DQN, current_quantiles have a shape (batch_size, n_quantiles), and make cum_prob
|
||||||
|
# broadcastable to (batch_size, n_quantiles, n_target_quantiles)
|
||||||
|
cum_prob = cum_prob.view(1, -1, 1)
|
||||||
|
elif current_quantiles.ndim == 3:
|
||||||
|
# For TQC, current_quantiles have a shape (batch_size, n_critics, n_quantiles), and make cum_prob
|
||||||
|
# broadcastable to (batch_size, n_critics, n_quantiles, n_target_quantiles)
|
||||||
|
cum_prob = cum_prob.view(1, 1, -1, 1)
|
||||||
|
|
||||||
|
# QR-DQN
|
||||||
|
# target_quantiles: (batch_size, n_target_quantiles) -> (batch_size, 1, n_target_quantiles)
|
||||||
|
# current_quantiles: (batch_size, n_quantiles) -> (batch_size, n_quantiles, 1)
|
||||||
|
# pairwise_delta: (batch_size, n_target_quantiles, n_quantiles)
|
||||||
|
# TQC
|
||||||
|
# target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles)
|
||||||
|
# current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1)
|
||||||
|
# pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles)
|
||||||
|
# Note: in both cases, the loss has the same shape as pairwise_delta
|
||||||
|
pairwise_delta = target_quantiles.unsqueeze(-2) - current_quantiles.unsqueeze(-1)
|
||||||
|
abs_pairwise_delta = th.abs(pairwise_delta)
|
||||||
|
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
|
||||||
|
loss = th.abs(cum_prob - (pairwise_delta.detach() < 0).float()) * huber_loss
|
||||||
|
if sum_over_quantiles:
|
||||||
|
loss = loss.sum(dim=-2).mean()
|
||||||
|
else:
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy
|
||||||
|
from sb3_contrib.qrdqn.qrdqn import QRDQN
|
||||||
|
|
@ -0,0 +1,249 @@
|
||||||
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import torch as th
|
||||||
|
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||||
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
|
||||||
|
from stable_baselines3.common.type_aliases import Schedule
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class QuantileNetwork(BasePolicy):
|
||||||
|
"""
|
||||||
|
Quantile network for QR-DQN
|
||||||
|
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param n_quantiles: Number of quantiles
|
||||||
|
:param net_arch: The specification of the network architecture.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
features_extractor: nn.Module,
|
||||||
|
features_dim: int,
|
||||||
|
n_quantiles: int = 200,
|
||||||
|
net_arch: Optional[List[int]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
):
|
||||||
|
super(QuantileNetwork, self).__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
features_extractor=features_extractor,
|
||||||
|
normalize_images=normalize_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
if net_arch is None:
|
||||||
|
net_arch = [64, 64]
|
||||||
|
|
||||||
|
self.net_arch = net_arch
|
||||||
|
self.activation_fn = activation_fn
|
||||||
|
self.features_extractor = features_extractor
|
||||||
|
self.features_dim = features_dim
|
||||||
|
self.n_quantiles = n_quantiles
|
||||||
|
self.normalize_images = normalize_images
|
||||||
|
action_dim = self.action_space.n # number of actions
|
||||||
|
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
|
||||||
|
self.quantile_net = nn.Sequential(*quantile_net)
|
||||||
|
|
||||||
|
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
Predict the quantiles.
|
||||||
|
|
||||||
|
:param obs: Observation
|
||||||
|
:return: The estimated quantiles for each action.
|
||||||
|
"""
|
||||||
|
quantiles = self.quantile_net(self.extract_features(obs))
|
||||||
|
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
|
||||||
|
|
||||||
|
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||||
|
q_values = self.forward(observation).mean(dim=1)
|
||||||
|
# Greedy action
|
||||||
|
action = q_values.argmax(dim=1).reshape(-1)
|
||||||
|
return action
|
||||||
|
|
||||||
|
def _get_data(self) -> Dict[str, Any]:
|
||||||
|
data = super()._get_data()
|
||||||
|
|
||||||
|
data.update(
|
||||||
|
dict(
|
||||||
|
net_arch=self.net_arch,
|
||||||
|
features_dim=self.features_dim,
|
||||||
|
n_quantiles=self.n_quantiles,
|
||||||
|
activation_fn=self.activation_fn,
|
||||||
|
features_extractor=self.features_extractor,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class QRDQNPolicy(BasePolicy):
|
||||||
|
"""
|
||||||
|
Policy class with quantile and target networks for QR-DQN.
|
||||||
|
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param lr_schedule: Learning rate schedule (could be constant)
|
||||||
|
:param n_quantiles: Number of quantiles
|
||||||
|
:param net_arch: The specification of the network architecture.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param features_extractor_class: Features extractor to use.
|
||||||
|
:param features_extractor_kwargs: Keyword arguments
|
||||||
|
to pass to the features extractor.
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
:param optimizer_class: The optimizer to use,
|
||||||
|
``th.optim.Adam`` by default
|
||||||
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
lr_schedule: Schedule,
|
||||||
|
n_quantiles: int = 200,
|
||||||
|
net_arch: Optional[List[int]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||||
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super(QRDQNPolicy, self).__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
features_extractor_class,
|
||||||
|
features_extractor_kwargs,
|
||||||
|
optimizer_class=optimizer_class,
|
||||||
|
optimizer_kwargs=optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if net_arch is None:
|
||||||
|
if features_extractor_class == FlattenExtractor:
|
||||||
|
net_arch = [64, 64]
|
||||||
|
else:
|
||||||
|
net_arch = []
|
||||||
|
|
||||||
|
self.n_quantiles = n_quantiles
|
||||||
|
self.net_arch = net_arch
|
||||||
|
self.activation_fn = activation_fn
|
||||||
|
self.normalize_images = normalize_images
|
||||||
|
|
||||||
|
self.net_args = {
|
||||||
|
"observation_space": self.observation_space,
|
||||||
|
"action_space": self.action_space,
|
||||||
|
"n_quantiles": self.n_quantiles,
|
||||||
|
"net_arch": self.net_arch,
|
||||||
|
"activation_fn": self.activation_fn,
|
||||||
|
"normalize_images": normalize_images,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.quantile_net, self.quantile_net_target = None, None
|
||||||
|
self._build(lr_schedule)
|
||||||
|
|
||||||
|
def _build(self, lr_schedule: Schedule) -> None:
|
||||||
|
"""
|
||||||
|
Create the network and the optimizer.
|
||||||
|
|
||||||
|
:param lr_schedule: Learning rate schedule
|
||||||
|
lr_schedule(1) is the initial learning rate
|
||||||
|
"""
|
||||||
|
self.quantile_net = self.make_quantile_net()
|
||||||
|
self.quantile_net_target = self.make_quantile_net()
|
||||||
|
self.quantile_net_target.load_state_dict(self.quantile_net.state_dict())
|
||||||
|
|
||||||
|
# Setup optimizer with initial learning rate
|
||||||
|
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||||
|
|
||||||
|
def make_quantile_net(self) -> QuantileNetwork:
|
||||||
|
# Make sure we always have separate networks for features extractors etc
|
||||||
|
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
|
||||||
|
return QuantileNetwork(**net_args).to(self.device)
|
||||||
|
|
||||||
|
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||||
|
return self._predict(obs, deterministic=deterministic)
|
||||||
|
|
||||||
|
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||||
|
return self.quantile_net._predict(obs, deterministic=deterministic)
|
||||||
|
|
||||||
|
def _get_data(self) -> Dict[str, Any]:
|
||||||
|
data = super()._get_data()
|
||||||
|
|
||||||
|
data.update(
|
||||||
|
dict(
|
||||||
|
n_quantiles=self.net_args["n_quantiles"],
|
||||||
|
net_arch=self.net_args["net_arch"],
|
||||||
|
activation_fn=self.net_args["activation_fn"],
|
||||||
|
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||||
|
optimizer_class=self.optimizer_class,
|
||||||
|
optimizer_kwargs=self.optimizer_kwargs,
|
||||||
|
features_extractor_class=self.features_extractor_class,
|
||||||
|
features_extractor_kwargs=self.features_extractor_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
MlpPolicy = QRDQNPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class CnnPolicy(QRDQNPolicy):
|
||||||
|
"""
|
||||||
|
Policy class for QR-DQN when using images as input.
|
||||||
|
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param lr_schedule: Learning rate schedule (could be constant)
|
||||||
|
:param n_quantiles: Number of quantiles
|
||||||
|
:param net_arch: The specification of the network architecture.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param features_extractor_class: Features extractor to use.
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
:param optimizer_class: The optimizer to use,
|
||||||
|
``th.optim.Adam`` by default
|
||||||
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
lr_schedule: Schedule,
|
||||||
|
n_quantiles: int = 200,
|
||||||
|
net_arch: Optional[List[int]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||||
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
super(CnnPolicy, self).__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
lr_schedule,
|
||||||
|
n_quantiles,
|
||||||
|
net_arch,
|
||||||
|
activation_fn,
|
||||||
|
features_extractor_class,
|
||||||
|
features_extractor_kwargs,
|
||||||
|
normalize_images,
|
||||||
|
optimizer_class,
|
||||||
|
optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_policy("MlpPolicy", MlpPolicy)
|
||||||
|
register_policy("CnnPolicy", CnnPolicy)
|
||||||
|
|
@ -0,0 +1,253 @@
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from stable_baselines3.common import logger
|
||||||
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
|
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
||||||
|
|
||||||
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
|
from sb3_contrib.qrdqn.policies import QRDQNPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class QRDQN(OffPolicyAlgorithm):
|
||||||
|
"""
|
||||||
|
Quantile Regression Deep Q-Network (QR-DQN)
|
||||||
|
Paper: https://arxiv.org/abs/1710.10044
|
||||||
|
Default hyperparameters are taken from the paper and are tuned for Atari games.
|
||||||
|
|
||||||
|
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||||
|
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||||
|
:param learning_rate: The learning rate, it can be a function
|
||||||
|
of the current progress remaining (from 1 to 0)
|
||||||
|
:param buffer_size: size of the replay buffer
|
||||||
|
:param learning_starts: how many steps of the model to collect transitions for before learning starts
|
||||||
|
:param batch_size: Minibatch size for each gradient update
|
||||||
|
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
|
||||||
|
:param gamma: the discount factor
|
||||||
|
:param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable.
|
||||||
|
:param gradient_steps: How many gradient steps to do after each rollout
|
||||||
|
(see ``train_freq`` and ``n_episodes_rollout``)
|
||||||
|
Set to ``-1`` means to do as many gradient steps as steps done in the environment
|
||||||
|
during the rollout.
|
||||||
|
:param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
|
||||||
|
Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable.
|
||||||
|
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
||||||
|
at a cost of more complexity.
|
||||||
|
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||||
|
:param target_update_interval: update the target network every ``target_update_interval``
|
||||||
|
environment steps.
|
||||||
|
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
|
||||||
|
:param exploration_initial_eps: initial value of random action probability
|
||||||
|
:param exploration_final_eps: final value of random action probability
|
||||||
|
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping)
|
||||||
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
|
:param create_eval_env: Whether to create a second environment that will be
|
||||||
|
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||||
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
|
:param seed: Seed for the pseudo random generators
|
||||||
|
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
||||||
|
Setting it to auto, the code will be run on the GPU if possible.
|
||||||
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
policy: Union[str, Type[QRDQNPolicy]],
|
||||||
|
env: Union[GymEnv, str],
|
||||||
|
learning_rate: Union[float, Schedule] = 5e-5,
|
||||||
|
buffer_size: int = 1000000,
|
||||||
|
learning_starts: int = 50000,
|
||||||
|
batch_size: Optional[int] = 32,
|
||||||
|
tau: float = 1.0,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
train_freq: int = 4,
|
||||||
|
gradient_steps: int = 1,
|
||||||
|
n_episodes_rollout: int = -1,
|
||||||
|
optimize_memory_usage: bool = False,
|
||||||
|
target_update_interval: int = 10000,
|
||||||
|
exploration_fraction: float = 0.005,
|
||||||
|
exploration_initial_eps: float = 1.0,
|
||||||
|
exploration_final_eps: float = 0.01,
|
||||||
|
max_grad_norm: Optional[float] = None,
|
||||||
|
tensorboard_log: Optional[str] = None,
|
||||||
|
create_eval_env: bool = False,
|
||||||
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
verbose: int = 0,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
device: Union[th.device, str] = "auto",
|
||||||
|
_init_setup_model: bool = True,
|
||||||
|
):
|
||||||
|
|
||||||
|
super(QRDQN, self).__init__(
|
||||||
|
policy,
|
||||||
|
env,
|
||||||
|
QRDQNPolicy,
|
||||||
|
learning_rate,
|
||||||
|
buffer_size,
|
||||||
|
learning_starts,
|
||||||
|
batch_size,
|
||||||
|
tau,
|
||||||
|
gamma,
|
||||||
|
train_freq,
|
||||||
|
gradient_steps,
|
||||||
|
n_episodes_rollout,
|
||||||
|
action_noise=None, # No action noise
|
||||||
|
policy_kwargs=policy_kwargs,
|
||||||
|
tensorboard_log=tensorboard_log,
|
||||||
|
verbose=verbose,
|
||||||
|
device=device,
|
||||||
|
create_eval_env=create_eval_env,
|
||||||
|
seed=seed,
|
||||||
|
sde_support=False,
|
||||||
|
optimize_memory_usage=optimize_memory_usage,
|
||||||
|
supported_action_spaces=(gym.spaces.Discrete,),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.exploration_initial_eps = exploration_initial_eps
|
||||||
|
self.exploration_final_eps = exploration_final_eps
|
||||||
|
self.exploration_fraction = exploration_fraction
|
||||||
|
self.target_update_interval = target_update_interval
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
# "epsilon" for the epsilon-greedy exploration
|
||||||
|
self.exploration_rate = 0.0
|
||||||
|
# Linear schedule will be defined in `_setup_model()`
|
||||||
|
self.exploration_schedule = None
|
||||||
|
self.quantile_net, self.quantile_net_target = None, None
|
||||||
|
|
||||||
|
if "optimizer_class" not in self.policy_kwargs:
|
||||||
|
self.policy_kwargs["optimizer_class"] = th.optim.Adam
|
||||||
|
# Proposed in the QR-DQN paper where `batch_size = 32`
|
||||||
|
self.policy_kwargs["optimizer_kwargs"] = dict(eps=0.01 / batch_size)
|
||||||
|
|
||||||
|
if _init_setup_model:
|
||||||
|
self._setup_model()
|
||||||
|
|
||||||
|
def _setup_model(self) -> None:
|
||||||
|
super(QRDQN, self)._setup_model()
|
||||||
|
self._create_aliases()
|
||||||
|
self.exploration_schedule = get_linear_fn(
|
||||||
|
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_aliases(self) -> None:
|
||||||
|
self.quantile_net = self.policy.quantile_net
|
||||||
|
self.quantile_net_target = self.policy.quantile_net_target
|
||||||
|
self.n_quantiles = self.policy.n_quantiles
|
||||||
|
|
||||||
|
def _on_step(self) -> None:
|
||||||
|
"""
|
||||||
|
Update the exploration rate and target network if needed.
|
||||||
|
This method is called in ``collect_rollouts()`` after each step in the environment.
|
||||||
|
"""
|
||||||
|
if self.num_timesteps % self.target_update_interval == 0:
|
||||||
|
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
||||||
|
|
||||||
|
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
|
||||||
|
logger.record("rollout/exploration rate", self.exploration_rate)
|
||||||
|
|
||||||
|
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
||||||
|
# Update learning rate according to schedule
|
||||||
|
self._update_learning_rate(self.policy.optimizer)
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
for gradient_step in range(gradient_steps):
|
||||||
|
# Sample replay buffer
|
||||||
|
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||||
|
|
||||||
|
with th.no_grad():
|
||||||
|
# Compute the quantiles of next observation
|
||||||
|
next_quantiles = self.quantile_net_target(replay_data.next_observations)
|
||||||
|
# Follow greedy policy: use the one with the highest value
|
||||||
|
next_quantiles, _ = next_quantiles.max(dim=2)
|
||||||
|
# 1-step TD target
|
||||||
|
target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles
|
||||||
|
|
||||||
|
# Get current quantile estimates
|
||||||
|
current_quantiles = self.quantile_net(replay_data.observations)
|
||||||
|
|
||||||
|
# Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1).
|
||||||
|
actions = replay_data.actions[..., None].long().expand(batch_size, self.n_quantiles, 1)
|
||||||
|
# Retrieve the quantiles for the actions from the replay buffer
|
||||||
|
current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2)
|
||||||
|
|
||||||
|
# Compute Quantile Huber loss, summing over a quantile dimension as in the paper.
|
||||||
|
loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True)
|
||||||
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
# Optimize the policy
|
||||||
|
self.policy.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
# Clip gradient norm
|
||||||
|
if self.max_grad_norm is not None:
|
||||||
|
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||||
|
self.policy.optimizer.step()
|
||||||
|
|
||||||
|
# Increase update counter
|
||||||
|
self._n_updates += gradient_steps
|
||||||
|
|
||||||
|
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||||
|
logger.record("train/loss", np.mean(losses))
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
observation: np.ndarray,
|
||||||
|
state: Optional[np.ndarray] = None,
|
||||||
|
mask: Optional[np.ndarray] = None,
|
||||||
|
deterministic: bool = False,
|
||||||
|
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Overrides the base_class predict function to include epsilon-greedy exploration.
|
||||||
|
|
||||||
|
:param observation: the input observation
|
||||||
|
:param state: The last states (can be None, used in recurrent policies)
|
||||||
|
:param mask: The last masks (can be None, used in recurrent policies)
|
||||||
|
:param deterministic: Whether or not to return deterministic actions.
|
||||||
|
:return: the model's action and the next state
|
||||||
|
(used in recurrent policies)
|
||||||
|
"""
|
||||||
|
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||||
|
if is_vectorized_observation(observation, self.observation_space):
|
||||||
|
n_batch = observation.shape[0]
|
||||||
|
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
||||||
|
else:
|
||||||
|
action = np.array(self.action_space.sample())
|
||||||
|
else:
|
||||||
|
action, state = self.policy.predict(observation, state, mask, deterministic)
|
||||||
|
return action, state
|
||||||
|
|
||||||
|
def learn(
|
||||||
|
self,
|
||||||
|
total_timesteps: int,
|
||||||
|
callback: MaybeCallback = None,
|
||||||
|
log_interval: int = 4,
|
||||||
|
eval_env: Optional[GymEnv] = None,
|
||||||
|
eval_freq: int = -1,
|
||||||
|
n_eval_episodes: int = 5,
|
||||||
|
tb_log_name: str = "QRDQN",
|
||||||
|
eval_log_path: Optional[str] = None,
|
||||||
|
reset_num_timesteps: bool = True,
|
||||||
|
) -> OffPolicyAlgorithm:
|
||||||
|
|
||||||
|
return super(QRDQN, self).learn(
|
||||||
|
total_timesteps=total_timesteps,
|
||||||
|
callback=callback,
|
||||||
|
log_interval=log_interval,
|
||||||
|
eval_env=eval_env,
|
||||||
|
eval_freq=eval_freq,
|
||||||
|
n_eval_episodes=n_eval_episodes,
|
||||||
|
tb_log_name=tb_log_name,
|
||||||
|
eval_log_path=eval_log_path,
|
||||||
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _excluded_save_params(self) -> List[str]:
|
||||||
|
return super(QRDQN, self)._excluded_save_params() + ["quantile_net", "quantile_net_target"]
|
||||||
|
|
||||||
|
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||||
|
state_dicts = ["policy", "policy.optimizer"]
|
||||||
|
|
||||||
|
return state_dicts, []
|
||||||
|
|
@ -391,6 +391,8 @@ class TQCPolicy(BasePolicy):
|
||||||
optimizer_kwargs=self.optimizer_kwargs,
|
optimizer_kwargs=self.optimizer_kwargs,
|
||||||
features_extractor_class=self.features_extractor_class,
|
features_extractor_class=self.features_extractor_class,
|
||||||
features_extractor_kwargs=self.features_extractor_kwargs,
|
features_extractor_kwargs=self.features_extractor_kwargs,
|
||||||
|
n_quantiles=self.critic_kwargs["n_quantiles"],
|
||||||
|
n_critics=self.critic_kwargs["n_critics"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||||
from stable_baselines3.common.utils import polyak_update
|
from stable_baselines3.common.utils import polyak_update
|
||||||
|
|
||||||
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
from sb3_contrib.tqc.policies import TQCPolicy
|
from sb3_contrib.tqc.policies import TQCPolicy
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -171,26 +172,6 @@ class TQC(OffPolicyAlgorithm):
|
||||||
self.critic = self.policy.critic
|
self.critic = self.policy.critic
|
||||||
self.critic_target = self.policy.critic_target
|
self.critic_target = self.policy.critic_target
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor:
|
|
||||||
"""
|
|
||||||
The quantile-regression loss, as described in the QR-DQN and TQC papers.
|
|
||||||
Taken from https://github.com/bayesgroup/tqc_pytorch
|
|
||||||
|
|
||||||
:param quantiles:
|
|
||||||
:param samples:
|
|
||||||
:return: the loss
|
|
||||||
"""
|
|
||||||
# batch x nets x quantiles x samples
|
|
||||||
pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None]
|
|
||||||
abs_pairwise_delta = th.abs(pairwise_delta)
|
|
||||||
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
|
|
||||||
|
|
||||||
n_quantiles = quantiles.shape[2]
|
|
||||||
tau = th.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles
|
|
||||||
loss = (th.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean()
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
|
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
|
||||||
# Update optimizers learning rate
|
# Update optimizers learning rate
|
||||||
optimizers = [self.actor.optimizer, self.critic.optimizer]
|
optimizers = [self.actor.optimizer, self.critic.optimizer]
|
||||||
|
|
@ -237,24 +218,27 @@ class TQC(OffPolicyAlgorithm):
|
||||||
self.ent_coef_optimizer.step()
|
self.ent_coef_optimizer.step()
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics
|
|
||||||
# Select action according to policy
|
# Select action according to policy
|
||||||
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||||
# Compute and cut quantiles at the next state
|
# Compute and cut quantiles at the next state
|
||||||
# batch x nets x quantiles
|
# batch x nets x quantiles
|
||||||
next_z = self.critic_target(replay_data.next_observations, next_actions)
|
next_quantiles = self.critic_target(replay_data.next_observations, next_actions)
|
||||||
sorted_z, _ = th.sort(next_z.reshape(batch_size, -1))
|
|
||||||
sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop]
|
# Sort and drop top k quantiles to control overestimation.
|
||||||
|
n_target_quantiles = self.critic.quantiles_total - self.top_quantiles_to_drop_per_net * self.critic.n_critics
|
||||||
|
next_quantiles, _ = th.sort(next_quantiles.reshape(batch_size, -1))
|
||||||
|
next_quantiles = next_quantiles[:, :n_target_quantiles]
|
||||||
|
|
||||||
target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1)
|
|
||||||
# td error + entropy term
|
# td error + entropy term
|
||||||
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
|
target_quantiles = next_quantiles - ent_coef * next_log_prob.reshape(-1, 1)
|
||||||
|
target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_quantiles
|
||||||
|
# Make target_quantiles broadcastable to (batch_size, n_critics, n_target_quantiles).
|
||||||
|
target_quantiles.unsqueeze_(dim=1)
|
||||||
|
|
||||||
# Get current Q estimates
|
# Get current Quantile estimates using action from the replay buffer
|
||||||
# using action from the replay buffer
|
current_quantiles = self.critic(replay_data.observations, replay_data.actions)
|
||||||
current_z = self.critic(replay_data.observations, replay_data.actions)
|
# Compute critic loss, not summing over the quantile dimension as in the paper.
|
||||||
# Compute critic loss
|
critic_loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=False)
|
||||||
critic_loss = self.quantile_huber_loss(current_z, q_backup)
|
|
||||||
critic_losses.append(critic_loss.item())
|
critic_losses.append(critic_loss.item())
|
||||||
|
|
||||||
# Optimize the critic
|
# Optimize the critic
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators
|
||||||
# Ignore import not used when aliases are defined
|
# Ignore import not used when aliases are defined
|
||||||
per-file-ignores =
|
per-file-ignores =
|
||||||
./sb3_contrib/__init__.py:F401
|
./sb3_contrib/__init__.py:F401
|
||||||
|
./sb3_contrib/qrdqn/__init__.py:F401
|
||||||
./sb3_contrib/tqc/__init__.py:F401
|
./sb3_contrib/tqc/__init__.py:F401
|
||||||
./sb3_contrib/common/wrappers/__init__.py:F401
|
./sb3_contrib/common/wrappers/__init__.py:F401
|
||||||
exclude =
|
exclude =
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,10 @@ import torch as th
|
||||||
from stable_baselines3.common.identity_env import FakeImageEnv
|
from stable_baselines3.common.identity_env import FakeImageEnv
|
||||||
from stable_baselines3.common.utils import zip_strict
|
from stable_baselines3.common.utils import zip_strict
|
||||||
|
|
||||||
from sb3_contrib import TQC
|
from sb3_contrib import QRDQN, TQC
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [TQC])
|
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||||
def test_cnn(tmp_path, model_class):
|
def test_cnn(tmp_path, model_class):
|
||||||
SAVE_NAME = "cnn_model.zip"
|
SAVE_NAME = "cnn_model.zip"
|
||||||
# Fake grayscale with frameskip
|
# Fake grayscale with frameskip
|
||||||
|
|
@ -18,10 +18,13 @@ def test_cnn(tmp_path, model_class):
|
||||||
# to check that the network handle it automatically
|
# to check that the network handle it automatically
|
||||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if model_class in {TQC}:
|
if model_class in {TQC, QRDQN}:
|
||||||
# Avoid memory error when using replay buffer
|
# Avoid memory error when using replay buffer
|
||||||
# Reduce the size of the features
|
# Reduce the size of the features and the number of quantiles
|
||||||
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
|
kwargs = dict(
|
||||||
|
buffer_size=250,
|
||||||
|
policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)),
|
||||||
|
)
|
||||||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||||
|
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
@ -39,6 +42,13 @@ def test_cnn(tmp_path, model_class):
|
||||||
os.remove(str(tmp_path / SAVE_NAME))
|
os.remove(str(tmp_path / SAVE_NAME))
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qrdqn_names_(model):
|
||||||
|
# Small hack to make the test work with QRDQN
|
||||||
|
if isinstance(model, QRDQN):
|
||||||
|
model.critic = model.quantile_net
|
||||||
|
model.critic_target = model.quantile_net_target
|
||||||
|
|
||||||
|
|
||||||
def params_should_match(params, other_params):
|
def params_should_match(params, other_params):
|
||||||
for param, other_param in zip_strict(params, other_params):
|
for param, other_param in zip_strict(params, other_params):
|
||||||
assert th.allclose(param, other_param)
|
assert th.allclose(param, other_param)
|
||||||
|
|
@ -49,28 +59,36 @@ def params_should_differ(params, other_params):
|
||||||
assert not th.allclose(param, other_param)
|
assert not th.allclose(param, other_param)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [TQC])
|
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||||
@pytest.mark.parametrize("share_features_extractor", [True, False])
|
@pytest.mark.parametrize("share_features_extractor", [True, False])
|
||||||
def test_feature_extractor_target_net(model_class, share_features_extractor):
|
def test_feature_extractor_target_net(model_class, share_features_extractor):
|
||||||
|
if model_class == QRDQN and share_features_extractor:
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
||||||
# Avoid memory error when using replay buffer
|
|
||||||
# Reduce the size of the features
|
if model_class in {TQC, QRDQN}:
|
||||||
kwargs = dict(
|
# Avoid memory error when using replay buffer
|
||||||
buffer_size=250,
|
# Reduce the size of the features and the number of quantiles
|
||||||
learning_starts=100,
|
kwargs = dict(
|
||||||
policy_kwargs=dict(
|
buffer_size=250,
|
||||||
features_extractor_kwargs=dict(features_dim=32),
|
learning_starts=100,
|
||||||
share_features_extractor=share_features_extractor,
|
policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)),
|
||||||
),
|
)
|
||||||
)
|
if model_class != QRDQN:
|
||||||
|
kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor
|
||||||
|
|
||||||
model = model_class("CnnPolicy", env, seed=0, **kwargs)
|
model = model_class("CnnPolicy", env, seed=0, **kwargs)
|
||||||
|
|
||||||
|
patch_qrdqn_names_(model)
|
||||||
|
|
||||||
if share_features_extractor:
|
if share_features_extractor:
|
||||||
# Check that the objects are the same and not just copied
|
# Check that the objects are the same and not just copied
|
||||||
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
|
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
|
||||||
else:
|
else:
|
||||||
# Check that the objects differ
|
# Check that the objects differ
|
||||||
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
|
if model_class != QRDQN:
|
||||||
|
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
|
||||||
|
|
||||||
# Critic and target should be equal at the begginning of training
|
# Critic and target should be equal at the begginning of training
|
||||||
params_should_match(model.critic.parameters(), model.critic_target.parameters())
|
params_should_match(model.critic.parameters(), model.critic_target.parameters())
|
||||||
|
|
@ -83,6 +101,8 @@ def test_feature_extractor_target_net(model_class, share_features_extractor):
|
||||||
# Re-initialize and collect some random data (without doing gradient steps)
|
# Re-initialize and collect some random data (without doing gradient steps)
|
||||||
model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10)
|
model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10)
|
||||||
|
|
||||||
|
patch_qrdqn_names_(model)
|
||||||
|
|
||||||
original_param = deepcopy(list(model.critic.parameters()))
|
original_param = deepcopy(list(model.critic.parameters()))
|
||||||
original_target_param = deepcopy(list(model.critic_target.parameters()))
|
original_target_param = deepcopy(list(model.critic_target.parameters()))
|
||||||
|
|
||||||
|
|
@ -103,6 +123,11 @@ def test_feature_extractor_target_net(model_class, share_features_extractor):
|
||||||
model.lr_schedule = lambda _: 0.0
|
model.lr_schedule = lambda _: 0.0
|
||||||
# Re-activate polyak update
|
# Re-activate polyak update
|
||||||
model.tau = 0.01
|
model.tau = 0.01
|
||||||
|
# Special case for QRDQN: target net is updated in the `collect_rollouts()`
|
||||||
|
# not the `train()` method
|
||||||
|
if model_class == QRDQN:
|
||||||
|
model.target_update_interval = 1
|
||||||
|
model._on_step()
|
||||||
|
|
||||||
model.train(gradient_steps=1)
|
model.train(gradient_steps=1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from sb3_contrib import TQC
|
from sb3_contrib import QRDQN, TQC
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
||||||
|
|
@ -21,7 +21,11 @@ def test_tqc(ent_coef):
|
||||||
def test_n_critics(n_critics):
|
def test_n_critics(n_critics):
|
||||||
# Test TQC with different number of critics
|
# Test TQC with different number of critics
|
||||||
model = TQC(
|
model = TQC(
|
||||||
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1
|
"MlpPolicy",
|
||||||
|
"Pendulum-v0",
|
||||||
|
policy_kwargs=dict(net_arch=[64], n_critics=n_critics),
|
||||||
|
learning_starts=100,
|
||||||
|
verbose=1,
|
||||||
)
|
)
|
||||||
model.learn(total_timesteps=300)
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
|
|
@ -38,3 +42,17 @@ def test_sde():
|
||||||
model.learn(total_timesteps=300)
|
model.learn(total_timesteps=300)
|
||||||
model.policy.reset_noise()
|
model.policy.reset_noise()
|
||||||
model.policy.actor.get_std()
|
model.policy.actor.get_std()
|
||||||
|
|
||||||
|
|
||||||
|
def test_qrdqn():
|
||||||
|
model = QRDQN(
|
||||||
|
"MlpPolicy",
|
||||||
|
"CartPole-v1",
|
||||||
|
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
|
||||||
|
learning_starts=100,
|
||||||
|
buffer_size=500,
|
||||||
|
learning_rate=3e-4,
|
||||||
|
verbose=1,
|
||||||
|
create_eval_env=True,
|
||||||
|
)
|
||||||
|
model.learn(total_timesteps=500, eval_freq=250)
|
||||||
|
|
|
||||||
|
|
@ -7,22 +7,21 @@ import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3 import DQN
|
|
||||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||||
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||||
from stable_baselines3.common.utils import get_device
|
from stable_baselines3.common.utils import get_device
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||||
|
|
||||||
from sb3_contrib import TQC
|
from sb3_contrib import QRDQN, TQC
|
||||||
|
|
||||||
MODEL_LIST = [TQC]
|
MODEL_LIST = [TQC, QRDQN]
|
||||||
|
|
||||||
|
|
||||||
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
||||||
"""
|
"""
|
||||||
Selects an environment with the correct action space as DQN only supports discrete action space
|
Selects an environment with the correct action space as QRDQN only supports discrete action space
|
||||||
"""
|
"""
|
||||||
if model_class == DQN:
|
if model_class == QRDQN:
|
||||||
return IdentityEnv(10)
|
return IdentityEnv(10)
|
||||||
else:
|
else:
|
||||||
return IdentityEnvBox(10)
|
return IdentityEnvBox(10)
|
||||||
|
|
@ -41,8 +40,13 @@ def test_save_load(tmp_path, model_class):
|
||||||
|
|
||||||
env = DummyVecEnv([lambda: select_env(model_class)])
|
env = DummyVecEnv([lambda: select_env(model_class)])
|
||||||
|
|
||||||
|
policy_kwargs = dict(net_arch=[16])
|
||||||
|
|
||||||
|
if model_class in {QRDQN, TQC}:
|
||||||
|
policy_kwargs.update(dict(n_quantiles=20))
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
model = model_class("MlpPolicy", env, verbose=1, policy_kwargs=policy_kwargs)
|
||||||
model.learn(total_timesteps=300)
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
env.reset()
|
env.reset()
|
||||||
|
|
@ -167,13 +171,18 @@ def test_set_env(model_class):
|
||||||
:param model_class: (BaseAlgorithm) A RL model
|
:param model_class: (BaseAlgorithm) A RL model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# use discrete for DQN
|
# use discrete for QRDQN
|
||||||
env = DummyVecEnv([lambda: select_env(model_class)])
|
env = DummyVecEnv([lambda: select_env(model_class)])
|
||||||
env2 = DummyVecEnv([lambda: select_env(model_class)])
|
env2 = DummyVecEnv([lambda: select_env(model_class)])
|
||||||
env3 = select_env(model_class)
|
env3 = select_env(model_class)
|
||||||
|
|
||||||
|
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
||||||
|
if model_class in {TQC, QRDQN}:
|
||||||
|
kwargs.update(dict(learning_starts=100))
|
||||||
|
kwargs["policy_kwargs"].update(dict(n_quantiles=20))
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
|
model = model_class("MlpPolicy", env, **kwargs)
|
||||||
# learn
|
# learn
|
||||||
model.learn(total_timesteps=300)
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
|
|
@ -219,7 +228,7 @@ def test_exclude_include_saved_params(tmp_path, model_class):
|
||||||
os.remove(tmp_path / "test_save.zip")
|
os.remove(tmp_path / "test_save.zip")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [TQC])
|
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||||
def test_save_load_replay_buffer(tmp_path, model_class):
|
def test_save_load_replay_buffer(tmp_path, model_class):
|
||||||
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
|
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
|
||||||
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
|
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
|
||||||
|
|
@ -254,20 +263,28 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
||||||
:param model_class: (BaseAlgorithm) A RL model
|
:param model_class: (BaseAlgorithm) A RL model
|
||||||
:param policy_str: (str) Name of the policy.
|
:param policy_str: (str) Name of the policy.
|
||||||
"""
|
"""
|
||||||
kwargs = {}
|
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
||||||
if policy_str == "MlpPolicy":
|
if policy_str == "MlpPolicy":
|
||||||
env = select_env(model_class)
|
env = select_env(model_class)
|
||||||
else:
|
else:
|
||||||
if model_class in [TQC]:
|
if model_class in [TQC, QRDQN]:
|
||||||
# Avoid memory error when using replay buffer
|
# Avoid memory error when using replay buffer
|
||||||
# Reduce the size of the features
|
# Reduce the size of the features
|
||||||
kwargs = dict(buffer_size=250)
|
kwargs = dict(
|
||||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
|
buffer_size=250,
|
||||||
|
learning_starts=100,
|
||||||
|
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||||
|
)
|
||||||
|
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN)
|
||||||
|
|
||||||
|
# Reduce number of quantiles for faster tests
|
||||||
|
if model_class in [TQC, QRDQN]:
|
||||||
|
kwargs["policy_kwargs"].update(dict(n_quantiles=20))
|
||||||
|
|
||||||
env = DummyVecEnv([lambda: env])
|
env = DummyVecEnv([lambda: env])
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
|
model = model_class(policy_str, env, verbose=1, **kwargs)
|
||||||
model.learn(total_timesteps=300)
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
env.reset()
|
env.reset()
|
||||||
|
|
@ -334,3 +351,83 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
||||||
os.remove(tmp_path / "policy.pkl")
|
os.remove(tmp_path / "policy.pkl")
|
||||||
if actor_class is not None:
|
if actor_class is not None:
|
||||||
os.remove(tmp_path / "actor.pkl")
|
os.remove(tmp_path / "actor.pkl")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_class", [QRDQN])
|
||||||
|
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
|
||||||
|
def test_save_load_q_net(tmp_path, model_class, policy_str):
|
||||||
|
"""
|
||||||
|
Test saving and loading q-network/quantile net only.
|
||||||
|
|
||||||
|
:param model_class: (BaseAlgorithm) A RL model
|
||||||
|
:param policy_str: (str) Name of the policy.
|
||||||
|
"""
|
||||||
|
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
||||||
|
if policy_str == "MlpPolicy":
|
||||||
|
env = select_env(model_class)
|
||||||
|
else:
|
||||||
|
if model_class in [QRDQN]:
|
||||||
|
# Avoid memory error when using replay buffer
|
||||||
|
# Reduce the size of the features
|
||||||
|
kwargs = dict(
|
||||||
|
buffer_size=250,
|
||||||
|
learning_starts=100,
|
||||||
|
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||||
|
)
|
||||||
|
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN)
|
||||||
|
|
||||||
|
# Reduce number of quantiles for faster tests
|
||||||
|
if model_class in [QRDQN]:
|
||||||
|
kwargs["policy_kwargs"].update(dict(n_quantiles=20))
|
||||||
|
|
||||||
|
env = DummyVecEnv([lambda: env])
|
||||||
|
|
||||||
|
# create model
|
||||||
|
model = model_class(policy_str, env, verbose=1, **kwargs)
|
||||||
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
|
env.reset()
|
||||||
|
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||||
|
|
||||||
|
q_net = model.quantile_net
|
||||||
|
q_net_class = q_net.__class__
|
||||||
|
|
||||||
|
# Get dictionary of current parameters
|
||||||
|
params = deepcopy(q_net.state_dict())
|
||||||
|
|
||||||
|
# Modify all parameters to be random values
|
||||||
|
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||||
|
|
||||||
|
# Update model parameters with the new random values
|
||||||
|
q_net.load_state_dict(random_params)
|
||||||
|
|
||||||
|
new_params = q_net.state_dict()
|
||||||
|
# Check that all params are different now
|
||||||
|
for k in params:
|
||||||
|
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
|
||||||
|
|
||||||
|
params = new_params
|
||||||
|
|
||||||
|
# get selected actions
|
||||||
|
selected_actions, _ = q_net.predict(observations, deterministic=True)
|
||||||
|
|
||||||
|
# Save and load q_net
|
||||||
|
q_net.save(tmp_path / "q_net.pkl")
|
||||||
|
|
||||||
|
del q_net
|
||||||
|
|
||||||
|
q_net = q_net_class.load(tmp_path / "q_net.pkl")
|
||||||
|
|
||||||
|
# check if params are still the same after load
|
||||||
|
new_params = q_net.state_dict()
|
||||||
|
|
||||||
|
# Check that all params are the same as before save load procedure now
|
||||||
|
for key in params:
|
||||||
|
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
|
||||||
|
|
||||||
|
# check if model still selects the same actions
|
||||||
|
new_selected_actions, _ = q_net.predict(observations, deterministic=True)
|
||||||
|
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||||
|
|
||||||
|
# clear file from os
|
||||||
|
os.remove(tmp_path / "q_net.pkl")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch as th
|
||||||
|
|
||||||
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
|
|
||||||
|
|
||||||
|
def test_quantile_huber_loss():
|
||||||
|
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10)), 2.5)
|
||||||
|
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10), sum_over_quantiles=False), 0.25)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
quantile_huber_loss(th.zeros(1, 4, 4), th.zeros(1, 4))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
quantile_huber_loss(th.zeros(1, 4), th.zeros(1, 1, 4))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4))
|
||||||
Loading…
Reference in New Issue