Recurrent PPO (#53)

* Running (not working yet) version of recurrent PPO

* Fixes for multi envs

* Save WIP, rework the sampling

* Add Box support

* Fix sample order

* Being cleanup, code is broken (again)

* First working version (no shared lstm)

* Start cleanup

* Try rnn with value function

* Re-enable batch size

* Deactivate vf rnn

* Allow any batch size

* Add support for evaluation

* Add CNN support

* Fix start of sequence

* Allow shared LSTM

* Rename mask to episode_start

* Fix type hint

* Enable LSTM for critic

* Clean code

* Fix for CNN LSTM

* Fix sampling with n_layers > 1

* Add std logger

* Update wording

* Rename and add dict obs support

* Fixes for dict obs support

* Do not run slow tests

* Fix doc

* Update recurrent PPO example

* Update README

* Use Pendulum-v1 for tests

* Fix image env

* Speedup LSTM forward pass (#63)

* added more efficient lstm implementation

* Rename and add comment

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>

* Fixes

* Remove OpenAI sampling and improve coverage

* Sync with SB3 PPO

* Pass state shape and allow lstm kwargs

* Update tests

* Add masking for padded sequences

* Update default in perf test

* Remove TODO, mask is now working

* Add helper to remove duplicated code, remove hack for padding

* Enable LSTM critic and raise threshold for cartpole with no vel

* Fix tests

* Update doc and tests

* Doc fix

* Fix for new Sphinx version

* Fix doc note

* Switch to batch first, no more additional swap

* Add comments and mask entropy loss

Co-authored-by: Neville Walo <43504521+Walon1998@users.noreply.github.com>
This commit is contained in:
Antonin RAFFIN 2022-05-29 22:31:12 -04:00 committed by GitHub
parent cd592a111f
commit 75b2de1399
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1988 additions and 28 deletions

View File

@ -25,11 +25,12 @@ We hope this allows us to provide reliable implementations following stable-base
See documentation for the full list of included features.
**RL Algorithms**:
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)

View File

@ -9,14 +9,17 @@ along with some useful characteristics: support for discrete/continuous actions,
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
ARS ✔️ ❌️ ❌ ❌ ✔️
MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️
TQC ✔️ ❌ ❌ ❌ ✔️
TRPO ✔️ ✔️ ✔️ ✔️ ✔️
============ =========== ============ ================= =============== ================
.. note::
Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
``Tuple`` observation spaces are not supported by any environment,
however, single-level ``Dict`` spaces are
Actions ``gym.spaces``:

View File

@ -71,3 +71,38 @@ Train an agent using Augmented Random Search (ARS) agent on the Pendulum environ
model = ARS("LinearPolicy", "Pendulum-v1", verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("ars_pendulum")
RecurrentPPO
------------
Train a PPO agent with a recurrent policy on the CartPole environment.
.. note::
It is particularly important to pass the ``lstm_states``
and ``episode_start`` argument to the ``predict()`` method,
so the cell and hidden states of the LSTM are correctly updated.
.. code-block:: python
import numpy as np
from sb3_contrib import RecurrentPPO
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
model.learn(5000)
env = model.get_env()
obs = env.reset()
# cell and hidden state of the LSTM
lstm_states = None
num_envs = 1
# Episode start signals are used to reset the lstm states
episode_starts = np.ones((num_envs,), dtype=bool)
while True:
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
obs, rewards, dones, info = env.step(action)
episode_starts = dones
env.render()

View File

@ -33,6 +33,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
modules/ars
modules/ppo_mask
modules/ppo_recurrent
modules/qrdqn
modules/tqc
modules/trpo

View File

@ -3,9 +3,11 @@
Changelog
==========
Release 1.5.1a7 (WIP)
Release 1.5.1a8 (WIP)
-------------------------------
**Add RecurrentPPO (aka PPO LSTM)**
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.5.1a7
@ -17,6 +19,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``RecurrentPPO`` (aka PPO LSTM)
Bug Fixes:
^^^^^^^^^^
@ -34,7 +37,8 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Allow PPO to turn of advantage normalization (see `PR #61 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/61>`_) @vwxyzjn
- Allow PPO to turn of advantage normalization (see `PR #61 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/61>`_) (@vwxyzjn)
Bug Fixes:
^^^^^^^^^^
@ -46,6 +50,9 @@ Deprecations:
Others:
^^^^^^^
Documentation:
^^^^^^^^^^^^^^
Release 1.4.0 (2022-01-19)
-------------------------------
**Add Trust Region Policy Optimization (TRPO) and Augmented Random Search (ARS) algorithms**

View File

@ -5,7 +5,7 @@
Maskable PPO
============
Implementation of `invalid action masking <https://arxiv.org/abs/2006.14171>`_ for the Proximal Policy Optimization(PPO)
Implementation of `invalid action masking <https://arxiv.org/abs/2006.14171>`_ for the Proximal Policy Optimization (PPO)
algorithm. Other than adding support for action masking, the behavior is the same as in SB3's core PPO algorithm.

View File

@ -0,0 +1,153 @@
.. _ppo_lstm:
.. automodule:: sb3_contrib.ppo_recurrent
Recurrent PPO
=============
Implementation of recurrent policies for the Proximal Policy Optimization (PPO)
algorithm. Other than adding support for recurrent policies (LSTM here), the behavior is the same as in SB3's core PPO algorithm.
.. rubric:: Available Policies
.. autosummary::
:nosignatures:
MlpLstmPolicy
CnnLstmPolicy
MultiInputLstmPolicy
Notes
-----
- Blog post: https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/
Can I use?
----------
- Recurrent policies: ✔️
- Multi processing: ✔️
- Gym spaces:
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
Dict ❌ ✔️
============= ====== ===========
Example
-------
.. note::
It is particularly important to pass the ``lstm_states``
and ``episode_start`` argument to the ``predict()`` method,
so the cell and hidden states of the LSTM are correctly updated.
.. code-block:: python
import numpy as np
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.evaluation import evaluate_policy
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
model.learn(5000)
env = model.get_env()
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False)
print(mean_reward)
model.save("ppo_recurrent")
del model # remove to demonstrate saving and loading
model = RecurrentPPO.load("ppo_recurrent")
obs = env.reset()
# cell and hidden state of the LSTM
lstm_states = None
num_envs = 1
# Episode start signals are used to reset the lstm states
episode_starts = np.ones((num_envs,), dtype=bool)
while True:
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
obs, rewards, dones, info = env.step(action)
episode_starts = dones
env.render()
Results
-------
Report on environments with masked velocity (with and without framestack) can be found here: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4
``RecurrentPPO`` was evaluated against PPO on:
- PendulumNoVel-v1
- LunarLanderNoVel-v2
- CartPoleNoVel-v1
- MountainCarContinuousNoVel-v0
- CarRacing-v0
How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Clone the repo for the experiment:
.. code-block:: bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo
git checkout feat/recurrent-ppo
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
.. code-block:: bash
python train.py --algo ppo_lstm --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Parameters
----------
.. autoclass:: RecurrentPPO
:members:
:inherited-members:
RecurrentPPO Policies
---------------------
.. autoclass:: MlpLstmPolicy
:members:
:inherited-members:
.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticPolicy
:members:
:noindex:
.. autoclass:: CnnLstmPolicy
:members:
.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticCnnPolicy
:members:
:noindex:
.. autoclass:: MultiInputLstmPolicy
:members:
.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentMultiInputActorCriticPolicy
:members:
:noindex:

View File

@ -2,6 +2,7 @@ import os
from sb3_contrib.ars import ARS
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC
from sb3_contrib.trpo import TRPO

View File

@ -215,12 +215,12 @@ class MaskableActorCriticPolicy(BasePolicy):
action_masks: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action and state from an observation (and optional state).
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:param action_masks: Action masks to apply to the action distribution
:return: the model's action and the next state
@ -229,8 +229,8 @@ class MaskableActorCriticPolicy(BasePolicy):
# TODO (GH/1): add support for RNN policies
# if state is None:
# state = self.initial_state
# if mask is None:
# mask = [False for _ in range(self.n_envs)]
# if episode_start is None:
# episode_start = [False for _ in range(self.n_envs)]
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)
@ -256,7 +256,7 @@ class MaskableActorCriticPolicy(BasePolicy):
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
actions = actions[0]
return actions, state
return actions, None
def evaluate_actions(
self,

View File

View File

@ -0,0 +1,384 @@
from functools import partial
from typing import Callable, Generator, Optional, Tuple, Union
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib.common.recurrent.type_aliases import (
RecurrentDictRolloutBufferSamples,
RecurrentRolloutBufferSamples,
RNNStates,
)
def pad(
seq_start_indices: np.ndarray,
seq_end_indices: np.ndarray,
device: th.device,
tensor: np.ndarray,
padding_value: float = 0.0,
) -> th.Tensor:
"""
Chunk sequences and pad them to have constant dimensions.
:param seq_start_indices: Indices of the transitions that start a sequence
:param seq_end_indices: Indices of the transitions that end a sequence
:param device: PyTorch device
:param tensor: Tensor of shape (batch_size, *tensor_shape)
:param padding_value: Value used to pad sequence to the same length
(zero padding by default)
:return: (n_seq, max_length, *tensor_shape)
"""
# Create sequences given start and end
seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)]
return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value)
def pad_and_flatten(
seq_start_indices: np.ndarray,
seq_end_indices: np.ndarray,
device: th.device,
tensor: np.ndarray,
padding_value: float = 0.0,
) -> th.Tensor:
"""
Pad and flatten the sequences of scalar values,
while keeping the sequence order.
From (batch_size, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,)
:param seq_start_indices: Indices of the transitions that start a sequence
:param seq_end_indices: Indices of the transitions that end a sequence
:param device: PyTorch device (cpu, gpu, ...)
:param tensor: Tensor of shape (max_length, n_seq, 1)
:param padding_value: Value used to pad sequence to the same length
(zero padding by default)
:return: (n_seq * max_length,) aka (padded_batch_size,)
"""
return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten()
def create_sequencers(
episode_starts: np.ndarray,
env_change: np.ndarray,
device: th.device,
) -> Tuple[np.ndarray, Callable, Callable]:
"""
Create the utility function to chunk data into
sequences and pad them to create fixed size tensors.
:param episode_starts: Indices where an episode starts
:param env_change: Indices where the data collected
come from a different env (when using multiple env for data collection)
:param device: PyTorch device
:return: Indices of the transitions that start a sequence,
pad and pad_and_flatten utilities tailored for this batch
(sequence starts and ends indices are fixed)
"""
# Create sequence if env changes too
seq_start = np.logical_or(episode_starts, env_change).flatten()
# First index is always the beginning of a sequence
seq_start[0] = True
# Retrieve indices of sequence starts
seq_start_indices = np.where(seq_start == True)[0] # noqa: E712
# End of sequence are just before sequence starts
# Last index is also always end of a sequence
seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])])
# Create padding method for this minibatch
# to avoid repeating arguments (seq_start_indices, seq_end_indices)
local_pad = partial(pad, seq_start_indices, seq_end_indices, device)
local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device)
return seq_start_indices, local_pad, local_pad_and_flatten
class RecurrentRolloutBuffer(RolloutBuffer):
"""
Rollout buffer that also stores the LSTM cell and hidden states.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param hidden_state_shape: Shape of the buffer that will collect lstm states
(n_steps, lstm.num_layers, n_envs, lstm.hidden_size)
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
hidden_state_shape: Tuple[int, int, int, int],
device: Union[th.device, str] = "cpu",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
self.hidden_state_shape = hidden_state_shape
self.seq_start_indices, self.seq_end_indices = None, None
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
def reset(self):
super().reset()
self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
def add(self, *args, lstm_states: RNNStates, **kwargs) -> None:
"""
:param hidden_states: LSTM cell and hidden state
"""
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]:
assert self.full, "Rollout buffer must be full before sampling from it"
# Prepare the data
if not self.generator_ready:
# hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)
# swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size)
for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]:
self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2)
# flatten but keep the sequence order
# 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape)
# 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape)
for tensor in [
"observations",
"actions",
"values",
"log_probs",
"advantages",
"returns",
"hidden_states_pi",
"cell_states_pi",
"hidden_states_vf",
"cell_states_vf",
"episode_starts",
]:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
# Sampling strategy that allows any mini batch size but requires
# more complexity and use of padding
# Trick to shuffle a bit: keep the sequence order
# but split the indices in two
split_index = np.random.randint(self.buffer_size * self.n_envs)
indices = np.arange(self.buffer_size * self.n_envs)
indices = np.concatenate((indices[split_index:], indices[:split_index]))
env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs)
# Flag first timestep as change of environment
env_change[0, :] = 1.0
env_change = self.swap_and_flatten(env_change)
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
batch_inds = indices[start_idx : start_idx + batch_size]
yield self._get_samples(batch_inds, env_change)
start_idx += batch_size
def _get_samples(
self,
batch_inds: np.ndarray,
env_change: np.ndarray,
env: Optional[VecNormalize] = None,
) -> RecurrentRolloutBufferSamples:
# Retrieve sequence starts and utility function
self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers(
self.episode_starts[batch_inds], env_change[batch_inds], self.device
)
n_layers = self.hidden_states_pi.shape[1]
# Number of sequences
n_seq = len(self.seq_start_indices)
max_length = self.pad(self.actions[batch_inds]).shape[1]
padded_batch_size = n_seq * max_length
# We retrieve the lstm hidden states that will allow
# to properly initialize the LSTM at the beginning of each sequence
lstm_states_pi = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
)
lstm_states_vf = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
)
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]))
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1]))
return RecurrentRolloutBufferSamples(
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size,) + self.obs_shape),
actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]),
old_values=self.pad_and_flatten(self.values[batch_inds]),
old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]),
advantages=self.pad_and_flatten(self.advantages[batch_inds]),
returns=self.pad_and_flatten(self.returns[batch_inds]),
lstm_states=RNNStates(lstm_states_pi, lstm_states_vf),
episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]),
mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])),
)
class RecurrentDictRolloutBuffer(DictRolloutBuffer):
"""
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
Extends the RecurrentRolloutBuffer to use dictionary observations
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param hidden_state_shape: Shape of the buffer that will collect lstm states
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
hidden_state_shape: Tuple[int, int, int, int],
device: Union[th.device, str] = "cpu",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
self.hidden_state_shape = hidden_state_shape
self.seq_start_indices, self.seq_end_indices = None, None
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
def reset(self):
super().reset()
self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
def add(self, *args, lstm_states: RNNStates, **kwargs) -> None:
"""
:param hidden_states: LSTM cell and hidden state
"""
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]:
assert self.full, "Rollout buffer must be full before sampling from it"
# Prepare the data
if not self.generator_ready:
# hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)
# swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size)
for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]:
self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2)
for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)
for tensor in [
"actions",
"values",
"log_probs",
"advantages",
"returns",
"hidden_states_pi",
"cell_states_pi",
"hidden_states_vf",
"cell_states_vf",
"episode_starts",
]:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
# Trick to shuffle a bit: keep the sequence order
# but split the indices in two
split_index = np.random.randint(self.buffer_size * self.n_envs)
indices = np.arange(self.buffer_size * self.n_envs)
indices = np.concatenate((indices[split_index:], indices[:split_index]))
env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs)
# Flag first timestep as change of environment
env_change[0, :] = 1.0
env_change = self.swap_and_flatten(env_change)
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
batch_inds = indices[start_idx : start_idx + batch_size]
yield self._get_samples(batch_inds, env_change)
start_idx += batch_size
def _get_samples(
self,
batch_inds: np.ndarray,
env_change: np.ndarray,
env: Optional[VecNormalize] = None,
) -> RecurrentDictRolloutBufferSamples:
# Retrieve sequence starts and utility function
self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers(
self.episode_starts[batch_inds], env_change[batch_inds], self.device
)
n_layers = self.hidden_states_pi.shape[1]
n_seq = len(self.seq_start_indices)
max_length = self.pad(self.actions[batch_inds]).shape[1]
padded_batch_size = n_seq * max_length
# We retrieve the lstm hidden states that will allow
# to properly initialize the LSTM at the beginning of each sequence
lstm_states_pi = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
)
lstm_states_vf = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
)
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]))
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1]))
observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()}
observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()}
return RecurrentDictRolloutBufferSamples(
observations=observations,
actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]),
old_values=self.pad_and_flatten(self.values[batch_inds]),
old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]),
advantages=self.pad_and_flatten(self.advantages[batch_inds]),
returns=self.pad_and_flatten(self.returns[batch_inds]),
lstm_states=RNNStates(lstm_states_pi, lstm_states_vf),
episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]),
mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])),
)

