Add Trust Region Policy Optimization (TRPO) (#40)
* Feat: adding TRPO algorithm (WIP) WIP - Trust Region Policy Algorithm Currently the Hessian vector product is not working (see inline comments for more detail) * Feat: adding TRPO algorithm (WIP) Adding no_grad block for the line search Additional assert in the conjugate solver to help debugging * Feat: adding TRPO algorithm (WIP) - Adding ActorCriticPolicy.get_distribution - Using the Distribution object to compute the KL divergence - Checking for objective improvement in the line search - Moving magic numbers to instance variables * Feat: adding TRPO algorithm (WIP) Improving numerical stability of the conjugate gradient algorithm Critic updates * Feat: adding TRPO algorithm (WIP) Changes around the alpha of the line search Adding TRPO to __init__ files * feat: TRPO - addressing PR comments - renaming cg_solver to conjugate_gradient_solver and renaming parameter Avp_fun to matrix_vector_dot_func + docstring - extra comments + better variable names in trpo.py - defining a method for the hessian vector product instead of an inline function - fix registering correct policies for TRPO and using correct policy base in constructor * refactor: TRPO - policier - refactoring sb3_contrib.common.policies to reuse as much code as possible from sb3 * feat: using updated ActorCriticPolicy from SB3 - get_distribution will be added directly to the SB3 version of ActorCriticPolicy, this commit reflects this * Bump version for `get_distribution` support * Add basic test * Reformat * [ci skip] Fix changelog * fix: setting train mode for trpo * fix: batch_size type hint in trpo.py * style: renaming variables + docstring in trpo.py * Rename + cleanup * Move grad computation to separate method * Remove grad norm clipping * Remove n epochs and add sub-sampling * Update defaults * Add Doc * Add more test and fixes for CNN * Update doc + add benchmark * Add tests + update doc * Fix doc * Improve names for conjugate gradient * Update comments * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
b44689b0ea
commit
59be198da0
|
|
@ -28,6 +28,7 @@ See documentation for the full list of included features.
|
||||||
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
||||||
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
|
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
|
||||||
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
|
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
|
||||||
|
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
|
||||||
|
|
||||||
**Gym Wrappers**:
|
**Gym Wrappers**:
|
||||||
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
.. _utils:
|
||||||
|
|
||||||
|
Utils
|
||||||
|
=====
|
||||||
|
|
||||||
|
.. automodule:: sb3_contrib.common.utils
|
||||||
|
:members:
|
||||||
|
|
@ -9,6 +9,7 @@ along with some useful characteristics: support for discrete/continuous actions,
|
||||||
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||||
============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
TQC ✔️ ❌ ❌ ❌ ✔️
|
TQC ✔️ ❌ ❌ ❌ ✔️
|
||||||
|
TRPO ✔️ ✔️ ✔️ ✔️ ✔️
|
||||||
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
|
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
|
||||||
============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,3 +44,16 @@ Train a PPO with invalid action masking agent on a toy environment.
|
||||||
model = MaskablePPO("MlpPolicy", env, verbose=1)
|
model = MaskablePPO("MlpPolicy", env, verbose=1)
|
||||||
model.learn(5000)
|
model.learn(5000)
|
||||||
model.save("qrdqn_cartpole")
|
model.save("qrdqn_cartpole")
|
||||||
|
|
||||||
|
TRPO
|
||||||
|
----
|
||||||
|
|
||||||
|
Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environment.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from sb3_contrib import TRPO
|
||||||
|
|
||||||
|
model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1)
|
||||||
|
model.learn(total_timesteps=100_000, log_interval=4)
|
||||||
|
model.save("trpo_pendulum")
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
|
||||||
:caption: RL Algorithms
|
:caption: RL Algorithms
|
||||||
|
|
||||||
modules/tqc
|
modules/tqc
|
||||||
|
modules/trpo
|
||||||
modules/qrdqn
|
modules/qrdqn
|
||||||
modules/ppo_mask
|
modules/ppo_mask
|
||||||
|
|
||||||
|
|
@ -39,6 +40,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
:caption: Common
|
:caption: Common
|
||||||
|
|
||||||
|
common/utils
|
||||||
common/wrappers
|
common/wrappers
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,9 @@ Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
Release 1.3.1a6 (WIP)
|
Release 1.3.1a7 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
**Add TRPO**
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
|
|
@ -15,6 +16,7 @@ Breaking Changes:
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
- Added ``TRPO`` (@cyprienc)
|
||||||
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
|
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
|
|
@ -34,7 +36,7 @@ Documentation:
|
||||||
Release 1.3.0 (2021-10-23)
|
Release 1.3.0 (2021-10-23)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
**Invalid action masking for PPO**
|
**Add Invalid action masking for PPO**
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
|
|
@ -52,6 +54,7 @@ New Features:
|
||||||
- Added ``MaskablePPO`` algorithm (@kronion)
|
- Added ``MaskablePPO`` algorithm (@kronion)
|
||||||
- ``MaskablePPO`` Dictionary Observation support (@glmcdona)
|
- ``MaskablePPO`` Dictionary Observation support (@glmcdona)
|
||||||
|
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
|
@ -75,9 +78,6 @@ Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
- Upgraded to Stable-Baselines3 >= 1.2.0
|
- Upgraded to Stable-Baselines3 >= 1.2.0
|
||||||
|
|
||||||
New Features:
|
|
||||||
^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
- QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright)
|
- QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright)
|
||||||
|
|
@ -221,4 +221,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
|
||||||
Contributors:
|
Contributors:
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona
|
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,151 @@
|
||||||
|
.. _tqc:
|
||||||
|
|
||||||
|
.. automodule:: sb3_contrib.trpo
|
||||||
|
|
||||||
|
TRPO
|
||||||
|
====
|
||||||
|
|
||||||
|
`Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
|
||||||
|
is an iterative approach for optimizing policies with guaranteed monotonic improvement.
|
||||||
|
|
||||||
|
.. rubric:: Available Policies
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
MlpPolicy
|
||||||
|
CnnPolicy
|
||||||
|
MultiInputPolicy
|
||||||
|
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
|
||||||
|
- Original paper: https://arxiv.org/abs/1502.05477
|
||||||
|
- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/
|
||||||
|
|
||||||
|
|
||||||
|
Can I use?
|
||||||
|
----------
|
||||||
|
|
||||||
|
- Recurrent policies: ❌
|
||||||
|
- Multi processing: ✔️
|
||||||
|
- Gym spaces:
|
||||||
|
|
||||||
|
|
||||||
|
============= ====== ===========
|
||||||
|
Space Action Observation
|
||||||
|
============= ====== ===========
|
||||||
|
Discrete ✔️ ✔️
|
||||||
|
Box ✔️ ✔️
|
||||||
|
MultiDiscrete ✔️ ✔️
|
||||||
|
MultiBinary ✔️ ✔️
|
||||||
|
Dict ❌ ✔️
|
||||||
|
============= ====== ===========
|
||||||
|
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from sb3_contrib import TRPO
|
||||||
|
|
||||||
|
env = gym.make("Pendulum-v0")
|
||||||
|
|
||||||
|
model = TRPO("MlpPolicy", env, verbose=1)
|
||||||
|
model.learn(total_timesteps=10000, log_interval=4)
|
||||||
|
model.save("trpo_pendulum")
|
||||||
|
|
||||||
|
del model # remove to demonstrate saving and loading
|
||||||
|
|
||||||
|
model = TRPO.load("trpo_pendulum")
|
||||||
|
|
||||||
|
obs = env.reset()
|
||||||
|
while True:
|
||||||
|
action, _states = model.predict(obs, deterministic=True)
|
||||||
|
obs, reward, done, info = env.step(action)
|
||||||
|
env.render()
|
||||||
|
if done:
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
|
||||||
|
Results
|
||||||
|
-------
|
||||||
|
|
||||||
|
Result on the MuJoCo benchmark (1M steps on ``-v3`` envs with MuJoCo v2.1.0) using 3 seeds.
|
||||||
|
The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/40>`_.
|
||||||
|
|
||||||
|
|
||||||
|
===================== ============
|
||||||
|
Environments TRPO
|
||||||
|
===================== ============
|
||||||
|
HalfCheetah 1803 +/- 46
|
||||||
|
Ant 3554 +/- 591
|
||||||
|
Hopper 3372 +/- 215
|
||||||
|
Walker2d 4502 +/- 234
|
||||||
|
Swimmer 359 +/- 2
|
||||||
|
===================== ============
|
||||||
|
|
||||||
|
|
||||||
|
How to replicate the results?
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Clone RL-Zoo and checkout the branch ``feat/trpo``:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
git clone https://github.com/cyprienc/rl-baselines3-zoo
|
||||||
|
cd rl-baselines3-zoo/
|
||||||
|
|
||||||
|
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python train.py --algo tqc --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000
|
||||||
|
|
||||||
|
|
||||||
|
Plot the results:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python scripts/all_plots.py -a trpo -e HalfCheetah Ant Hopper Walker2d Swimmer -f logs/ -o logs/trpo_results
|
||||||
|
python scripts/plot_from_file.py -i logs/trpo_results.pkl -latex -l TRPO
|
||||||
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
|
||||||
|
.. autoclass:: TRPO
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
|
||||||
|
.. _trpo_policies:
|
||||||
|
|
||||||
|
TRPO Policies
|
||||||
|
-------------
|
||||||
|
|
||||||
|
.. autoclass:: MlpPolicy
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
|
||||||
|
.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy
|
||||||
|
:members:
|
||||||
|
:noindex:
|
||||||
|
|
||||||
|
.. autoclass:: CnnPolicy
|
||||||
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy
|
||||||
|
:members:
|
||||||
|
:noindex:
|
||||||
|
|
||||||
|
.. autoclass:: MultiInputPolicy
|
||||||
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy
|
||||||
|
:members:
|
||||||
|
:noindex:
|
||||||
|
|
@ -3,6 +3,7 @@ import os
|
||||||
from sb3_contrib.ppo_mask import MaskablePPO
|
from sb3_contrib.ppo_mask import MaskablePPO
|
||||||
from sb3_contrib.qrdqn import QRDQN
|
from sb3_contrib.qrdqn import QRDQN
|
||||||
from sb3_contrib.tqc import TQC
|
from sb3_contrib.tqc import TQC
|
||||||
|
from sb3_contrib.trpo import TRPO
|
||||||
|
|
||||||
# Read version from file
|
# Read version from file
|
||||||
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
|
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Optional
|
from typing import Callable, Optional, Sequence
|
||||||
|
|
||||||
import torch as th
|
import torch as th
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
def quantile_huber_loss(
|
def quantile_huber_loss(
|
||||||
|
|
@ -67,3 +68,96 @@ def quantile_huber_loss(
|
||||||
else:
|
else:
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def conjugate_gradient_solver(
|
||||||
|
matrix_vector_dot_fn: Callable[[th.Tensor], th.Tensor],
|
||||||
|
b,
|
||||||
|
max_iter=10,
|
||||||
|
residual_tol=1e-10,
|
||||||
|
) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
Finds an approximate solution to a set of linear equations Ax = b
|
||||||
|
|
||||||
|
Sources:
|
||||||
|
- https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py
|
||||||
|
- https://github.com/joschu/modular_rl/blob/master/modular_rl/trpo.py#L122
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- https://epubs.siam.org/doi/abs/10.1137/1.9781611971446.ch6
|
||||||
|
|
||||||
|
:param matrix_vector_dot_fn:
|
||||||
|
a function that right multiplies a matrix A by a vector v
|
||||||
|
:param b:
|
||||||
|
the right hand term in the set of linear equations Ax = b
|
||||||
|
:param max_iter:
|
||||||
|
the maximum number of iterations (default is 10)
|
||||||
|
:param residual_tol:
|
||||||
|
residual tolerance for early stopping of the solving (default is 1e-10)
|
||||||
|
:return x:
|
||||||
|
the approximate solution to the system of equations defined by `matrix_vector_dot_fn`
|
||||||
|
and b
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The vector is not initialized at 0 because of the instability issues when the gradient becomes small.
|
||||||
|
# A small random gaussian noise is used for the initialization.
|
||||||
|
x = 1e-4 * th.randn_like(b)
|
||||||
|
residual = b - matrix_vector_dot_fn(x)
|
||||||
|
# Equivalent to th.linalg.norm(residual) ** 2 (L2 norm squared)
|
||||||
|
residual_squared_norm = th.matmul(residual, residual)
|
||||||
|
|
||||||
|
if residual_squared_norm < residual_tol:
|
||||||
|
# If the gradient becomes extremely small
|
||||||
|
# The denominator in alpha will become zero
|
||||||
|
# Leading to a division by zero
|
||||||
|
return x
|
||||||
|
|
||||||
|
p = residual.clone()
|
||||||
|
|
||||||
|
for i in range(max_iter):
|
||||||
|
# A @ p (matrix vector multiplication)
|
||||||
|
A_dot_p = matrix_vector_dot_fn(p)
|
||||||
|
|
||||||
|
alpha = residual_squared_norm / p.dot(A_dot_p)
|
||||||
|
x += alpha * p
|
||||||
|
|
||||||
|
if i == max_iter - 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
residual -= alpha * A_dot_p
|
||||||
|
new_residual_squared_norm = th.matmul(residual, residual)
|
||||||
|
|
||||||
|
if new_residual_squared_norm < residual_tol:
|
||||||
|
return x
|
||||||
|
|
||||||
|
beta = new_residual_squared_norm / residual_squared_norm
|
||||||
|
residual_squared_norm = new_residual_squared_norm
|
||||||
|
p = residual + beta * p
|
||||||
|
|
||||||
|
|
||||||
|
def flat_grad(
|
||||||
|
output,
|
||||||
|
parameters: Sequence[nn.parameter.Parameter],
|
||||||
|
create_graph: bool = False,
|
||||||
|
retain_graph: bool = False,
|
||||||
|
) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
Returns the gradients of the passed sequence of parameters into a flat gradient.
|
||||||
|
Order of parameters is preserved.
|
||||||
|
|
||||||
|
:param output: functional output to compute the gradient for
|
||||||
|
:param parameters: sequence of ``Parameter``
|
||||||
|
:param retain_graph: – If ``False``, the graph used to compute the grad will be freed.
|
||||||
|
Defaults to the value of ``create_graph``.
|
||||||
|
:param create_graph: – If ``True``, graph of the derivative will be constructed,
|
||||||
|
allowing to compute higher order derivative products. Default: ``False``.
|
||||||
|
:return: Tensor containing the flattened gradients
|
||||||
|
"""
|
||||||
|
grads = th.autograd.grad(
|
||||||
|
output,
|
||||||
|
parameters,
|
||||||
|
create_graph=create_graph,
|
||||||
|
retain_graph=retain_graph,
|
||||||
|
allow_unused=True,
|
||||||
|
)
|
||||||
|
return th.cat([th.ravel(grad) for grad in grads if grad is not None])
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
|
from sb3_contrib.trpo.trpo import TRPO
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
# This file is here just to define MlpPolicy/CnnPolicy
|
||||||
|
# that work for TRPO
|
||||||
|
from stable_baselines3.common.policies import (
|
||||||
|
ActorCriticCnnPolicy,
|
||||||
|
ActorCriticPolicy,
|
||||||
|
MultiInputActorCriticPolicy,
|
||||||
|
register_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
MlpPolicy = ActorCriticPolicy
|
||||||
|
CnnPolicy = ActorCriticCnnPolicy
|
||||||
|
MultiInputPolicy = MultiInputActorCriticPolicy
|
||||||
|
|
||||||
|
register_policy("MlpPolicy", ActorCriticPolicy)
|
||||||
|
register_policy("CnnPolicy", ActorCriticCnnPolicy)
|
||||||
|
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||||
|
|
@ -0,0 +1,421 @@
|
||||||
|
import copy
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from gym import spaces
|
||||||
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
|
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||||
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule
|
||||||
|
from stable_baselines3.common.utils import explained_variance
|
||||||
|
from torch import nn
|
||||||
|
from torch.distributions import kl_divergence
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
|
||||||
|
|
||||||
|
|
||||||
|
class TRPO(OnPolicyAlgorithm):
|
||||||
|
"""
|
||||||
|
Trust Region Policy Optimization (TRPO)
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1502.05477
|
||||||
|
Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
|
||||||
|
and Stable Baselines (TRPO from https://github.com/hill-a/stable-baselines)
|
||||||
|
|
||||||
|
Introduction to TRPO: https://spinningup.openai.com/en/latest/algorithms/trpo.html
|
||||||
|
|
||||||
|
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||||
|
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||||
|
:param learning_rate: The learning rate for the value function, it can be a function
|
||||||
|
of the current progress remaining (from 1 to 0)
|
||||||
|
:param n_steps: The number of steps to run for each environment per update
|
||||||
|
(i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
|
||||||
|
NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
|
||||||
|
See https://github.com/pytorch/pytorch/issues/29372
|
||||||
|
:param batch_size: Minibatch size for the value function
|
||||||
|
:param gamma: Discount factor
|
||||||
|
:param cg_max_steps: maximum number of steps in the Conjugate Gradient algorithm
|
||||||
|
for computing the Hessian vector product
|
||||||
|
:param cg_damping: damping in the Hessian vector product computation
|
||||||
|
:param line_search_shrinking_factor: step-size reduction factor for the line-search
|
||||||
|
(i.e., ``theta_new = theta + alpha^i * step``)
|
||||||
|
:param line_search_max_iter: maximum number of iteration
|
||||||
|
for the backtracking line-search
|
||||||
|
:param n_critic_updates: number of critic updates per policy update
|
||||||
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||||
|
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
|
||||||
|
instead of action noise exploration (default: False)
|
||||||
|
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||||
|
Default: -1 (only sample at the beginning of the rollout)
|
||||||
|
:param normalize_advantage: Whether to normalize or not the advantage
|
||||||
|
:param target_kl: Target Kullback-Leibler divergence between updates.
|
||||||
|
Should be small for stability. Values like 0.01, 0.05.
|
||||||
|
:param sub_sampling_factor: Sub-sample the batch to make computation faster
|
||||||
|
see p40-42 of John Schulman thesis http://joschu.net/docs/thesis.pdf
|
||||||
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
|
:param create_eval_env: Whether to create a second environment that will be
|
||||||
|
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||||
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
|
:param seed: Seed for the pseudo random generators
|
||||||
|
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
||||||
|
Setting it to auto, the code will be run on the GPU if possible.
|
||||||
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
policy: Union[str, Type[ActorCriticPolicy]],
|
||||||
|
env: Union[GymEnv, str],
|
||||||
|
learning_rate: Union[float, Schedule] = 1e-3,
|
||||||
|
n_steps: int = 2048,
|
||||||
|
batch_size: int = 128,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
cg_max_steps: int = 15,
|
||||||
|
cg_damping: float = 0.1,
|
||||||
|
line_search_shrinking_factor: float = 0.8,
|
||||||
|
line_search_max_iter: int = 10,
|
||||||
|
n_critic_updates: int = 10,
|
||||||
|
gae_lambda: float = 0.95,
|
||||||
|
use_sde: bool = False,
|
||||||
|
sde_sample_freq: int = -1,
|
||||||
|
normalize_advantage: bool = True,
|
||||||
|
target_kl: float = 0.01,
|
||||||
|
sub_sampling_factor: int = 1,
|
||||||
|
tensorboard_log: Optional[str] = None,
|
||||||
|
create_eval_env: bool = False,
|
||||||
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
verbose: int = 0,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
device: Union[th.device, str] = "auto",
|
||||||
|
_init_setup_model: bool = True,
|
||||||
|
):
|
||||||
|
|
||||||
|
super(TRPO, self).__init__(
|
||||||
|
policy,
|
||||||
|
env,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
n_steps=n_steps,
|
||||||
|
gamma=gamma,
|
||||||
|
gae_lambda=gae_lambda,
|
||||||
|
ent_coef=0.0, # entropy bonus is not used by TRPO
|
||||||
|
vf_coef=0.0, # value function is optimized separately
|
||||||
|
max_grad_norm=0.0,
|
||||||
|
use_sde=use_sde,
|
||||||
|
sde_sample_freq=sde_sample_freq,
|
||||||
|
policy_base=ActorCriticPolicy,
|
||||||
|
tensorboard_log=tensorboard_log,
|
||||||
|
policy_kwargs=policy_kwargs,
|
||||||
|
verbose=verbose,
|
||||||
|
device=device,
|
||||||
|
create_eval_env=create_eval_env,
|
||||||
|
seed=seed,
|
||||||
|
_init_setup_model=False,
|
||||||
|
supported_action_spaces=(
|
||||||
|
spaces.Box,
|
||||||
|
spaces.Discrete,
|
||||||
|
spaces.MultiDiscrete,
|
||||||
|
spaces.MultiBinary,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.normalize_advantage = normalize_advantage
|
||||||
|
# Sanity check, otherwise it will lead to noisy gradient and NaN
|
||||||
|
# because of the advantage normalization
|
||||||
|
if self.env is not None:
|
||||||
|
# Check that `n_steps * n_envs > 1` to avoid NaN
|
||||||
|
# when doing advantage normalization
|
||||||
|
buffer_size = self.env.num_envs * self.n_steps
|
||||||
|
if normalize_advantage:
|
||||||
|
assert buffer_size > 1, (
|
||||||
|
"`n_steps * n_envs` must be greater than 1. "
|
||||||
|
f"Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
|
||||||
|
)
|
||||||
|
# Check that the rollout buffer size is a multiple of the mini-batch size
|
||||||
|
untruncated_batches = buffer_size // batch_size
|
||||||
|
if buffer_size % batch_size > 0:
|
||||||
|
warnings.warn(
|
||||||
|
f"You have specified a mini-batch size of {batch_size},"
|
||||||
|
f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
|
||||||
|
f" after every {untruncated_batches} untruncated mini-batches,"
|
||||||
|
f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
|
||||||
|
f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
|
||||||
|
f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
|
||||||
|
)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
# Conjugate gradients parameters
|
||||||
|
self.cg_max_steps = cg_max_steps
|
||||||
|
self.cg_damping = cg_damping
|
||||||
|
# Backtracking line search parameters
|
||||||
|
self.line_search_shrinking_factor = line_search_shrinking_factor
|
||||||
|
self.line_search_max_iter = line_search_max_iter
|
||||||
|
self.target_kl = target_kl
|
||||||
|
self.n_critic_updates = n_critic_updates
|
||||||
|
self.sub_sampling_factor = sub_sampling_factor
|
||||||
|
|
||||||
|
if _init_setup_model:
|
||||||
|
self._setup_model()
|
||||||
|
|
||||||
|
def _compute_actor_grad(
|
||||||
|
self, kl_div: th.Tensor, policy_objective: th.Tensor
|
||||||
|
) -> Tuple[List[nn.Parameter], th.Tensor, th.Tensor, List[Tuple[int, ...]]]:
|
||||||
|
"""
|
||||||
|
Compute actor gradients for kl div and surrogate objectives.
|
||||||
|
|
||||||
|
:param kl_div: The KL divergence objective
|
||||||
|
:param policy_objective: The surrogate objective ("classic" policy gradient)
|
||||||
|
:return: List of actor params, gradients and gradients shape.
|
||||||
|
"""
|
||||||
|
# This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence
|
||||||
|
# The policy objective is also called surrogate objective
|
||||||
|
policy_objective_gradients = []
|
||||||
|
# Contains the gradients of the KL divergence
|
||||||
|
grad_kl = []
|
||||||
|
# Contains the shape of the gradients of the KL divergence w.r.t each parameter
|
||||||
|
# This way the flattened gradient can be reshaped back into the original shapes and applied to
|
||||||
|
# the parameters
|
||||||
|
grad_shape = []
|
||||||
|
# Contains the parameters which have non-zeros KL divergence gradients
|
||||||
|
# The list is used during the line-search to apply the step to each parameters
|
||||||
|
actor_params = []
|
||||||
|
|
||||||
|
for name, param in self.policy.named_parameters():
|
||||||
|
# Skip parameters related to value function based on name
|
||||||
|
# this work for built-in policies only (not custom ones)
|
||||||
|
if "value" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For each parameter we compute the gradient of the KL divergence w.r.t to that parameter
|
||||||
|
kl_param_grad, *_ = th.autograd.grad(
|
||||||
|
kl_div,
|
||||||
|
param,
|
||||||
|
create_graph=True,
|
||||||
|
retain_graph=True,
|
||||||
|
allow_unused=True,
|
||||||
|
only_inputs=True,
|
||||||
|
)
|
||||||
|
# If the gradient is not zero (not None), we store the parameter in the actor_params list
|
||||||
|
# and add the gradient and its shape to grad_kl and grad_shape respectively
|
||||||
|
if kl_param_grad is not None:
|
||||||
|
# If the parameter impacts the KL divergence (i.e. the policy)
|
||||||
|
# we compute the gradient of the policy objective w.r.t to the parameter
|
||||||
|
# this avoids computing the gradient if it's not going to be used in the conjugate gradient step
|
||||||
|
policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True)
|
||||||
|
|
||||||
|
grad_shape.append(kl_param_grad.shape)
|
||||||
|
grad_kl.append(kl_param_grad.view(-1))
|
||||||
|
policy_objective_gradients.append(policy_objective_grad.view(-1))
|
||||||
|
actor_params.append(param)
|
||||||
|
|
||||||
|
# Gradients are concatenated before the conjugate gradient step
|
||||||
|
policy_objective_gradients = th.cat(policy_objective_gradients)
|
||||||
|
grad_kl = th.cat(grad_kl)
|
||||||
|
return actor_params, policy_objective_gradients, grad_kl, grad_shape
|
||||||
|
|
||||||
|
def train(self) -> None:
|
||||||
|
"""
|
||||||
|
Update policy using the currently gathered rollout buffer.
|
||||||
|
"""
|
||||||
|
# Switch to train mode (this affects batch norm / dropout)
|
||||||
|
self.policy.set_training_mode(True)
|
||||||
|
# Update optimizer learning rate
|
||||||
|
self._update_learning_rate(self.policy.optimizer)
|
||||||
|
|
||||||
|
policy_objective_values = []
|
||||||
|
kl_divergences = []
|
||||||
|
line_search_results = []
|
||||||
|
value_losses = []
|
||||||
|
|
||||||
|
# This will only loop once (get all data in one go)
|
||||||
|
for rollout_data in self.rollout_buffer.get(batch_size=None):
|
||||||
|
|
||||||
|
# Optional: sub-sample data for faster computation
|
||||||
|
if self.sub_sampling_factor > 1:
|
||||||
|
rollout_data = RolloutBufferSamples(
|
||||||
|
rollout_data.observations[:: self.sub_sampling_factor],
|
||||||
|
rollout_data.actions[:: self.sub_sampling_factor],
|
||||||
|
None, # old values, not used here
|
||||||
|
rollout_data.old_log_prob[:: self.sub_sampling_factor],
|
||||||
|
rollout_data.advantages[:: self.sub_sampling_factor],
|
||||||
|
None, # returns, not used here
|
||||||
|
)
|
||||||
|
|
||||||
|
actions = rollout_data.actions
|
||||||
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
|
# Convert discrete action from float to long
|
||||||
|
actions = rollout_data.actions.long().flatten()
|
||||||
|
|
||||||
|
# Re-sample the noise matrix because the log_std has changed
|
||||||
|
if self.use_sde:
|
||||||
|
# batch_size is only used for the value function
|
||||||
|
self.policy.reset_noise(actions.shape[0])
|
||||||
|
|
||||||
|
with th.no_grad():
|
||||||
|
# Note: is copy enough, no need for deepcopy?
|
||||||
|
# If using gSDE and deepcopy, we need to use `old_distribution.distribution`
|
||||||
|
# directly to avoid PyTorch errors.
|
||||||
|
old_distribution = copy.copy(self.policy.get_distribution(rollout_data.observations))
|
||||||
|
|
||||||
|
distribution = self.policy.get_distribution(rollout_data.observations)
|
||||||
|
log_prob = distribution.log_prob(actions)
|
||||||
|
|
||||||
|
advantages = rollout_data.advantages
|
||||||
|
if self.normalize_advantage:
|
||||||
|
advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8)
|
||||||
|
|
||||||
|
# ratio between old and new policy, should be one at the first iteration
|
||||||
|
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||||
|
|
||||||
|
# surrogate policy objective
|
||||||
|
policy_objective = (advantages * ratio).mean()
|
||||||
|
|
||||||
|
# KL divergence
|
||||||
|
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean()
|
||||||
|
|
||||||
|
# Surrogate & KL gradient
|
||||||
|
self.policy.optimizer.zero_grad()
|
||||||
|
|
||||||
|
actor_params, policy_objective_gradients, grad_kl, grad_shape = self._compute_actor_grad(kl_div, policy_objective)
|
||||||
|
|
||||||
|
# Hessian-vector dot product function used in the conjugate gradient step
|
||||||
|
hessian_vector_product_fn = partial(self.hessian_vector_product, actor_params, grad_kl)
|
||||||
|
|
||||||
|
# Computing search direction
|
||||||
|
search_direction = conjugate_gradient_solver(
|
||||||
|
hessian_vector_product_fn,
|
||||||
|
policy_objective_gradients,
|
||||||
|
max_iter=self.cg_max_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Maximal step length
|
||||||
|
line_search_max_step_size = 2 * self.target_kl
|
||||||
|
line_search_max_step_size /= th.matmul(
|
||||||
|
search_direction, hessian_vector_product_fn(search_direction, retain_graph=False)
|
||||||
|
)
|
||||||
|
line_search_max_step_size = th.sqrt(line_search_max_step_size)
|
||||||
|
|
||||||
|
line_search_backtrack_coeff = 1.0
|
||||||
|
original_actor_params = [param.detach().clone() for param in actor_params]
|
||||||
|
|
||||||
|
is_line_search_success = False
|
||||||
|
with th.no_grad():
|
||||||
|
# Line-search (backtracking)
|
||||||
|
for _ in range(self.line_search_max_iter):
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
# Applying the scaled step direction
|
||||||
|
for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape):
|
||||||
|
n_params = param.numel()
|
||||||
|
param.data = (
|
||||||
|
original_param.data
|
||||||
|
+ line_search_backtrack_coeff
|
||||||
|
* line_search_max_step_size
|
||||||
|
* search_direction[start_idx : (start_idx + n_params)].view(shape)
|
||||||
|
)
|
||||||
|
start_idx += n_params
|
||||||
|
|
||||||
|
# Recomputing the policy log-probabilities
|
||||||
|
distribution = self.policy.get_distribution(rollout_data.observations)
|
||||||
|
log_prob = distribution.log_prob(actions)
|
||||||
|
|
||||||
|
# New policy objective
|
||||||
|
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||||
|
new_policy_objective = (advantages * ratio).mean()
|
||||||
|
|
||||||
|
# New KL-divergence
|
||||||
|
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean()
|
||||||
|
|
||||||
|
# Constraint criteria:
|
||||||
|
# we need to improve the surrogate policy objective
|
||||||
|
# while being close enough (in term of kl div) to the old policy
|
||||||
|
if (kl_div < self.target_kl) and (new_policy_objective > policy_objective):
|
||||||
|
is_line_search_success = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reducing step size if line-search wasn't successful
|
||||||
|
line_search_backtrack_coeff *= self.line_search_shrinking_factor
|
||||||
|
|
||||||
|
line_search_results.append(is_line_search_success)
|
||||||
|
|
||||||
|
if not is_line_search_success:
|
||||||
|
# If the line-search wasn't successful we revert to the original parameters
|
||||||
|
for param, original_param in zip(actor_params, original_actor_params):
|
||||||
|
param.data = original_param.data.clone()
|
||||||
|
|
||||||
|
policy_objective_values.append(policy_objective.item())
|
||||||
|
kl_divergences.append(0)
|
||||||
|
else:
|
||||||
|
policy_objective_values.append(new_policy_objective.item())
|
||||||
|
kl_divergences.append(kl_div.item())
|
||||||
|
|
||||||
|
# Critic update
|
||||||
|
for _ in range(self.n_critic_updates):
|
||||||
|
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||||||
|
values_pred = self.policy.predict_values(rollout_data.observations)
|
||||||
|
value_loss = F.mse_loss(rollout_data.returns, values_pred.flatten())
|
||||||
|
value_losses.append(value_loss.item())
|
||||||
|
|
||||||
|
self.policy.optimizer.zero_grad()
|
||||||
|
value_loss.backward()
|
||||||
|
# Removing gradients of parameters shared with the actor
|
||||||
|
# otherwise it defeats the purposes of the KL constraint
|
||||||
|
for param in actor_params:
|
||||||
|
param.grad = None
|
||||||
|
self.policy.optimizer.step()
|
||||||
|
|
||||||
|
self._n_updates += 1
|
||||||
|
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
self.logger.record("train/policy_objective", np.mean(policy_objective_values))
|
||||||
|
self.logger.record("train/value_loss", np.mean(value_losses))
|
||||||
|
self.logger.record("train/kl_divergence_loss", np.mean(kl_divergences))
|
||||||
|
self.logger.record("train/explained_variance", explained_var)
|
||||||
|
self.logger.record("train/is_line_search_success", np.mean(line_search_results))
|
||||||
|
if hasattr(self.policy, "log_std"):
|
||||||
|
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||||||
|
|
||||||
|
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||||
|
|
||||||
|
def hessian_vector_product(
|
||||||
|
self, params: List[nn.Parameter], grad_kl: th.Tensor, vector: th.Tensor, retain_graph: bool = True
|
||||||
|
) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
Computes the matrix-vector product with the Fisher information matrix.
|
||||||
|
|
||||||
|
:param params: list of parameters used to compute the Hessian
|
||||||
|
:param grad_kl: flattened gradient of the KL divergence between the old and new policy
|
||||||
|
:param vector: vector to compute the dot product the hessian-vector dot product with
|
||||||
|
:param retain_graph: if True, the graph will be kept after computing the Hessian
|
||||||
|
:return: Hessian-vector dot product (with damping)
|
||||||
|
"""
|
||||||
|
jacobian_vector_product = (grad_kl * vector).sum()
|
||||||
|
return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector
|
||||||
|
|
||||||
|
def learn(
|
||||||
|
self,
|
||||||
|
total_timesteps: int,
|
||||||
|
callback: MaybeCallback = None,
|
||||||
|
log_interval: int = 1,
|
||||||
|
eval_env: Optional[GymEnv] = None,
|
||||||
|
eval_freq: int = -1,
|
||||||
|
n_eval_episodes: int = 5,
|
||||||
|
tb_log_name: str = "TRPO",
|
||||||
|
eval_log_path: Optional[str] = None,
|
||||||
|
reset_num_timesteps: bool = True,
|
||||||
|
) -> OnPolicyAlgorithm:
|
||||||
|
|
||||||
|
return super(TRPO, self).learn(
|
||||||
|
total_timesteps=total_timesteps,
|
||||||
|
callback=callback,
|
||||||
|
log_interval=log_interval,
|
||||||
|
eval_env=eval_env,
|
||||||
|
eval_freq=eval_freq,
|
||||||
|
n_eval_episodes=n_eval_episodes,
|
||||||
|
tb_log_name=tb_log_name,
|
||||||
|
eval_log_path=eval_log_path,
|
||||||
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
|
)
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.3.1a6
|
1.3.1a7
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ per-file-ignores =
|
||||||
./sb3_contrib/ppo_mask/__init__.py:F401
|
./sb3_contrib/ppo_mask/__init__.py:F401
|
||||||
./sb3_contrib/qrdqn/__init__.py:F401
|
./sb3_contrib/qrdqn/__init__.py:F401
|
||||||
./sb3_contrib/tqc/__init__.py:F401
|
./sb3_contrib/tqc/__init__.py:F401
|
||||||
|
./sb3_contrib/trpo/__init__.py:F401
|
||||||
./sb3_contrib/common/vec_env/wrappers/__init__.py:F401
|
./sb3_contrib/common/vec_env/wrappers/__init__.py:F401
|
||||||
./sb3_contrib/common/wrappers/__init__.py:F401
|
./sb3_contrib/common/wrappers/__init__.py:F401
|
||||||
./sb3_contrib/common/envs/__init__.py:F401
|
./sb3_contrib/common/envs/__init__.py:F401
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,10 @@ from stable_baselines3.common.envs import FakeImageEnv
|
||||||
from stable_baselines3.common.utils import zip_strict
|
from stable_baselines3.common.utils import zip_strict
|
||||||
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
|
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
|
||||||
|
|
||||||
from sb3_contrib import QRDQN, TQC
|
from sb3_contrib import QRDQN, TQC, TRPO
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
@pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO])
|
||||||
def test_cnn(tmp_path, model_class):
|
def test_cnn(tmp_path, model_class):
|
||||||
SAVE_NAME = "cnn_model.zip"
|
SAVE_NAME = "cnn_model.zip"
|
||||||
# Fake grayscale with frameskip
|
# Fake grayscale with frameskip
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from stable_baselines3.common.envs import SimpleMultiObsEnv
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
|
||||||
|
|
||||||
from sb3_contrib import QRDQN, TQC
|
from sb3_contrib import QRDQN, TQC, TRPO
|
||||||
|
|
||||||
|
|
||||||
class DummyDictEnv(gym.Env):
|
class DummyDictEnv(gym.Env):
|
||||||
|
|
@ -78,7 +78,7 @@ class DummyDictEnv(gym.Env):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
|
||||||
def test_consistency(model_class):
|
def test_consistency(model_class):
|
||||||
"""
|
"""
|
||||||
Make sure that dict obs with vector only vs using flatten obs is equivalent.
|
Make sure that dict obs with vector only vs using flatten obs is equivalent.
|
||||||
|
|
@ -94,7 +94,7 @@ def test_consistency(model_class):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
n_steps = 256
|
n_steps = 256
|
||||||
|
|
||||||
if model_class in {}:
|
if model_class in {TRPO}:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
n_steps=128,
|
n_steps=128,
|
||||||
)
|
)
|
||||||
|
|
@ -124,7 +124,7 @@ def test_consistency(model_class):
|
||||||
assert np.allclose(action_1, action_2)
|
assert np.allclose(action_1, action_2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
|
||||||
@pytest.mark.parametrize("channel_last", [False, True])
|
@pytest.mark.parametrize("channel_last", [False, True])
|
||||||
def test_dict_spaces(model_class, channel_last):
|
def test_dict_spaces(model_class, channel_last):
|
||||||
"""
|
"""
|
||||||
|
|
@ -138,11 +138,11 @@ def test_dict_spaces(model_class, channel_last):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
n_steps = 256
|
n_steps = 256
|
||||||
|
|
||||||
if model_class in {}:
|
if model_class in {TRPO}:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
n_steps=128,
|
n_steps=128,
|
||||||
policy_kwargs=dict(
|
policy_kwargs=dict(
|
||||||
net_arch=[32],
|
net_arch=[dict(pi=[32], vf=[32])],
|
||||||
features_extractor_kwargs=dict(cnn_output_dim=32),
|
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -169,7 +169,7 @@ def test_dict_spaces(model_class, channel_last):
|
||||||
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
|
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
|
||||||
@pytest.mark.parametrize("channel_last", [False, True])
|
@pytest.mark.parametrize("channel_last", [False, True])
|
||||||
def test_dict_vec_framestack(model_class, channel_last):
|
def test_dict_vec_framestack(model_class, channel_last):
|
||||||
"""
|
"""
|
||||||
|
|
@ -187,11 +187,11 @@ def test_dict_vec_framestack(model_class, channel_last):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
n_steps = 256
|
n_steps = 256
|
||||||
|
|
||||||
if model_class in {}:
|
if model_class in {TRPO}:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
n_steps=128,
|
n_steps=128,
|
||||||
policy_kwargs=dict(
|
policy_kwargs=dict(
|
||||||
net_arch=[32],
|
net_arch=[dict(pi=[32], vf=[32])],
|
||||||
features_extractor_kwargs=dict(cnn_output_dim=32),
|
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -218,7 +218,7 @@ def test_dict_vec_framestack(model_class, channel_last):
|
||||||
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
|
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
|
||||||
def test_vec_normalize(model_class):
|
def test_vec_normalize(model_class):
|
||||||
"""
|
"""
|
||||||
Additional tests to check observation space support
|
Additional tests to check observation space support
|
||||||
|
|
@ -230,11 +230,11 @@ def test_vec_normalize(model_class):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
n_steps = 256
|
n_steps = 256
|
||||||
|
|
||||||
if model_class in {}:
|
if model_class in {TRPO}:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
n_steps=128,
|
n_steps=128,
|
||||||
policy_kwargs=dict(
|
policy_kwargs=dict(
|
||||||
net_arch=[32],
|
net_arch=[dict(pi=[32], vf=[32])],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import gym
|
||||||
import pytest
|
import pytest
|
||||||
from stable_baselines3.common.env_util import make_vec_env
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
|
|
||||||
from sb3_contrib import QRDQN, TQC
|
from sb3_contrib import QRDQN, TQC, TRPO
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
||||||
|
|
@ -60,6 +60,28 @@ def test_qrdqn():
|
||||||
model.learn(total_timesteps=500, eval_freq=250)
|
model.learn(total_timesteps=500, eval_freq=250)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
|
||||||
|
def test_trpo(env_id):
|
||||||
|
model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||||
|
model.learn(total_timesteps=500)
|
||||||
|
|
||||||
|
|
||||||
|
def test_trpo_params():
|
||||||
|
# Test with gSDE and subsampling
|
||||||
|
model = TRPO(
|
||||||
|
"MlpPolicy",
|
||||||
|
"Pendulum-v0",
|
||||||
|
n_steps=64,
|
||||||
|
batch_size=32,
|
||||||
|
use_sde=True,
|
||||||
|
sub_sampling_factor=4,
|
||||||
|
seed=0,
|
||||||
|
policy_kwargs=dict(net_arch=[dict(pi=[32], vf=[32])]),
|
||||||
|
verbose=1,
|
||||||
|
)
|
||||||
|
model.learn(total_timesteps=500)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||||
def test_offpolicy_multi_env(model_class):
|
def test_offpolicy_multi_env(model_class):
|
||||||
if model_class in [TQC]:
|
if model_class in [TQC]:
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,9 @@ from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnv
|
||||||
from stable_baselines3.common.utils import get_device
|
from stable_baselines3.common.utils import get_device
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||||
|
|
||||||
from sb3_contrib import QRDQN, TQC
|
from sb3_contrib import QRDQN, TQC, TRPO
|
||||||
|
|
||||||
MODEL_LIST = [TQC, QRDQN]
|
MODEL_LIST = [TQC, QRDQN, TRPO]
|
||||||
|
|
||||||
|
|
||||||
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
||||||
|
|
@ -277,6 +277,11 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
||||||
learning_starts=100,
|
learning_starts=100,
|
||||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
kwargs = dict(
|
||||||
|
n_steps=128,
|
||||||
|
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||||
|
)
|
||||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN)
|
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN)
|
||||||
|
|
||||||
# Reduce number of quantiles for faster tests
|
# Reduce number of quantiles for faster tests
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch as th
|
import torch as th
|
||||||
|
from stable_baselines3.common.utils import set_random_seed
|
||||||
|
|
||||||
from sb3_contrib.common.utils import quantile_huber_loss
|
from sb3_contrib import TRPO
|
||||||
|
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad, quantile_huber_loss
|
||||||
|
|
||||||
|
|
||||||
def test_quantile_huber_loss():
|
def test_quantile_huber_loss():
|
||||||
|
|
@ -17,3 +19,46 @@ def test_quantile_huber_loss():
|
||||||
quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4))
|
quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4))
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4))
|
quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4))
|
||||||
|
|
||||||
|
|
||||||
|
def test_cg():
|
||||||
|
# Test that conjugate gradient can actually solve
|
||||||
|
# Ax = b when the A^-1 is known
|
||||||
|
set_random_seed(4)
|
||||||
|
A = th.ones(3, 3)
|
||||||
|
# Symmetric matrix
|
||||||
|
A[0, 1] = 2
|
||||||
|
A[1, 0] = 2
|
||||||
|
x = th.ones(3) + th.rand(3)
|
||||||
|
b = A @ x
|
||||||
|
|
||||||
|
def matrix_vector_dot_func(vector):
|
||||||
|
return A @ vector
|
||||||
|
|
||||||
|
x_approx = conjugate_gradient_solver(matrix_vector_dot_func, b, max_iter=5, residual_tol=1e-10)
|
||||||
|
assert th.allclose(x_approx, x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_flat_grad():
|
||||||
|
n_parameters = 12 # 3 * (2 * 2)
|
||||||
|
x = th.nn.Parameter(th.ones(2, 2, requires_grad=True))
|
||||||
|
y = (x ** 2).sum()
|
||||||
|
flat_grad_out = flat_grad(y, [x, x, x])
|
||||||
|
assert len(flat_grad_out.shape) == 1
|
||||||
|
# dy/dx = 2
|
||||||
|
assert th.allclose(flat_grad_out, th.ones(n_parameters) * 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_trpo_warnings():
|
||||||
|
"""Test that TRPO warns and errors correctly on
|
||||||
|
problematic rollout buffer sizes"""
|
||||||
|
|
||||||
|
# Only 1 step: advantage normalization will return NaN
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
TRPO("MlpPolicy", "Pendulum-v0", n_steps=1)
|
||||||
|
# One step not advantage normalization: ok
|
||||||
|
TRPO("MlpPolicy", "Pendulum-v0", n_steps=1, normalize_advantage=False, batch_size=1)
|
||||||
|
|
||||||
|
# Truncated mini-batch
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
TRPO("MlpPolicy", "Pendulum-v0", n_steps=6, batch_size=8)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue