Recurrent PPO (#53)
* Running (not working yet) version of recurrent PPO * Fixes for multi envs * Save WIP, rework the sampling * Add Box support * Fix sample order * Being cleanup, code is broken (again) * First working version (no shared lstm) * Start cleanup * Try rnn with value function * Re-enable batch size * Deactivate vf rnn * Allow any batch size * Add support for evaluation * Add CNN support * Fix start of sequence * Allow shared LSTM * Rename mask to episode_start * Fix type hint * Enable LSTM for critic * Clean code * Fix for CNN LSTM * Fix sampling with n_layers > 1 * Add std logger * Update wording * Rename and add dict obs support * Fixes for dict obs support * Do not run slow tests * Fix doc * Update recurrent PPO example * Update README * Use Pendulum-v1 for tests * Fix image env * Speedup LSTM forward pass (#63) * added more efficient lstm implementation * Rename and add comment Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Fixes * Remove OpenAI sampling and improve coverage * Sync with SB3 PPO * Pass state shape and allow lstm kwargs * Update tests * Add masking for padded sequences * Update default in perf test * Remove TODO, mask is now working * Add helper to remove duplicated code, remove hack for padding * Enable LSTM critic and raise threshold for cartpole with no vel * Fix tests * Update doc and tests * Doc fix * Fix for new Sphinx version * Fix doc note * Switch to batch first, no more additional swap * Add comments and mask entropy loss Co-authored-by: Neville Walo <43504521+Walon1998@users.noreply.github.com>
This commit is contained in:
parent
cd592a111f
commit
75b2de1399
|
|
@ -25,11 +25,12 @@ We hope this allows us to provide reliable implementations following stable-base
|
|||
See documentation for the full list of included features.
|
||||
|
||||
**RL Algorithms**:
|
||||
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
||||
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
|
||||
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
|
||||
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
|
||||
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
|
||||
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
||||
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
|
||||
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
|
||||
|
||||
**Gym Wrappers**:
|
||||
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
||||
|
|
|
|||
|
|
@ -9,14 +9,17 @@ along with some useful characteristics: support for discrete/continuous actions,
|
|||
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||
============ =========== ============ ================= =============== ================
|
||||
ARS ✔️ ❌️ ❌ ❌ ✔️
|
||||
MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️
|
||||
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
|
||||
RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️
|
||||
TQC ✔️ ❌ ❌ ❌ ✔️
|
||||
TRPO ✔️ ✔️ ✔️ ✔️ ✔️
|
||||
============ =========== ============ ================= =============== ================
|
||||
|
||||
|
||||
.. note::
|
||||
Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
|
||||
``Tuple`` observation spaces are not supported by any environment,
|
||||
however, single-level ``Dict`` spaces are
|
||||
|
||||
Actions ``gym.spaces``:
|
||||
|
||||
|
|
|
|||
|
|
@ -71,3 +71,38 @@ Train an agent using Augmented Random Search (ARS) agent on the Pendulum environ
|
|||
model = ARS("LinearPolicy", "Pendulum-v1", verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.save("ars_pendulum")
|
||||
|
||||
RecurrentPPO
|
||||
------------
|
||||
|
||||
Train a PPO agent with a recurrent policy on the CartPole environment.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
It is particularly important to pass the ``lstm_states``
|
||||
and ``episode_start`` argument to the ``predict()`` method,
|
||||
so the cell and hidden states of the LSTM are correctly updated.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sb3_contrib import RecurrentPPO
|
||||
|
||||
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
|
||||
model.learn(5000)
|
||||
|
||||
env = model.get_env()
|
||||
obs = env.reset()
|
||||
# cell and hidden state of the LSTM
|
||||
lstm_states = None
|
||||
num_envs = 1
|
||||
# Episode start signals are used to reset the lstm states
|
||||
episode_starts = np.ones((num_envs,), dtype=bool)
|
||||
while True:
|
||||
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
episode_starts = dones
|
||||
env.render()
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
|
|||
|
||||
modules/ars
|
||||
modules/ppo_mask
|
||||
modules/ppo_recurrent
|
||||
modules/qrdqn
|
||||
modules/tqc
|
||||
modules/trpo
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.5.1a7 (WIP)
|
||||
Release 1.5.1a8 (WIP)
|
||||
-------------------------------
|
||||
|
||||
**Add RecurrentPPO (aka PPO LSTM)**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Upgraded to Stable-Baselines3 >= 1.5.1a7
|
||||
|
|
@ -17,6 +19,7 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``RecurrentPPO`` (aka PPO LSTM)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -34,7 +37,8 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Allow PPO to turn of advantage normalization (see `PR #61 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/61>`_) @vwxyzjn
|
||||
- Allow PPO to turn of advantage normalization (see `PR #61 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/61>`_) (@vwxyzjn)
|
||||
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -46,6 +50,9 @@ Deprecations:
|
|||
Others:
|
||||
^^^^^^^
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
Release 1.4.0 (2022-01-19)
|
||||
-------------------------------
|
||||
**Add Trust Region Policy Optimization (TRPO) and Augmented Random Search (ARS) algorithms**
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
Maskable PPO
|
||||
============
|
||||
|
||||
Implementation of `invalid action masking <https://arxiv.org/abs/2006.14171>`_ for the Proximal Policy Optimization(PPO)
|
||||
Implementation of `invalid action masking <https://arxiv.org/abs/2006.14171>`_ for the Proximal Policy Optimization (PPO)
|
||||
algorithm. Other than adding support for action masking, the behavior is the same as in SB3's core PPO algorithm.
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,153 @@
|
|||
.. _ppo_lstm:
|
||||
|
||||
.. automodule:: sb3_contrib.ppo_recurrent
|
||||
|
||||
Recurrent PPO
|
||||
=============
|
||||
|
||||
Implementation of recurrent policies for the Proximal Policy Optimization (PPO)
|
||||
algorithm. Other than adding support for recurrent policies (LSTM here), the behavior is the same as in SB3's core PPO algorithm.
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
MlpLstmPolicy
|
||||
CnnLstmPolicy
|
||||
MultiInputLstmPolicy
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
- Blog post: https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/
|
||||
|
||||
|
||||
Can I use?
|
||||
----------
|
||||
|
||||
- Recurrent policies: ✔️
|
||||
- Multi processing: ✔️
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ✔️ ✔️
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ✔️ ✔️
|
||||
MultiBinary ✔️ ✔️
|
||||
Dict ❌ ✔️
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. note::
|
||||
|
||||
It is particularly important to pass the ``lstm_states``
|
||||
and ``episode_start`` argument to the ``predict()`` method,
|
||||
so the cell and hidden states of the LSTM are correctly updated.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sb3_contrib import RecurrentPPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
|
||||
model.learn(5000)
|
||||
|
||||
env = model.get_env()
|
||||
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False)
|
||||
print(mean_reward)
|
||||
|
||||
model.save("ppo_recurrent")
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = RecurrentPPO.load("ppo_recurrent")
|
||||
|
||||
obs = env.reset()
|
||||
# cell and hidden state of the LSTM
|
||||
lstm_states = None
|
||||
num_envs = 1
|
||||
# Episode start signals are used to reset the lstm states
|
||||
episode_starts = np.ones((num_envs,), dtype=bool)
|
||||
while True:
|
||||
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
episode_starts = dones
|
||||
env.render()
|
||||
|
||||
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
Report on environments with masked velocity (with and without framestack) can be found here: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4
|
||||
|
||||
``RecurrentPPO`` was evaluated against PPO on:
|
||||
|
||||
- PendulumNoVel-v1
|
||||
- LunarLanderNoVel-v2
|
||||
- CartPoleNoVel-v1
|
||||
- MountainCarContinuousNoVel-v0
|
||||
- CarRacing-v0
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Clone the repo for the experiment:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
cd rl-baselines3-zoo
|
||||
git checkout feat/recurrent-ppo
|
||||
|
||||
|
||||
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo ppo_lstm --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
.. autoclass:: RecurrentPPO
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
|
||||
RecurrentPPO Policies
|
||||
---------------------
|
||||
|
||||
.. autoclass:: MlpLstmPolicy
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticPolicy
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
.. autoclass:: CnnLstmPolicy
|
||||
:members:
|
||||
|
||||
.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticCnnPolicy
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
.. autoclass:: MultiInputLstmPolicy
|
||||
:members:
|
||||
|
||||
.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentMultiInputActorCriticPolicy
|
||||
:members:
|
||||
:noindex:
|
||||
|
|
@ -2,6 +2,7 @@ import os
|
|||
|
||||
from sb3_contrib.ars import ARS
|
||||
from sb3_contrib.ppo_mask import MaskablePPO
|
||||
from sb3_contrib.ppo_recurrent import RecurrentPPO
|
||||
from sb3_contrib.qrdqn import QRDQN
|
||||
from sb3_contrib.tqc import TQC
|
||||
from sb3_contrib.trpo import TRPO
|
||||
|
|
|
|||
|
|
@ -215,12 +215,12 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action and state from an observation (and optional state).
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param episode_start: The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:param action_masks: Action masks to apply to the action distribution
|
||||
:return: the model's action and the next state
|
||||
|
|
@ -229,8 +229,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
# TODO (GH/1): add support for RNN policies
|
||||
# if state is None:
|
||||
# state = self.initial_state
|
||||
# if mask is None:
|
||||
# mask = [False for _ in range(self.n_envs)]
|
||||
# if episode_start is None:
|
||||
# episode_start = [False for _ in range(self.n_envs)]
|
||||
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.set_training_mode(False)
|
||||
|
|
@ -256,7 +256,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
|
||||
actions = actions[0]
|
||||
|
||||
return actions, state
|
||||
return actions, None
|
||||
|
||||
def evaluate_actions(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,384 @@
|
|||
from functools import partial
|
||||
from typing import Callable, Generator, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
from sb3_contrib.common.recurrent.type_aliases import (
|
||||
RecurrentDictRolloutBufferSamples,
|
||||
RecurrentRolloutBufferSamples,
|
||||
RNNStates,
|
||||
)
|
||||
|
||||
|
||||
def pad(
|
||||
seq_start_indices: np.ndarray,
|
||||
seq_end_indices: np.ndarray,
|
||||
device: th.device,
|
||||
tensor: np.ndarray,
|
||||
padding_value: float = 0.0,
|
||||
) -> th.Tensor:
|
||||
"""
|
||||
Chunk sequences and pad them to have constant dimensions.
|
||||
|
||||
:param seq_start_indices: Indices of the transitions that start a sequence
|
||||
:param seq_end_indices: Indices of the transitions that end a sequence
|
||||
:param device: PyTorch device
|
||||
:param tensor: Tensor of shape (batch_size, *tensor_shape)
|
||||
:param padding_value: Value used to pad sequence to the same length
|
||||
(zero padding by default)
|
||||
:return: (n_seq, max_length, *tensor_shape)
|
||||
"""
|
||||
# Create sequences given start and end
|
||||
seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)]
|
||||
return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value)
|
||||
|
||||
|
||||
def pad_and_flatten(
|
||||
seq_start_indices: np.ndarray,
|
||||
seq_end_indices: np.ndarray,
|
||||
device: th.device,
|
||||
tensor: np.ndarray,
|
||||
padding_value: float = 0.0,
|
||||
) -> th.Tensor:
|
||||
"""
|
||||
Pad and flatten the sequences of scalar values,
|
||||
while keeping the sequence order.
|
||||
From (batch_size, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,)
|
||||
|
||||
:param seq_start_indices: Indices of the transitions that start a sequence
|
||||
:param seq_end_indices: Indices of the transitions that end a sequence
|
||||
:param device: PyTorch device (cpu, gpu, ...)
|
||||
:param tensor: Tensor of shape (max_length, n_seq, 1)
|
||||
:param padding_value: Value used to pad sequence to the same length
|
||||
(zero padding by default)
|
||||
:return: (n_seq * max_length,) aka (padded_batch_size,)
|
||||
"""
|
||||
return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten()
|
||||
|
||||
|
||||
def create_sequencers(
|
||||
episode_starts: np.ndarray,
|
||||
env_change: np.ndarray,
|
||||
device: th.device,
|
||||
) -> Tuple[np.ndarray, Callable, Callable]:
|
||||
"""
|
||||
Create the utility function to chunk data into
|
||||
sequences and pad them to create fixed size tensors.
|
||||
|
||||
:param episode_starts: Indices where an episode starts
|
||||
:param env_change: Indices where the data collected
|
||||
come from a different env (when using multiple env for data collection)
|
||||
:param device: PyTorch device
|
||||
:return: Indices of the transitions that start a sequence,
|
||||
pad and pad_and_flatten utilities tailored for this batch
|
||||
(sequence starts and ends indices are fixed)
|
||||
"""
|
||||
# Create sequence if env changes too
|
||||
seq_start = np.logical_or(episode_starts, env_change).flatten()
|
||||
# First index is always the beginning of a sequence
|
||||
seq_start[0] = True
|
||||
# Retrieve indices of sequence starts
|
||||
seq_start_indices = np.where(seq_start == True)[0] # noqa: E712
|
||||
# End of sequence are just before sequence starts
|
||||
# Last index is also always end of a sequence
|
||||
seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])])
|
||||
|
||||
# Create padding method for this minibatch
|
||||
# to avoid repeating arguments (seq_start_indices, seq_end_indices)
|
||||
local_pad = partial(pad, seq_start_indices, seq_end_indices, device)
|
||||
local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device)
|
||||
return seq_start_indices, local_pad, local_pad_and_flatten
|
||||
|
||||
|
||||
class RecurrentRolloutBuffer(RolloutBuffer):
|
||||
"""
|
||||
Rollout buffer that also stores the LSTM cell and hidden states.
|
||||
|
||||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param hidden_state_shape: Shape of the buffer that will collect lstm states
|
||||
(n_steps, lstm.num_layers, n_envs, lstm.hidden_size)
|
||||
:param device: PyTorch device
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
Equivalent to classic advantage when set to 1.
|
||||
:param gamma: Discount factor
|
||||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
hidden_state_shape: Tuple[int, int, int, int],
|
||||
device: Union[th.device, str] = "cpu",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
):
|
||||
self.hidden_state_shape = hidden_state_shape
|
||||
self.seq_start_indices, self.seq_end_indices = None, None
|
||||
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
||||
self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
||||
self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
||||
self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
||||
|
||||
def add(self, *args, lstm_states: RNNStates, **kwargs) -> None:
|
||||
"""
|
||||
:param hidden_states: LSTM cell and hidden state
|
||||
"""
|
||||
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
|
||||
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
|
||||
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
|
||||
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
|
||||
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]:
|
||||
assert self.full, "Rollout buffer must be full before sampling from it"
|
||||
|
||||
# Prepare the data
|
||||
if not self.generator_ready:
|
||||
# hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)
|
||||
# swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size)
|
||||
for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]:
|
||||
self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2)
|
||||
|
||||
# flatten but keep the sequence order
|
||||
# 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape)
|
||||
# 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape)
|
||||
for tensor in [
|
||||
"observations",
|
||||
"actions",
|
||||
"values",
|
||||
"log_probs",
|
||||
"advantages",
|
||||
"returns",
|
||||
"hidden_states_pi",
|
||||
"cell_states_pi",
|
||||
"hidden_states_vf",
|
||||
"cell_states_vf",
|
||||
"episode_starts",
|
||||
]:
|
||||
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
||||
self.generator_ready = True
|
||||
|
||||
# Return everything, don't create minibatches
|
||||
if batch_size is None:
|
||||
batch_size = self.buffer_size * self.n_envs
|
||||
|
||||
# Sampling strategy that allows any mini batch size but requires
|
||||
# more complexity and use of padding
|
||||
# Trick to shuffle a bit: keep the sequence order
|
||||
# but split the indices in two
|
||||
split_index = np.random.randint(self.buffer_size * self.n_envs)
|
||||
indices = np.arange(self.buffer_size * self.n_envs)
|
||||
indices = np.concatenate((indices[split_index:], indices[:split_index]))
|
||||
|
||||
env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs)
|
||||
# Flag first timestep as change of environment
|
||||
env_change[0, :] = 1.0
|
||||
env_change = self.swap_and_flatten(env_change)
|
||||
|
||||
start_idx = 0
|
||||
while start_idx < self.buffer_size * self.n_envs:
|
||||
batch_inds = indices[start_idx : start_idx + batch_size]
|
||||
yield self._get_samples(batch_inds, env_change)
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(
|
||||
self,
|
||||
batch_inds: np.ndarray,
|
||||
env_change: np.ndarray,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> RecurrentRolloutBufferSamples:
|
||||
# Retrieve sequence starts and utility function
|
||||
self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers(
|
||||
self.episode_starts[batch_inds], env_change[batch_inds], self.device
|
||||
)
|
||||
|
||||
n_layers = self.hidden_states_pi.shape[1]
|
||||
# Number of sequences
|
||||
n_seq = len(self.seq_start_indices)
|
||||
max_length = self.pad(self.actions[batch_inds]).shape[1]
|
||||
padded_batch_size = n_seq * max_length
|
||||
# We retrieve the lstm hidden states that will allow
|
||||
# to properly initialize the LSTM at the beginning of each sequence
|
||||
lstm_states_pi = (
|
||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
)
|
||||
lstm_states_vf = (
|
||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
)
|
||||
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]))
|
||||
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1]))
|
||||
|
||||
return RecurrentRolloutBufferSamples(
|
||||
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
|
||||
observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size,) + self.obs_shape),
|
||||
actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]),
|
||||
old_values=self.pad_and_flatten(self.values[batch_inds]),
|
||||
old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]),
|
||||
advantages=self.pad_and_flatten(self.advantages[batch_inds]),
|
||||
returns=self.pad_and_flatten(self.returns[batch_inds]),
|
||||
lstm_states=RNNStates(lstm_states_pi, lstm_states_vf),
|
||||
episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]),
|
||||
mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])),
|
||||
)
|
||||
|
||||
|
||||
class RecurrentDictRolloutBuffer(DictRolloutBuffer):
|
||||
"""
|
||||
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
|
||||
Extends the RecurrentRolloutBuffer to use dictionary observations
|
||||
|
||||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param hidden_state_shape: Shape of the buffer that will collect lstm states
|
||||
:param device: PyTorch device
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
Equivalent to classic advantage when set to 1.
|
||||
:param gamma: Discount factor
|
||||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
hidden_state_shape: Tuple[int, int, int, int],
|
||||
device: Union[th.device, str] = "cpu",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
):
|
||||
self.hidden_state_shape = hidden_state_shape
|
||||
self.seq_start_indices, self.seq_end_indices = None, None
|
||||
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
||||
self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
||||
self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
||||
self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
||||
|
||||
def add(self, *args, lstm_states: RNNStates, **kwargs) -> None:
|
||||
"""
|
||||
:param hidden_states: LSTM cell and hidden state
|
||||
"""
|
||||
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
|
||||
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
|
||||
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
|
||||
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
|
||||
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]:
|
||||
assert self.full, "Rollout buffer must be full before sampling from it"
|
||||
|
||||
# Prepare the data
|
||||
if not self.generator_ready:
|
||||
# hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)
|
||||
# swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size)
|
||||
for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]:
|
||||
self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2)
|
||||
|
||||
for key, obs in self.observations.items():
|
||||
self.observations[key] = self.swap_and_flatten(obs)
|
||||
|
||||
for tensor in [
|
||||
"actions",
|
||||
"values",
|
||||
"log_probs",
|
||||
"advantages",
|
||||
"returns",
|
||||
"hidden_states_pi",
|
||||
"cell_states_pi",
|
||||
"hidden_states_vf",
|
||||
"cell_states_vf",
|
||||
"episode_starts",
|
||||
]:
|
||||
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
||||
self.generator_ready = True
|
||||
|
||||
# Return everything, don't create minibatches
|
||||
if batch_size is None:
|
||||
batch_size = self.buffer_size * self.n_envs
|
||||
|
||||
# Trick to shuffle a bit: keep the sequence order
|
||||
# but split the indices in two
|
||||
split_index = np.random.randint(self.buffer_size * self.n_envs)
|
||||
indices = np.arange(self.buffer_size * self.n_envs)
|
||||
indices = np.concatenate((indices[split_index:], indices[:split_index]))
|
||||
|
||||
env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs)
|
||||
# Flag first timestep as change of environment
|
||||
env_change[0, :] = 1.0
|
||||
env_change = self.swap_and_flatten(env_change)
|
||||
|
||||
start_idx = 0
|
||||
while start_idx < self.buffer_size * self.n_envs:
|
||||
batch_inds = indices[start_idx : start_idx + batch_size]
|
||||
yield self._get_samples(batch_inds, env_change)
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(
|
||||
self,
|
||||
batch_inds: np.ndarray,
|
||||
env_change: np.ndarray,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> RecurrentDictRolloutBufferSamples:
|
||||
# Retrieve sequence starts and utility function
|
||||
self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers(
|
||||
self.episode_starts[batch_inds], env_change[batch_inds], self.device
|
||||
)
|
||||
|
||||
n_layers = self.hidden_states_pi.shape[1]
|
||||
n_seq = len(self.seq_start_indices)
|
||||
max_length = self.pad(self.actions[batch_inds]).shape[1]
|
||||
padded_batch_size = n_seq * max_length
|
||||
# We retrieve the lstm hidden states that will allow
|
||||
# to properly initialize the LSTM at the beginning of each sequence
|
||||
lstm_states_pi = (
|
||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
)
|
||||
lstm_states_vf = (
|
||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
)
|
||||
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]))
|
||||
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1]))
|
||||
|
||||
observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()}
|
||||
observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()}
|
||||
|
||||
return RecurrentDictRolloutBufferSamples(
|
||||
observations=observations,
|
||||
actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]),
|
||||
old_values=self.pad_and_flatten(self.values[batch_inds]),
|
||||
old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]),
|
||||
advantages=self.pad_and_flatten(self.advantages[batch_inds]),
|
||||
returns=self.pad_and_flatten(self.returns[batch_inds]),
|
||||
lstm_states=RNNStates(lstm_states_pi, lstm_states_vf),
|
||||
episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]),
|
||||
mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])),
|
||||
)
|
||||
|
|
@ -0,0 +1,601 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3.common.distributions import Distribution
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
CombinedExtractor,
|
||||
FlattenExtractor,
|
||||
MlpExtractor,
|
||||
NatureCNN,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.utils import zip_strict
|
||||
from torch import nn
|
||||
|
||||
from sb3_contrib.common.recurrent.type_aliases import RNNStates
|
||||
|
||||
|
||||
class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||
"""
|
||||
Recurrent policy class for actor-critic algorithms (has both policy and value prediction).
|
||||
To be used with A2C, PPO and the likes.
|
||||
It assumes that both the actor and the critic LSTM
|
||||
have the same architecture.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param net_arch: The specification of the policy and value networks.
|
||||
:param activation_fn: Activation function
|
||||
:param ortho_init: Whether to use or not orthogonal initialization
|
||||
:param use_sde: Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param lstm_hidden_size: Number of hidden units for each LSTM layer.
|
||||
:param n_lstm_layers: Number of LSTM layers.
|
||||
:param shared_lstm: Whether the LSTM is shared between the actor and the critic
|
||||
(in that case, only the actor gradient is used)
|
||||
By default, the actor and the critic have two separate LSTM.
|
||||
:param enable_critic_lstm: Use a seperate LSTM for the critic.
|
||||
:param lstm_kwargs: Additional keyword arguments to pass the the LSTM
|
||||
constructor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
lstm_hidden_size: int = 256,
|
||||
n_lstm_layers: int = 1,
|
||||
shared_lstm: bool = False,
|
||||
enable_critic_lstm: bool = True,
|
||||
lstm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.lstm_output_dim = lstm_hidden_size
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
activation_fn,
|
||||
ortho_init,
|
||||
use_sde,
|
||||
log_std_init,
|
||||
full_std,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
squash_output,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
)
|
||||
|
||||
self.lstm_kwargs = lstm_kwargs or {}
|
||||
self.shared_lstm = shared_lstm
|
||||
self.enable_critic_lstm = enable_critic_lstm
|
||||
self.lstm_actor = nn.LSTM(
|
||||
self.features_dim,
|
||||
lstm_hidden_size,
|
||||
num_layers=n_lstm_layers,
|
||||
**self.lstm_kwargs,
|
||||
)
|
||||
# For the predict() method, to initialize hidden states
|
||||
# (n_lstm_layers, batch_size, lstm_hidden_size)
|
||||
self.lstm_hidden_state_shape = (n_lstm_layers, 1, lstm_hidden_size)
|
||||
self.critic = None
|
||||
self.lstm_critic = None
|
||||
assert not (
|
||||
self.shared_lstm and self.enable_critic_lstm
|
||||
), "You must choose between shared LSTM, seperate or no LSTM for the critic"
|
||||
|
||||
# No LSTM for the critic, we still need to convert
|
||||
# output of features extractor to the correct size
|
||||
# (size of the output of the actor lstm)
|
||||
if not (self.shared_lstm or self.enable_critic_lstm):
|
||||
self.critic = nn.Linear(self.features_dim, lstm_hidden_size)
|
||||
|
||||
# Use a separate LSTM for the critic
|
||||
if self.enable_critic_lstm:
|
||||
self.lstm_critic = nn.LSTM(
|
||||
self.features_dim,
|
||||
lstm_hidden_size,
|
||||
num_layers=n_lstm_layers,
|
||||
**self.lstm_kwargs,
|
||||
)
|
||||
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def _build_mlp_extractor(self) -> None:
|
||||
"""
|
||||
Create the policy and value networks.
|
||||
Part of the layers can be shared.
|
||||
"""
|
||||
self.mlp_extractor = MlpExtractor(
|
||||
self.lstm_output_dim,
|
||||
net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_sequence(
|
||||
features: th.Tensor,
|
||||
lstm_states: Tuple[th.Tensor, th.Tensor],
|
||||
episode_starts: th.Tensor,
|
||||
lstm: nn.LSTM,
|
||||
) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Do a forward pass in the LSTM network.
|
||||
|
||||
:param features: Input tensor
|
||||
:param lstm_states: previous cell and hidden states of the LSTM
|
||||
:param episode_starts: Indicates when a new episode starts,
|
||||
in that case, we need to reset LSTM states.
|
||||
:param lstm: LSTM object.
|
||||
:return: LSTM output and updated LSTM states.
|
||||
"""
|
||||
# LSTM logic
|
||||
# (sequence length, batch size, features dim)
|
||||
# (batch size = n_envs for data collection or n_seq when doing gradient update)
|
||||
n_seq = lstm_states[0].shape[1]
|
||||
# Batch to sequence
|
||||
# (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim)
|
||||
# note: max length (max sequence length) is always 1 during data collection
|
||||
features_sequence = features.reshape((n_seq, -1, lstm.input_size)).swapaxes(0, 1)
|
||||
episode_starts = episode_starts.reshape((n_seq, -1)).swapaxes(0, 1)
|
||||
|
||||
# If we don't have to reset the state in the middle of a sequence
|
||||
# we can avoid the for loop, which speeds up things
|
||||
if th.all(episode_starts == 0.0):
|
||||
lstm_output, lstm_states = lstm(features_sequence, lstm_states)
|
||||
lstm_output = th.flatten(lstm_output.transpose(0, 1), start_dim=0, end_dim=1)
|
||||
return lstm_output, lstm_states
|
||||
|
||||
lstm_output = []
|
||||
# Iterate over the sequence
|
||||
for features, episode_start in zip_strict(features_sequence, episode_starts):
|
||||
hidden, lstm_states = lstm(
|
||||
features.unsqueeze(dim=0),
|
||||
(
|
||||
# Reset the states at the beginning of a new episode
|
||||
(1.0 - episode_start).view(1, n_seq, 1) * lstm_states[0],
|
||||
(1.0 - episode_start).view(1, n_seq, 1) * lstm_states[1],
|
||||
),
|
||||
)
|
||||
lstm_output += [hidden]
|
||||
# Sequence to batch
|
||||
# (sequence length, n_seq, lstm_out_dim) -> (batch_size, lstm_out_dim)
|
||||
lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1)
|
||||
return lstm_output, lstm_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: th.Tensor,
|
||||
lstm_states: RNNStates,
|
||||
episode_starts: th.Tensor,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]:
|
||||
"""
|
||||
Forward pass in all the networks (actor and critic)
|
||||
|
||||
:param obs: Observation. Observation
|
||||
:param lstm_states: The last hidden and memory states for the LSTM.
|
||||
:param episode_starts: Whether the observations correspond to new episodes
|
||||
or not (we reset the lstm states in that case).
|
||||
:param deterministic: Whether to sample or use deterministic actions
|
||||
:return: action, value and log probability of the action
|
||||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
# latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor)
|
||||
if self.lstm_critic is not None:
|
||||
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic)
|
||||
elif self.shared_lstm:
|
||||
# Re-use LSTM features but do not backpropagate
|
||||
latent_vf = latent_pi.detach()
|
||||
lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach())
|
||||
else:
|
||||
# Critic only has a feedforward network
|
||||
latent_vf = self.critic(features)
|
||||
lstm_states_vf = lstm_states_pi
|
||||
|
||||
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
||||
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
|
||||
|
||||
# Evaluate the values for the given observations
|
||||
values = self.value_net(latent_vf)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
actions = distribution.get_actions(deterministic=deterministic)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf)
|
||||
|
||||
def get_distribution(
|
||||
self,
|
||||
obs: th.Tensor,
|
||||
lstm_states: Tuple[th.Tensor, th.Tensor],
|
||||
episode_starts: th.Tensor,
|
||||
) -> Tuple[Distribution, Tuple[th.Tensor, ...]]:
|
||||
"""
|
||||
Get the current policy distribution given the observations.
|
||||
|
||||
:param obs: Observation.
|
||||
:param lstm_states: The last hidden and memory states for the LSTM.
|
||||
:param episode_starts: Whether the observations correspond to new episodes
|
||||
or not (we reset the lstm states in that case).
|
||||
:return: the action distribution and new hidden states.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor)
|
||||
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
||||
return self._get_action_dist_from_latent(latent_pi), lstm_states
|
||||
|
||||
def predict_values(
|
||||
self,
|
||||
obs: th.Tensor,
|
||||
lstm_states: Tuple[th.Tensor, th.Tensor],
|
||||
episode_starts: th.Tensor,
|
||||
) -> th.Tensor:
|
||||
"""
|
||||
Get the estimated values according to the current policy given the observations.
|
||||
|
||||
:param obs: Observation.
|
||||
:param lstm_states: The last hidden and memory states for the LSTM.
|
||||
:param episode_starts: Whether the observations correspond to new episodes
|
||||
or not (we reset the lstm states in that case).
|
||||
:return: the estimated values.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
if self.lstm_critic is not None:
|
||||
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic)
|
||||
elif self.shared_lstm:
|
||||
# Use LSTM from the actor
|
||||
latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor)
|
||||
latent_vf = latent_pi.detach()
|
||||
else:
|
||||
latent_vf = self.critic(features)
|
||||
|
||||
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
|
||||
return self.value_net(latent_vf)
|
||||
|
||||
def evaluate_actions(
|
||||
self,
|
||||
obs: th.Tensor,
|
||||
actions: th.Tensor,
|
||||
lstm_states: RNNStates,
|
||||
episode_starts: th.Tensor,
|
||||
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
||||
:param obs: Observation.
|
||||
:param actions:
|
||||
:param lstm_states: The last hidden and memory states for the LSTM.
|
||||
:param episode_starts: Whether the observations correspond to new episodes
|
||||
or not (we reset the lstm states in that case).
|
||||
:return: estimated value, log likelihood of taking those actions
|
||||
and entropy of the action distribution.
|
||||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, _ = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor)
|
||||
|
||||
if self.lstm_critic is not None:
|
||||
latent_vf, _ = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic)
|
||||
elif self.shared_lstm:
|
||||
latent_vf = latent_pi.detach()
|
||||
else:
|
||||
latent_vf = self.critic(features)
|
||||
|
||||
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
||||
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
|
||||
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
values = self.value_net(latent_vf)
|
||||
return values, log_prob, distribution.entropy()
|
||||
|
||||
def _predict(
|
||||
self,
|
||||
observation: th.Tensor,
|
||||
lstm_states: Tuple[th.Tensor, th.Tensor],
|
||||
episode_starts: th.Tensor,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[th.Tensor, Tuple[th.Tensor, ...]]:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
||||
:param observation:
|
||||
:param lstm_states: The last hidden and memory states for the LSTM.
|
||||
:param episode_starts: Whether the observations correspond to new episodes
|
||||
or not (we reset the lstm states in that case).
|
||||
:param deterministic: Whether to use stochastic or deterministic actions
|
||||
:return: Taken action according to the policy and hidden states of the RNN
|
||||
"""
|
||||
distribution, lstm_states = self.get_distribution(observation, lstm_states, episode_starts)
|
||||
return distribution.get_actions(deterministic=deterministic), lstm_states
|
||||
|
||||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: the input observation
|
||||
:param lstm_states: The last hidden and memory states for the LSTM.
|
||||
:param episode_starts: Whether the observations correspond to new episodes
|
||||
or not (we reset the lstm states in that case).
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next hidden state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.set_training_mode(False)
|
||||
|
||||
observation, vectorized_env = self.obs_to_tensor(observation)
|
||||
|
||||
if isinstance(observation, dict):
|
||||
n_envs = observation[list(observation.keys())[0]].shape[0]
|
||||
else:
|
||||
n_envs = observation.shape[0]
|
||||
# state : (n_layers, n_envs, dim)
|
||||
if state is None:
|
||||
# Initialize hidden states to zeros
|
||||
state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1)
|
||||
state = (state, state)
|
||||
|
||||
if episode_start is None:
|
||||
episode_start = np.array([False for _ in range(n_envs)])
|
||||
|
||||
with th.no_grad():
|
||||
# Convert to PyTorch tensors
|
||||
states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device)
|
||||
episode_starts = th.tensor(episode_start).float().to(self.device)
|
||||
actions, states = self._predict(
|
||||
observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic
|
||||
)
|
||||
states = (states[0].cpu().numpy(), states[1].cpu().numpy())
|
||||
|
||||
# Convert to numpy
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if self.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.unscale_action(actions)
|
||||
else:
|
||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
# Remove batch dimension if needed
|
||||
if not vectorized_env:
|
||||
actions = actions[0]
|
||||
|
||||
return actions, states
|
||||
|
||||
|
||||
class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||
"""
|
||||
CNN recurrent policy class for actor-critic algorithms (has both policy and value prediction).
|
||||
Used by A2C, PPO and the likes.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param net_arch: The specification of the policy and value networks.
|
||||
:param activation_fn: Activation function
|
||||
:param ortho_init: Whether to use or not orthogonal initialization
|
||||
:param use_sde: Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param lstm_hidden_size: Number of hidden units for each LSTM layer.
|
||||
:param n_lstm_layers: Number of LSTM layers.
|
||||
:param shared_lstm: Whether the LSTM is shared between the actor and the critic.
|
||||
By default, only the actor has a recurrent network.
|
||||
:param enable_critic_lstm: Use a seperate LSTM for the critic.
|
||||
:param lstm_kwargs: Additional keyword arguments to pass the the LSTM
|
||||
constructor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
lstm_hidden_size: int = 256,
|
||||
n_lstm_layers: int = 1,
|
||||
enable_critic_lstm: bool = True,
|
||||
lstm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
activation_fn,
|
||||
ortho_init,
|
||||
use_sde,
|
||||
log_std_init,
|
||||
full_std,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
squash_output,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
lstm_hidden_size,
|
||||
n_lstm_layers,
|
||||
enable_critic_lstm,
|
||||
lstm_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||
"""
|
||||
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
|
||||
Used by A2C, PPO and the likes.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param net_arch: The specification of the policy and value networks.
|
||||
:param activation_fn: Activation function
|
||||
:param ortho_init: Whether to use or not orthogonal initialization
|
||||
:param use_sde: Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param lstm_hidden_size: Number of hidden units for each LSTM layer.
|
||||
:param n_lstm_layers: Number of LSTM layers.
|
||||
:param shared_lstm: Whether the LSTM is shared between the actor and the critic.
|
||||
By default, only the actor has a recurrent network.
|
||||
:param enable_critic_lstm: Use a seperate LSTM for the critic.
|
||||
:param lstm_kwargs: Additional keyword arguments to pass the the LSTM
|
||||
constructor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
lstm_hidden_size: int = 256,
|
||||
n_lstm_layers: int = 1,
|
||||
enable_critic_lstm: bool = True,
|
||||
lstm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
activation_fn,
|
||||
ortho_init,
|
||||
use_sde,
|
||||
log_std_init,
|
||||
full_std,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
squash_output,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
lstm_hidden_size,
|
||||
n_lstm_layers,
|
||||
enable_critic_lstm,
|
||||
lstm_kwargs,
|
||||
)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
from typing import NamedTuple, Tuple
|
||||
|
||||
import torch as th
|
||||
from stable_baselines3.common.type_aliases import TensorDict
|
||||
|
||||
|
||||
class RNNStates(NamedTuple):
|
||||
pi: Tuple[th.Tensor, ...]
|
||||
vf: Tuple[th.Tensor, ...]
|
||||
|
||||
|
||||
class RecurrentRolloutBufferSamples(NamedTuple):
|
||||
observations: th.Tensor
|
||||
actions: th.Tensor
|
||||
old_values: th.Tensor
|
||||
old_log_prob: th.Tensor
|
||||
advantages: th.Tensor
|
||||
returns: th.Tensor
|
||||
lstm_states: RNNStates
|
||||
episode_starts: th.Tensor
|
||||
mask: th.Tensor
|
||||
|
||||
|
||||
class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples):
|
||||
observations: TensorDict
|
||||
actions: th.Tensor
|
||||
old_values: th.Tensor
|
||||
old_log_prob: th.Tensor
|
||||
advantages: th.Tensor
|
||||
returns: th.Tensor
|
||||
lstm_states: RNNStates
|
||||
episode_starts: th.Tensor
|
||||
mask: th.Tensor
|
||||
|
|
@ -392,14 +392,17 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the model's action(s) from an observation.
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param state: The last hidden states (can be None, used in recurrent policies)
|
||||
:param episode_start: The last masks (can be None, used in recurrent policies)
|
||||
this correspond to beginning of episodes,
|
||||
where the hidden states of the RNN must be reset.
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:param action_masks: Action masks to apply to the action distribution.
|
||||
:return: the model's action and the next state (used in recurrent policies)
|
||||
:return: the model's action and the next hidden state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
return self.policy.predict(observation, state, episode_start, deterministic, action_masks=action_masks)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
|
||||
from sb3_contrib.ppo_recurrent.ppo_recurrent import RecurrentPPO
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
from sb3_contrib.common.recurrent.policies import (
|
||||
RecurrentActorCriticCnnPolicy,
|
||||
RecurrentActorCriticPolicy,
|
||||
RecurrentMultiInputActorCriticPolicy,
|
||||
)
|
||||
|
||||
MlpLstmPolicy = RecurrentActorCriticPolicy
|
||||
CnnLstmPolicy = RecurrentActorCriticCnnPolicy
|
||||
MultiInputLstmPolicy = RecurrentMultiInputActorCriticPolicy
|
||||
|
|
@ -0,0 +1,530 @@
|
|||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer
|
||||
from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
|
||||
from sb3_contrib.common.recurrent.type_aliases import RNNStates
|
||||
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
|
||||
|
||||
|
||||
class RecurrentPPO(OnPolicyAlgorithm):
|
||||
"""
|
||||
Proximal Policy Optimization algorithm (PPO) (clip version)
|
||||
with support for recurrent policies (LSTM).
|
||||
|
||||
Based on the original Stable Baselines 3 implementation.
|
||||
|
||||
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
|
||||
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: The learning rate, it can be a function
|
||||
of the current progress remaining (from 1 to 0)
|
||||
:param n_steps: The number of steps to run for each environment per update
|
||||
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
|
||||
:param batch_size: Minibatch size
|
||||
:param n_epochs: Number of epoch when optimizing the surrogate loss
|
||||
:param gamma: Discount factor
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
:param clip_range: Clipping parameter, it can be a function of the current progress
|
||||
remaining (from 1 to 0).
|
||||
:param clip_range_vf: Clipping parameter for the value function,
|
||||
it can be a function of the current progress remaining (from 1 to 0).
|
||||
This is a parameter specific to the OpenAI implementation. If None is passed (default),
|
||||
no clipping will be done on the value function.
|
||||
IMPORTANT: this clipping depends on the reward scaling.
|
||||
:param normalize_advantage: Whether to normalize or not the advantage
|
||||
:param ent_coef: Entropy coefficient for the loss calculation
|
||||
:param vf_coef: Value function coefficient for the loss calculation
|
||||
:param max_grad_norm: The maximum value for the gradient clipping
|
||||
:param target_kl: Limit the KL divergence between updates,
|
||||
because the clipping is not enough to prevent large update
|
||||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||
By default, there is no limit on the kl div.
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpLstmPolicy": MlpLstmPolicy,
|
||||
"CnnLstmPolicy": CnnLstmPolicy,
|
||||
"MultiInputLstmPolicy": MultiInputLstmPolicy,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[RecurrentActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 3e-4,
|
||||
n_steps: int = 128,
|
||||
batch_size: Optional[int] = 128,
|
||||
n_epochs: int = 10,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
clip_range: Union[float, Schedule] = 0.2,
|
||||
clip_range_vf: Union[None, float, Schedule] = None,
|
||||
normalize_advantage: bool = True,
|
||||
ent_coef: float = 0.0,
|
||||
vf_coef: float = 0.5,
|
||||
max_grad_norm: float = 0.5,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
target_kl: Optional[float] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
gamma=gamma,
|
||||
gae_lambda=gae_lambda,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm,
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
tensorboard_log=tensorboard_log,
|
||||
create_eval_env=create_eval_env,
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=verbose,
|
||||
seed=seed,
|
||||
device=device,
|
||||
_init_setup_model=False,
|
||||
supported_action_spaces=(
|
||||
spaces.Box,
|
||||
spaces.Discrete,
|
||||
spaces.MultiDiscrete,
|
||||
spaces.MultiBinary,
|
||||
),
|
||||
)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.n_epochs = n_epochs
|
||||
self.clip_range = clip_range
|
||||
self.clip_range_vf = clip_range_vf
|
||||
self.normalize_advantage = normalize_advantage
|
||||
self.target_kl = target_kl
|
||||
self._last_lstm_states = None
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
buffer_cls = (
|
||||
RecurrentDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer
|
||||
)
|
||||
|
||||
self.policy = self.policy_class(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.lr_schedule,
|
||||
use_sde=self.use_sde,
|
||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
||||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
# We assume that LSTM for the actor and the critic
|
||||
# have the same architecture
|
||||
lstm = self.policy.lstm_actor
|
||||
|
||||
if not isinstance(self.policy, RecurrentActorCriticPolicy):
|
||||
raise ValueError("Policy must subclass RecurrentActorCriticPolicy")
|
||||
|
||||
single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size)
|
||||
# hidden and cell states for actor and critic
|
||||
self._last_lstm_states = RNNStates(
|
||||
(
|
||||
th.zeros(single_hidden_state_shape).to(self.device),
|
||||
th.zeros(single_hidden_state_shape).to(self.device),
|
||||
),
|
||||
(
|
||||
th.zeros(single_hidden_state_shape).to(self.device),
|
||||
th.zeros(single_hidden_state_shape).to(self.device),
|
||||
),
|
||||
)
|
||||
|
||||
hidden_state_buffer_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)
|
||||
|
||||
self.rollout_buffer = buffer_cls(
|
||||
self.n_steps,
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
hidden_state_buffer_shape,
|
||||
self.device,
|
||||
gamma=self.gamma,
|
||||
gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs,
|
||||
)
|
||||
|
||||
# Initialize schedules for policy/value clipping
|
||||
self.clip_range = get_schedule_fn(self.clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
if isinstance(self.clip_range_vf, (float, int)):
|
||||
assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, pass `None` to deactivate vf clipping"
|
||||
|
||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||
|
||||
def _setup_learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
eval_env: Optional[GymEnv],
|
||||
callback: MaybeCallback = None,
|
||||
eval_freq: int = 10000,
|
||||
n_eval_episodes: int = 5,
|
||||
log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
tb_log_name: str = "RecurrentPPO",
|
||||
) -> Tuple[int, BaseCallback]:
|
||||
"""
|
||||
Initialize different variables needed for training.
|
||||
|
||||
:param total_timesteps: The total number of samples (env steps) to train on
|
||||
:param eval_env: Environment to use for evaluation.
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
:param eval_freq: How many steps between evaluations
|
||||
:param n_eval_episodes: How many episodes to play per evaluation
|
||||
:param log_path: Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
||||
:param tb_log_name: the name of the run for tensorboard log
|
||||
:return:
|
||||
"""
|
||||
|
||||
total_timesteps, callback = super()._setup_learn(
|
||||
total_timesteps,
|
||||
eval_env,
|
||||
callback,
|
||||
eval_freq,
|
||||
n_eval_episodes,
|
||||
log_path,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
)
|
||||
return total_timesteps, callback
|
||||
|
||||
def collect_rollouts(
|
||||
self,
|
||||
env: VecEnv,
|
||||
callback: BaseCallback,
|
||||
rollout_buffer: RolloutBuffer,
|
||||
n_rollout_steps: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Collect experiences using the current policy and fill a ``RolloutBuffer``.
|
||||
The term rollout here refers to the model-free notion and should not
|
||||
be used with the concept of rollout used in model-based RL or planning.
|
||||
|
||||
:param env: The training environment
|
||||
:param callback: Callback that will be called at each step
|
||||
(and at the beginning and end of the rollout)
|
||||
:param rollout_buffer: Buffer to fill with rollouts
|
||||
:param n_steps: Number of experiences to collect per environment
|
||||
:return: True if function returned with at least `n_rollout_steps`
|
||||
collected, False if callback terminated rollout prematurely.
|
||||
"""
|
||||
assert isinstance(
|
||||
rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer)
|
||||
), f"{rollout_buffer} doesn't support recurrent policy"
|
||||
|
||||
assert self._last_obs is not None, "No previous observation was provided"
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.policy.set_training_mode(False)
|
||||
|
||||
n_steps = 0
|
||||
rollout_buffer.reset()
|
||||
# Sample new weights for the state dependent exploration
|
||||
if self.use_sde:
|
||||
self.policy.reset_noise(env.num_envs)
|
||||
|
||||
callback.on_rollout_start()
|
||||
|
||||
lstm_states = deepcopy(self._last_lstm_states)
|
||||
|
||||
while n_steps < n_rollout_steps:
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.policy.reset_noise(env.num_envs)
|
||||
|
||||
with th.no_grad():
|
||||
# Convert to pytorch tensor or to TensorDict
|
||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||
episode_starts = th.tensor(self._last_episode_starts).float().to(self.device)
|
||||
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
|
||||
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
# Rescale and perform action
|
||||
clipped_actions = actions
|
||||
# Clip the actions to avoid out of bound error
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
|
||||
self.num_timesteps += env.num_envs
|
||||
|
||||
# Give access to local variables
|
||||
callback.update_locals(locals())
|
||||
if callback.on_step() is False:
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
n_steps += 1
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
# Reshape in case of discrete action
|
||||
actions = actions.reshape(-1, 1)
|
||||
|
||||
# Handle timeout by bootstraping with value function
|
||||
# see GitHub issue #633
|
||||
for idx, done_ in enumerate(dones):
|
||||
if (
|
||||
done_
|
||||
and infos[idx].get("terminal_observation") is not None
|
||||
and infos[idx].get("TimeLimit.truncated", False)
|
||||
):
|
||||
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
|
||||
with th.no_grad():
|
||||
terminal_lstm_state = (
|
||||
lstm_states.vf[0][:, idx : idx + 1, :],
|
||||
lstm_states.vf[1][:, idx : idx + 1, :],
|
||||
)
|
||||
# terminal_lstm_state = None
|
||||
episode_starts = th.tensor([False]).float().to(self.device)
|
||||
terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0]
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
rollout_buffer.add(
|
||||
self._last_obs,
|
||||
actions,
|
||||
rewards,
|
||||
self._last_episode_starts,
|
||||
values,
|
||||
log_probs,
|
||||
lstm_states=self._last_lstm_states,
|
||||
)
|
||||
|
||||
self._last_obs = new_obs
|
||||
self._last_episode_starts = dones
|
||||
self._last_lstm_states = lstm_states
|
||||
|
||||
with th.no_grad():
|
||||
# Compute value for the last timestep
|
||||
episode_starts = th.tensor(dones).float().to(self.device)
|
||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts)
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return True
|
||||
|
||||
def train(self) -> None:
|
||||
"""
|
||||
Update policy using the currently gathered rollout buffer.
|
||||
"""
|
||||
# Switch to train mode (this affects batch norm / dropout)
|
||||
self.policy.set_training_mode(True)
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# Compute current clip range
|
||||
clip_range = self.clip_range(self._current_progress_remaining)
|
||||
# Optional: clip range for the value function
|
||||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||||
|
||||
entropy_losses = []
|
||||
pg_losses, value_losses = [], []
|
||||
clip_fractions = []
|
||||
|
||||
continue_training = True
|
||||
|
||||
# train for n_epochs epochs
|
||||
for epoch in range(self.n_epochs):
|
||||
approx_kl_divs = []
|
||||
# Do a complete pass on the rollout buffer
|
||||
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||||
actions = rollout_data.actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Convert discrete action from float to long
|
||||
actions = rollout_data.actions.long().flatten()
|
||||
|
||||
# Re-sample the noise matrix because the log_std has changed
|
||||
if self.use_sde:
|
||||
self.policy.reset_noise(self.batch_size)
|
||||
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(
|
||||
rollout_data.observations,
|
||||
actions,
|
||||
rollout_data.lstm_states,
|
||||
rollout_data.episode_starts,
|
||||
)
|
||||
|
||||
values = values.flatten()
|
||||
# Normalize advantage
|
||||
advantages = rollout_data.advantages
|
||||
if self.normalize_advantage:
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# ratio between old and new policy, should be one at the first iteration
|
||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||
|
||||
# clipped surrogate loss
|
||||
policy_loss_1 = advantages * ratio
|
||||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
# Mask padded sequences
|
||||
policy_loss_1 = policy_loss_1 * rollout_data.mask
|
||||
policy_loss_2 = policy_loss_2 * rollout_data.mask
|
||||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
||||
|
||||
# Logging
|
||||
pg_losses.append(policy_loss.item())
|
||||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
||||
clip_fractions.append(clip_fraction)
|
||||
|
||||
if self.clip_range_vf is None:
|
||||
# No clipping
|
||||
values_pred = values
|
||||
else:
|
||||
# Clip the different between old and new value
|
||||
# NOTE: this depends on the reward scaling
|
||||
values_pred = rollout_data.old_values + th.clamp(
|
||||
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
|
||||
)
|
||||
# Value loss using the TD(gae_lambda) target
|
||||
# Mask padded sequences
|
||||
value_loss = th.mean(((rollout_data.returns - values_pred) * rollout_data.mask) ** 2)
|
||||
|
||||
value_losses.append(value_loss.item())
|
||||
|
||||
# Entropy loss favor exploration
|
||||
if entropy is None:
|
||||
# Approximate entropy when no analytical form
|
||||
entropy_loss = -th.mean(-(log_prob * rollout_data.mask))
|
||||
else:
|
||||
entropy_loss = -th.mean(entropy * rollout_data.mask)
|
||||
|
||||
entropy_losses.append(entropy_loss.item())
|
||||
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
# Calculate approximate form of reverse KL Divergence for early stopping
|
||||
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
|
||||
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
|
||||
# and Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
with th.no_grad():
|
||||
log_ratio = log_prob - rollout_data.old_log_prob
|
||||
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
||||
approx_kl_divs.append(approx_kl_div)
|
||||
|
||||
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
||||
continue_training = False
|
||||
if self.verbose >= 1:
|
||||
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
||||
break
|
||||
|
||||
# Optimization step
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
||||
if not continue_training:
|
||||
break
|
||||
|
||||
self._n_updates += self.n_epochs
|
||||
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
|
||||
|
||||
# Logs
|
||||
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
|
||||
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
||||
self.logger.record("train/value_loss", np.mean(value_losses))
|
||||
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
||||
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
|
||||
self.logger.record("train/loss", loss.item())
|
||||
self.logger.record("train/explained_variance", explained_var)
|
||||
if hasattr(self.policy, "log_std"):
|
||||
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
self.logger.record("train/clip_range", clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "RecurrentPPO",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> "RecurrentPPO":
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
|
||||
)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
||||
|
||||
if continue_training is False:
|
||||
break
|
||||
|
||||
iteration += 1
|
||||
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
||||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("time/fps", fps)
|
||||
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
|
||||
self.train()
|
||||
|
||||
callback.on_training_end()
|
||||
|
||||
return self
|
||||
|
|
@ -218,13 +218,16 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Overrides the base_class predict function to include epsilon-greedy exploration.
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param state: The last hidden states (can be None, used in recurrent policies)
|
||||
:param episode_start: The last masks (can be None, used in recurrent policies)
|
||||
this correspond to beginning of episodes,
|
||||
where the hidden states of the RNN must be reset.
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next state
|
||||
:return: the model's action and the next hidden state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.5.1a7
|
||||
1.5.1a8
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
#!/bin/bash
|
||||
python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes
|
||||
python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not slow"
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ filterwarnings =
|
|||
ignore:Parameters to load are deprecated.:DeprecationWarning
|
||||
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
|
||||
ignore::UserWarning:gym
|
||||
markers =
|
||||
slow: marks tests as slow (deselect with '-m "not slow"')
|
||||
|
||||
[pytype]
|
||||
inputs = sb3_contrib
|
||||
|
|
@ -24,6 +26,7 @@ per-file-ignores =
|
|||
./sb3_contrib/__init__.py:F401
|
||||
./sb3_contrib/ars/__init__.py:F401
|
||||
./sb3_contrib/ppo_mask/__init__.py:F401
|
||||
./sb3_contrib/ppo_recurrent/__init__.py:F401
|
||||
./sb3_contrib/qrdqn/__init__.py:F401
|
||||
./sb3_contrib/tqc/__init__.py:F401
|
||||
./sb3_contrib/trpo/__init__.py:F401
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from stable_baselines3.common.env_util import make_vec_env
|
|||
from stable_baselines3.common.noise import NormalActionNoise
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
from sb3_contrib import ARS, QRDQN, TQC
|
||||
from sb3_contrib import ARS, QRDQN, TQC, RecurrentPPO
|
||||
from sb3_contrib.common.vec_env import AsyncEval
|
||||
|
||||
N_STEPS_TRAINING = 500
|
||||
|
|
@ -11,7 +11,7 @@ SEED = 0
|
|||
ARS_MULTI = "ars_multi"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI])
|
||||
@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI, RecurrentPPO])
|
||||
def test_deterministic_training_common(algo):
|
||||
results = [[], []]
|
||||
rewards = [[], []]
|
||||
|
|
@ -32,9 +32,12 @@ def test_deterministic_training_common(algo):
|
|||
kwargs.update({"learning_starts": 100, "target_update_interval": 100})
|
||||
elif algo == ARS:
|
||||
kwargs.update({"n_delta": 2})
|
||||
|
||||
elif algo == RecurrentPPO:
|
||||
kwargs.update({"policy_kwargs": dict(net_arch=[], enable_critic_lstm=True, lstm_hidden_size=8)})
|
||||
kwargs.update({"n_steps": 50, "n_epochs": 4})
|
||||
policy_str = "MlpLstmPolicy" if algo == RecurrentPPO else "MlpPolicy"
|
||||
for i in range(2):
|
||||
model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
|
||||
model = algo(policy_str, env_id, seed=SEED, **kwargs)
|
||||
|
||||
learn_kwargs = {"total_timesteps": N_STEPS_TRAINING}
|
||||
if ars_multi:
|
||||
|
|
@ -46,9 +49,11 @@ def test_deterministic_training_common(algo):
|
|||
model.learn(**learn_kwargs)
|
||||
env = model.get_env()
|
||||
obs = env.reset()
|
||||
states = None
|
||||
episode_start = None
|
||||
for _ in range(100):
|
||||
action, _ = model.predict(obs, deterministic=False)
|
||||
obs, reward, _, _ = env.step(action)
|
||||
action, states = model.predict(obs, state=states, episode_start=episode_start, deterministic=False)
|
||||
obs, reward, episode_start, _ = env.step(action)
|
||||
results[i].append(action)
|
||||
rewards[i].append(reward)
|
||||
assert sum(results[0]) == sum(results[1]), results
|
||||
|
|
|
|||
|
|
@ -0,0 +1,186 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gym import spaces
|
||||
from gym.envs.classic_control import CartPoleEnv
|
||||
from gym.wrappers.time_limit import TimeLimit
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.envs import FakeImageEnv
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
from sb3_contrib import RecurrentPPO
|
||||
|
||||
|
||||
class ToDictWrapper(gym.Wrapper):
|
||||
"""
|
||||
Simple wrapper to test MultInputPolicy on Dict obs.
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space})
|
||||
|
||||
def reset(self):
|
||||
return {"obs": self.env.reset()}
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, infos = self.env.step(action)
|
||||
return {"obs": obs}, reward, done, infos
|
||||
|
||||
|
||||
class CartPoleNoVelEnv(CartPoleEnv):
|
||||
"""Variant of CartPoleEnv with velocity information removed. This task requires memory to solve."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
high = np.array(
|
||||
[
|
||||
self.x_threshold * 2,
|
||||
self.theta_threshold_radians * 2,
|
||||
]
|
||||
)
|
||||
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
|
||||
|
||||
@staticmethod
|
||||
def _pos_obs(full_obs):
|
||||
xpos, _xvel, thetapos, _thetavel = full_obs
|
||||
return xpos, thetapos
|
||||
|
||||
def reset(self):
|
||||
full_obs = super().reset()
|
||||
return CartPoleNoVelEnv._pos_obs(full_obs)
|
||||
|
||||
def step(self, action):
|
||||
full_obs, rew, done, info = super().step(action)
|
||||
return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info
|
||||
|
||||
|
||||
def test_cnn():
|
||||
model = RecurrentPPO(
|
||||
"CnnLstmPolicy",
|
||||
FakeImageEnv(screen_height=40, screen_width=40, n_channels=3),
|
||||
n_steps=16,
|
||||
seed=0,
|
||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"policy_kwargs",
|
||||
[
|
||||
{},
|
||||
dict(shared_lstm=True, enable_critic_lstm=False),
|
||||
dict(
|
||||
enable_critic_lstm=True,
|
||||
lstm_hidden_size=4,
|
||||
lstm_kwargs=dict(dropout=0.5),
|
||||
),
|
||||
dict(
|
||||
enable_critic_lstm=False,
|
||||
lstm_hidden_size=4,
|
||||
lstm_kwargs=dict(dropout=0.5),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_policy_kwargs(policy_kwargs):
|
||||
model = RecurrentPPO(
|
||||
"MlpLstmPolicy",
|
||||
"CartPole-v1",
|
||||
n_steps=16,
|
||||
seed=0,
|
||||
policy_kwargs=policy_kwargs,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=32)
|
||||
|
||||
|
||||
def test_check():
|
||||
policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=True)
|
||||
with pytest.raises(AssertionError):
|
||||
RecurrentPPO(
|
||||
"MlpLstmPolicy",
|
||||
"CartPole-v1",
|
||||
n_steps=16,
|
||||
seed=0,
|
||||
policy_kwargs=policy_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
||||
def test_run(env):
|
||||
model = RecurrentPPO(
|
||||
"MlpLstmPolicy",
|
||||
env,
|
||||
n_steps=16,
|
||||
seed=0,
|
||||
create_eval_env=True,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=32, eval_freq=16)
|
||||
|
||||
|
||||
def test_run_sde():
|
||||
model = RecurrentPPO(
|
||||
"MlpLstmPolicy",
|
||||
"Pendulum-v1",
|
||||
n_steps=16,
|
||||
seed=0,
|
||||
create_eval_env=True,
|
||||
sde_sample_freq=4,
|
||||
use_sde=True,
|
||||
clip_range_vf=0.1,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=200, eval_freq=150)
|
||||
|
||||
|
||||
def test_dict_obs():
|
||||
env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper)
|
||||
model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32).learn(64)
|
||||
evaluate_policy(model, env, warn=False)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_ppo_lstm_performance():
|
||||
# env = make_vec_env("CartPole-v1", n_envs=16)
|
||||
def make_env():
|
||||
env = CartPoleNoVelEnv()
|
||||
env = TimeLimit(env, max_episode_steps=500)
|
||||
return env
|
||||
|
||||
env = VecNormalize(make_vec_env(make_env, n_envs=8))
|
||||
|
||||
eval_callback = EvalCallback(
|
||||
VecNormalize(make_vec_env(make_env, n_envs=4), training=False, norm_reward=False),
|
||||
n_eval_episodes=20,
|
||||
eval_freq=5000 // env.num_envs,
|
||||
)
|
||||
|
||||
model = RecurrentPPO(
|
||||
"MlpLstmPolicy",
|
||||
env,
|
||||
n_steps=128,
|
||||
learning_rate=0.0007,
|
||||
verbose=1,
|
||||
batch_size=256,
|
||||
seed=1,
|
||||
n_epochs=10,
|
||||
max_grad_norm=1,
|
||||
gae_lambda=0.98,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[dict(vf=[64])],
|
||||
lstm_hidden_size=64,
|
||||
ortho_init=False,
|
||||
enable_critic_lstm=True,
|
||||
),
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=50_000, callback=eval_callback)
|
||||
# Maximum episode reward is 500.
|
||||
# In CartPole-v1, a non-recurrent policy can easily get >= 450.
|
||||
# In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50.
|
||||
evaluate_policy(model, env, reward_threshold=450)
|
||||
Loading…
Reference in New Issue