View File

@ -0,0 +1,601 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from stable_baselines3.common.distributions import Distribution
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
MlpExtractor,
NatureCNN,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import zip_strict
from torch import nn
from sb3_contrib.common.recurrent.type_aliases import RNNStates
class RecurrentActorCriticPolicy(ActorCriticPolicy):
"""
Recurrent policy class for actor-critic algorithms (has both policy and value prediction).
To be used with A2C, PPO and the likes.
It assumes that both the actor and the critic LSTM
have the same architecture.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param lstm_hidden_size: Number of hidden units for each LSTM layer.
:param n_lstm_layers: Number of LSTM layers.
:param shared_lstm: Whether the LSTM is shared between the actor and the critic
(in that case, only the actor gradient is used)
By default, the actor and the critic have two separate LSTM.
:param enable_critic_lstm: Use a seperate LSTM for the critic.
:param lstm_kwargs: Additional keyword arguments to pass the the LSTM
constructor.
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
lstm_hidden_size: int = 256,
n_lstm_layers: int = 1,
shared_lstm: bool = False,
enable_critic_lstm: bool = True,
lstm_kwargs: Optional[Dict[str, Any]] = None,
):
self.lstm_output_dim = lstm_hidden_size
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
sde_net_arch,
use_expln,
squash_output,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
self.lstm_kwargs = lstm_kwargs or {}
self.shared_lstm = shared_lstm
self.enable_critic_lstm = enable_critic_lstm
self.lstm_actor = nn.LSTM(
self.features_dim,
lstm_hidden_size,
num_layers=n_lstm_layers,
**self.lstm_kwargs,
)
# For the predict() method, to initialize hidden states
# (n_lstm_layers, batch_size, lstm_hidden_size)
self.lstm_hidden_state_shape = (n_lstm_layers, 1, lstm_hidden_size)
self.critic = None
self.lstm_critic = None
assert not (
self.shared_lstm and self.enable_critic_lstm
), "You must choose between shared LSTM, seperate or no LSTM for the critic"
# No LSTM for the critic, we still need to convert
# output of features extractor to the correct size
# (size of the output of the actor lstm)
if not (self.shared_lstm or self.enable_critic_lstm):
self.critic = nn.Linear(self.features_dim, lstm_hidden_size)
# Use a separate LSTM for the critic
if self.enable_critic_lstm:
self.lstm_critic = nn.LSTM(
self.features_dim,
lstm_hidden_size,
num_layers=n_lstm_layers,
**self.lstm_kwargs,
)
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def _build_mlp_extractor(self) -> None:
"""
Create the policy and value networks.
Part of the layers can be shared.
"""
self.mlp_extractor = MlpExtractor(
self.lstm_output_dim,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
device=self.device,
)
@staticmethod
def _process_sequence(
features: th.Tensor,
lstm_states: Tuple[th.Tensor, th.Tensor],
episode_starts: th.Tensor,
lstm: nn.LSTM,
) -> Tuple[th.Tensor, th.Tensor]:
"""
Do a forward pass in the LSTM network.
:param features: Input tensor
:param lstm_states: previous cell and hidden states of the LSTM
:param episode_starts: Indicates when a new episode starts,
in that case, we need to reset LSTM states.
:param lstm: LSTM object.
:return: LSTM output and updated LSTM states.
"""
# LSTM logic
# (sequence length, batch size, features dim)
# (batch size = n_envs for data collection or n_seq when doing gradient update)
n_seq = lstm_states[0].shape[1]
# Batch to sequence
# (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim)
# note: max length (max sequence length) is always 1 during data collection
features_sequence = features.reshape((n_seq, -1, lstm.input_size)).swapaxes(0, 1)
episode_starts = episode_starts.reshape((n_seq, -1)).swapaxes(0, 1)
# If we don't have to reset the state in the middle of a sequence
# we can avoid the for loop, which speeds up things
if th.all(episode_starts == 0.0):
lstm_output, lstm_states = lstm(features_sequence, lstm_states)
lstm_output = th.flatten(lstm_output.transpose(0, 1), start_dim=0, end_dim=1)
return lstm_output, lstm_states
lstm_output = []
# Iterate over the sequence
for features, episode_start in zip_strict(features_sequence, episode_starts):
hidden, lstm_states = lstm(
features.unsqueeze(dim=0),
(
# Reset the states at the beginning of a new episode
(1.0 - episode_start).view(1, n_seq, 1) * lstm_states[0],
(1.0 - episode_start).view(1, n_seq, 1) * lstm_states[1],
),
)
lstm_output += [hidden]
# Sequence to batch
# (sequence length, n_seq, lstm_out_dim) -> (batch_size, lstm_out_dim)
lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1)
return lstm_output, lstm_states
def forward(
self,
obs: th.Tensor,
lstm_states: RNNStates,
episode_starts: th.Tensor,
deterministic: bool = False,
) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]:
"""
Forward pass in all the networks (actor and critic)
:param obs: Observation. Observation
:param lstm_states: The last hidden and memory states for the LSTM.
:param episode_starts: Whether the observations correspond to new episodes
or not (we reset the lstm states in that case).
:param deterministic: Whether to sample or use deterministic actions
:return: action, value and log probability of the action
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
# latent_pi, latent_vf = self.mlp_extractor(features)
latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor)
if self.lstm_critic is not None:
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic)
elif self.shared_lstm:
# Re-use LSTM features but do not backpropagate
latent_vf = latent_pi.detach()
lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach())
else:
# Critic only has a feedforward network
latent_vf = self.critic(features)
lstm_states_vf = lstm_states_pi
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf)
def get_distribution(
self,
obs: th.Tensor,
lstm_states: Tuple[th.Tensor, th.Tensor],
episode_starts: th.Tensor,
) -> Tuple[Distribution, Tuple[th.Tensor, ...]]:
"""
Get the current policy distribution given the observations.
:param obs: Observation.
:param lstm_states: The last hidden and memory states for the LSTM.
:param episode_starts: Whether the observations correspond to new episodes
or not (we reset the lstm states in that case).
:return: the action distribution and new hidden states.
"""
features = self.extract_features(obs)
latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor)
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
return self._get_action_dist_from_latent(latent_pi), lstm_states
def predict_values(
self,
obs: th.Tensor,
lstm_states: Tuple[th.Tensor, th.Tensor],
episode_starts: th.Tensor,
) -> th.Tensor:
"""
Get the estimated values according to the current policy given the observations.
:param obs: Observation.
:param lstm_states: The last hidden and memory states for the LSTM.
:param episode_starts: Whether the observations correspond to new episodes
or not (we reset the lstm states in that case).
:return: the estimated values.
"""
features = self.extract_features(obs)
if self.lstm_critic is not None:
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic)
elif self.shared_lstm:
# Use LSTM from the actor
latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor)
latent_vf = latent_pi.detach()
else:
latent_vf = self.critic(features)
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
return self.value_net(latent_vf)
def evaluate_actions(
self,
obs: th.Tensor,
actions: th.Tensor,
lstm_states: RNNStates,
episode_starts: th.Tensor,
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Evaluate actions according to the current policy,
given the observations.
:param obs: Observation.
:param actions:
:param lstm_states: The last hidden and memory states for the LSTM.
:param episode_starts: Whether the observations correspond to new episodes
or not (we reset the lstm states in that case).
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, _ = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor)
if self.lstm_critic is not None:
latent_vf, _ = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic)
elif self.shared_lstm:
latent_vf = latent_pi.detach()
else:
latent_vf = self.critic(features)
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
return values, log_prob, distribution.entropy()
def _predict(
self,
observation: th.Tensor,
lstm_states: Tuple[th.Tensor, th.Tensor],
episode_starts: th.Tensor,
deterministic: bool = False,
) -> Tuple[th.Tensor, Tuple[th.Tensor, ...]]:
"""
Get the action according to the policy for a given observation.
:param observation:
:param lstm_states: The last hidden and memory states for the LSTM.
:param episode_starts: Whether the observations correspond to new episodes
or not (we reset the lstm states in that case).
:param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy and hidden states of the RNN
"""
distribution, lstm_states = self.get_distribution(observation, lstm_states, episode_starts)
return distribution.get_actions(deterministic=deterministic), lstm_states
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param lstm_states: The last hidden and memory states for the LSTM.
:param episode_starts: Whether the observations correspond to new episodes
or not (we reset the lstm states in that case).
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)
observation, vectorized_env = self.obs_to_tensor(observation)
if isinstance(observation, dict):
n_envs = observation[list(observation.keys())[0]].shape[0]
else:
n_envs = observation.shape[0]
# state : (n_layers, n_envs, dim)
if state is None:
# Initialize hidden states to zeros
state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1)
state = (state, state)
if episode_start is None:
episode_start = np.array([False for _ in range(n_envs)])
with th.no_grad():
# Convert to PyTorch tensors
states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device)
episode_starts = th.tensor(episode_start).float().to(self.device)
actions, states = self._predict(
observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic
)
states = (states[0].cpu().numpy(), states[1].cpu().numpy())
# Convert to numpy
actions = actions.cpu().numpy()
if isinstance(self.action_space, gym.spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
# Remove batch dimension if needed
if not vectorized_env:
actions = actions[0]
return actions, states
class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
"""
CNN recurrent policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param lstm_hidden_size: Number of hidden units for each LSTM layer.
:param n_lstm_layers: Number of LSTM layers.
:param shared_lstm: Whether the LSTM is shared between the actor and the critic.
By default, only the actor has a recurrent network.
:param enable_critic_lstm: Use a seperate LSTM for the critic.
:param lstm_kwargs: Additional keyword arguments to pass the the LSTM
constructor.
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
lstm_hidden_size: int = 256,
n_lstm_layers: int = 1,
enable_critic_lstm: bool = True,
lstm_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
sde_net_arch,
use_expln,
squash_output,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
lstm_hidden_size,
n_lstm_layers,
enable_critic_lstm,
lstm_kwargs,
)
class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
"""
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param lstm_hidden_size: Number of hidden units for each LSTM layer.
:param n_lstm_layers: Number of LSTM layers.
:param shared_lstm: Whether the LSTM is shared between the actor and the critic.
By default, only the actor has a recurrent network.
:param enable_critic_lstm: Use a seperate LSTM for the critic.
:param lstm_kwargs: Additional keyword arguments to pass the the LSTM
constructor.
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
lstm_hidden_size: int = 256,
n_lstm_layers: int = 1,
enable_critic_lstm: bool = True,
lstm_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
sde_net_arch,
use_expln,
squash_output,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
lstm_hidden_size,
n_lstm_layers,
enable_critic_lstm,
lstm_kwargs,
)

View File

@ -0,0 +1,33 @@
from typing import NamedTuple, Tuple
import torch as th
from stable_baselines3.common.type_aliases import TensorDict
class RNNStates(NamedTuple):
pi: Tuple[th.Tensor, ...]
vf: Tuple[th.Tensor, ...]
class RecurrentRolloutBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
lstm_states: RNNStates
episode_starts: th.Tensor
mask: th.Tensor
class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples):
observations: TensorDict
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
lstm_states: RNNStates
episode_starts: th.Tensor
mask: th.Tensor

View File

@ -392,14 +392,17 @@ class MaskablePPO(OnPolicyAlgorithm):
action_masks: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the model's action(s) from an observation.
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:param action_masks: Action masks to apply to the action distribution.
:return: the model's action and the next state (used in recurrent policies)
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
return self.policy.predict(observation, state, episode_start, deterministic, action_masks=action_masks)

View File

@ -0,0 +1,2 @@
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
from sb3_contrib.ppo_recurrent.ppo_recurrent import RecurrentPPO

View File

@ -0,0 +1,9 @@
from sb3_contrib.common.recurrent.policies import (
RecurrentActorCriticCnnPolicy,
RecurrentActorCriticPolicy,
RecurrentMultiInputActorCriticPolicy,
)
MlpLstmPolicy = RecurrentActorCriticPolicy
CnnLstmPolicy = RecurrentActorCriticCnnPolicy
MultiInputLstmPolicy = RecurrentMultiInputActorCriticPolicy

View File

@ -0,0 +1,530 @@
import time
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer
from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
from sb3_contrib.common.recurrent.type_aliases import RNNStates
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
class RecurrentPPO(OnPolicyAlgorithm):
"""
Proximal Policy Optimization algorithm (PPO) (clip version)
with support for recurrent policies (LSTM).
Based on the original Stable Baselines 3 implementation.
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param n_steps: The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param batch_size: Minibatch size
:param n_epochs: Number of epoch when optimizing the surrogate loss
:param gamma: Discount factor
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
:param clip_range: Clipping parameter, it can be a function of the current progress
remaining (from 1 to 0).
:param clip_range_vf: Clipping parameter for the value function,
it can be a function of the current progress remaining (from 1 to 0).
This is a parameter specific to the OpenAI implementation. If None is passed (default),
no clipping will be done on the value function.
IMPORTANT: this clipping depends on the reward scaling.
:param normalize_advantage: Whether to normalize or not the advantage
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
:param target_kl: Limit the KL divergence between updates,
because the clipping is not enough to prevent large update
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
By default, there is no limit on the kl div.
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpLstmPolicy": MlpLstmPolicy,
"CnnLstmPolicy": CnnLstmPolicy,
"MultiInputLstmPolicy": MultiInputLstmPolicy,
}
def __init__(
self,
policy: Union[str, Type[RecurrentActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
n_steps: int = 128,
batch_size: Optional[int] = 128,
n_epochs: int = 10,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_range: Union[float, Schedule] = 0.2,
clip_range_vf: Union[None, float, Schedule] = None,
normalize_advantage: bool = True,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
use_sde: bool = False,
sde_sample_freq: int = -1,
target_kl: Optional[float] = None,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate=learning_rate,
n_steps=n_steps,
gamma=gamma,
gae_lambda=gae_lambda,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
tensorboard_log=tensorboard_log,
create_eval_env=create_eval_env,
policy_kwargs=policy_kwargs,
verbose=verbose,
seed=seed,
device=device,
_init_setup_model=False,
supported_action_spaces=(
spaces.Box,
spaces.Discrete,
spaces.MultiDiscrete,
spaces.MultiBinary,
),
)
self.batch_size = batch_size
self.n_epochs = n_epochs
self.clip_range = clip_range
self.clip_range_vf = clip_range_vf
self.normalize_advantage = normalize_advantage
self.target_kl = target_kl
self._last_lstm_states = None
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
buffer_cls = (
RecurrentDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer
)
self.policy = self.policy_class(
self.observation_space,
self.action_space,
self.lr_schedule,
use_sde=self.use_sde,
**self.policy_kwargs, # pytype:disable=not-instantiable
)
self.policy = self.policy.to(self.device)
# We assume that LSTM for the actor and the critic
# have the same architecture
lstm = self.policy.lstm_actor
if not isinstance(self.policy, RecurrentActorCriticPolicy):
raise ValueError("Policy must subclass RecurrentActorCriticPolicy")
single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size)
# hidden and cell states for actor and critic
self._last_lstm_states = RNNStates(
(
th.zeros(single_hidden_state_shape).to(self.device),
th.zeros(single_hidden_state_shape).to(self.device),
),
(
th.zeros(single_hidden_state_shape).to(self.device),
th.zeros(single_hidden_state_shape).to(self.device),
),
)
hidden_state_buffer_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)
self.rollout_buffer = buffer_cls(
self.n_steps,
self.observation_space,
self.action_space,
hidden_state_buffer_shape,
self.device,
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
)
# Initialize schedules for policy/value clipping
self.clip_range = get_schedule_fn(self.clip_range)
if self.clip_range_vf is not None:
if isinstance(self.clip_range_vf, (float, int)):
assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, pass `None` to deactivate vf clipping"
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
def _setup_learn(
self,
total_timesteps: int,
eval_env: Optional[GymEnv],
callback: MaybeCallback = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "RecurrentPPO",
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on
:param eval_env: Environment to use for evaluation.
:param callback: Callback(s) called at every step with state of the algorithm.
:param eval_freq: How many steps between evaluations
:param n_eval_episodes: How many episodes to play per evaluation
:param log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:return:
"""
total_timesteps, callback = super()._setup_learn(
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
log_path,
reset_num_timesteps,
tb_log_name,
)
return total_timesteps, callback
def collect_rollouts(
self,
env: VecEnv,
callback: BaseCallback,
rollout_buffer: RolloutBuffer,
n_rollout_steps: int,
) -> bool:
"""
Collect experiences using the current policy and fill a ``RolloutBuffer``.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
:param env: The training environment
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param rollout_buffer: Buffer to fill with rollouts
:param n_steps: Number of experiences to collect per environment
:return: True if function returned with at least `n_rollout_steps`
collected, False if callback terminated rollout prematurely.
"""
assert isinstance(
rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer)
), f"{rollout_buffer} doesn't support recurrent policy"
assert self._last_obs is not None, "No previous observation was provided"
# Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False)
n_steps = 0
rollout_buffer.reset()
# Sample new weights for the state dependent exploration
if self.use_sde:
self.policy.reset_noise(env.num_envs)
callback.on_rollout_start()
lstm_states = deepcopy(self._last_lstm_states)
while n_steps < n_rollout_steps:
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.policy.reset_noise(env.num_envs)
with th.no_grad():
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
episode_starts = th.tensor(self._last_episode_starts).float().to(self.device)
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
actions = actions.cpu().numpy()
# Rescale and perform action
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.action_space, gym.spaces.Box):
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions)
self.num_timesteps += env.num_envs
# Give access to local variables
callback.update_locals(locals())
if callback.on_step() is False:
return False
self._update_info_buffer(infos)
n_steps += 1
if isinstance(self.action_space, gym.spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
# Handle timeout by bootstraping with value function
# see GitHub issue #633
for idx, done_ in enumerate(dones):
if (
done_
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_lstm_state = (
lstm_states.vf[0][:, idx : idx + 1, :],
lstm_states.vf[1][:, idx : idx + 1, :],
)
# terminal_lstm_state = None
episode_starts = th.tensor([False]).float().to(self.device)
terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0]
rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(
self._last_obs,
actions,
rewards,
self._last_episode_starts,
values,
log_probs,
lstm_states=self._last_lstm_states,
)
self._last_obs = new_obs
self._last_episode_starts = dones
self._last_lstm_states = lstm_states
with th.no_grad():
# Compute value for the last timestep
episode_starts = th.tensor(dones).float().to(self.device)
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts)
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
callback.on_rollout_end()
return True
def train(self) -> None:
"""
Update policy using the currently gathered rollout buffer.
"""
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# Compute current clip range
clip_range = self.clip_range(self._current_progress_remaining)
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
entropy_losses = []
pg_losses, value_losses = [], []
clip_fractions = []
continue_training = True
# train for n_epochs epochs
for epoch in range(self.n_epochs):
approx_kl_divs = []
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()
# Re-sample the noise matrix because the log_std has changed
if self.use_sde:
self.policy.reset_noise(self.batch_size)
values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations,
actions,
rollout_data.lstm_states,
rollout_data.episode_starts,
)
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# clipped surrogate loss
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
# Mask padded sequences
policy_loss_1 = policy_loss_1 * rollout_data.mask
policy_loss_2 = policy_loss_2 * rollout_data.mask
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
# Logging
pg_losses.append(policy_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
# No clipping
values_pred = values
else:
# Clip the different between old and new value
# NOTE: this depends on the reward scaling
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
)
# Value loss using the TD(gae_lambda) target
# Mask padded sequences
value_loss = th.mean(((rollout_data.returns - values_pred) * rollout_data.mask) ** 2)
value_losses.append(value_loss.item())
# Entropy loss favor exploration
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -th.mean(-(log_prob * rollout_data.mask))
else:
entropy_loss = -th.mean(entropy * rollout_data.mask)
entropy_losses.append(entropy_loss.item())
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
# Calculate approximate form of reverse KL Divergence for early stopping
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
# and Schulman blog: http://joschu.net/blog/kl-approx.html
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
if not continue_training:
break
self._n_updates += self.n_epochs
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
# Logs
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
self.logger.record("train/loss", loss.item())
self.logger.record("train/explained_variance", explained_var)
if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/clip_range", clip_range)
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "RecurrentPPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> "RecurrentPPO":
iteration = 0
total_timesteps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
)
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if continue_training is False:
break
iteration += 1
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
self.train()
callback.on_training_end()
return self

View File

@ -218,13 +218,16 @@ class QRDQN(OffPolicyAlgorithm):
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Overrides the base_class predict function to include epsilon-greedy exploration.
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
if not deterministic and np.random.rand() < self.exploration_rate:

View File

@ -1 +1 @@
1.5.1a7
1.5.1a8

View File

@ -1,2 +1,2 @@
#!/bin/bash
python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes
python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not slow"

View File

@ -13,6 +13,8 @@ filterwarnings =
ignore:Parameters to load are deprecated.:DeprecationWarning
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
ignore::UserWarning:gym
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
[pytype]
inputs = sb3_contrib
@ -24,6 +26,7 @@ per-file-ignores =
./sb3_contrib/__init__.py:F401
./sb3_contrib/ars/__init__.py:F401
./sb3_contrib/ppo_mask/__init__.py:F401
./sb3_contrib/ppo_recurrent/__init__.py:F401
./sb3_contrib/qrdqn/__init__.py:F401
./sb3_contrib/tqc/__init__.py:F401
./sb3_contrib/trpo/__init__.py:F401

View File

@ -3,7 +3,7 @@ from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import ARS, QRDQN, TQC
from sb3_contrib import ARS, QRDQN, TQC, RecurrentPPO
from sb3_contrib.common.vec_env import AsyncEval
N_STEPS_TRAINING = 500
@ -11,7 +11,7 @@ SEED = 0
ARS_MULTI = "ars_multi"
@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI])
@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI, RecurrentPPO])
def test_deterministic_training_common(algo):
results = [[], []]
rewards = [[], []]
@ -32,9 +32,12 @@ def test_deterministic_training_common(algo):
kwargs.update({"learning_starts": 100, "target_update_interval": 100})
elif algo == ARS:
kwargs.update({"n_delta": 2})
elif algo == RecurrentPPO:
kwargs.update({"policy_kwargs": dict(net_arch=[], enable_critic_lstm=True, lstm_hidden_size=8)})
kwargs.update({"n_steps": 50, "n_epochs": 4})
policy_str = "MlpLstmPolicy" if algo == RecurrentPPO else "MlpPolicy"
for i in range(2):
model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
model = algo(policy_str, env_id, seed=SEED, **kwargs)
learn_kwargs = {"total_timesteps": N_STEPS_TRAINING}
if ars_multi:
@ -46,9 +49,11 @@ def test_deterministic_training_common(algo):
model.learn(**learn_kwargs)
env = model.get_env()
obs = env.reset()
states = None
episode_start = None
for _ in range(100):
action, _ = model.predict(obs, deterministic=False)
obs, reward, _, _ = env.step(action)
action, states = model.predict(obs, state=states, episode_start=episode_start, deterministic=False)
obs, reward, episode_start, _ = env.step(action)
results[i].append(action)
rewards[i].append(reward)
assert sum(results[0]) == sum(results[1]), results

186
tests/test_lstm.py Normal file
View File

@ -0,0 +1,186 @@
import gym
import numpy as np
import pytest
from gym import spaces
from gym.envs.classic_control import CartPoleEnv
from gym.wrappers.time_limit import TimeLimit
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import RecurrentPPO
class ToDictWrapper(gym.Wrapper):
"""
Simple wrapper to test MultInputPolicy on Dict obs.
"""
def __init__(self, env):
super().__init__(env)
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space})
def reset(self):
return {"obs": self.env.reset()}
def step(self, action):
obs, reward, done, infos = self.env.step(action)
return {"obs": obs}, reward, done, infos
class CartPoleNoVelEnv(CartPoleEnv):
"""Variant of CartPoleEnv with velocity information removed. This task requires memory to solve."""
def __init__(self):
super().__init__()
high = np.array(
[
self.x_threshold * 2,
self.theta_threshold_radians * 2,
]
)
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
@staticmethod
def _pos_obs(full_obs):
xpos, _xvel, thetapos, _thetavel = full_obs
return xpos, thetapos
def reset(self):
full_obs = super().reset()
return CartPoleNoVelEnv._pos_obs(full_obs)
def step(self, action):
full_obs, rew, done, info = super().step(action)
return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info
def test_cnn():
model = RecurrentPPO(
"CnnLstmPolicy",
FakeImageEnv(screen_height=40, screen_width=40, n_channels=3),
n_steps=16,
seed=0,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
)
model.learn(total_timesteps=32)
@pytest.mark.parametrize(
"policy_kwargs",
[
{},
dict(shared_lstm=True, enable_critic_lstm=False),
dict(
enable_critic_lstm=True,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
),
dict(
enable_critic_lstm=False,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
),
],
)
def test_policy_kwargs(policy_kwargs):
model = RecurrentPPO(
"MlpLstmPolicy",
"CartPole-v1",
n_steps=16,
seed=0,
policy_kwargs=policy_kwargs,
)
model.learn(total_timesteps=32)
def test_check():
policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=True)
with pytest.raises(AssertionError):
RecurrentPPO(
"MlpLstmPolicy",
"CartPole-v1",
n_steps=16,
seed=0,
policy_kwargs=policy_kwargs,
)
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
def test_run(env):
model = RecurrentPPO(
"MlpLstmPolicy",
env,
n_steps=16,
seed=0,
create_eval_env=True,
)
model.learn(total_timesteps=32, eval_freq=16)
def test_run_sde():
model = RecurrentPPO(
"MlpLstmPolicy",
"Pendulum-v1",
n_steps=16,
seed=0,
create_eval_env=True,
sde_sample_freq=4,
use_sde=True,
clip_range_vf=0.1,
)
model.learn(total_timesteps=200, eval_freq=150)
def test_dict_obs():
env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper)
model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32).learn(64)
evaluate_policy(model, env, warn=False)
@pytest.mark.slow
def test_ppo_lstm_performance():
# env = make_vec_env("CartPole-v1", n_envs=16)
def make_env():
env = CartPoleNoVelEnv()
env = TimeLimit(env, max_episode_steps=500)
return env
env = VecNormalize(make_vec_env(make_env, n_envs=8))
eval_callback = EvalCallback(
VecNormalize(make_vec_env(make_env, n_envs=4), training=False, norm_reward=False),
n_eval_episodes=20,
eval_freq=5000 // env.num_envs,
)
model = RecurrentPPO(
"MlpLstmPolicy",
env,
n_steps=128,
learning_rate=0.0007,
verbose=1,
batch_size=256,
seed=1,
n_epochs=10,
max_grad_norm=1,
gae_lambda=0.98,
policy_kwargs=dict(
net_arch=[dict(vf=[64])],
lstm_hidden_size=64,
ortho_init=False,
enable_critic_lstm=True,
),
)
model.learn(total_timesteps=50_000, callback=eval_callback)
# Maximum episode reward is 500.
# In CartPole-v1, a non-recurrent policy can easily get >= 450.
# In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50.
evaluate_policy(model, env, reward_threshold=450)