385 lines
17 KiB
Python
385 lines
17 KiB
Python
from functools import partial
|
|
from typing import Callable, Generator, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from gym import spaces
|
|
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
|
from stable_baselines3.common.vec_env import VecNormalize
|
|
|
|
from sb3_contrib.common.recurrent.type_aliases import (
|
|
RecurrentDictRolloutBufferSamples,
|
|
RecurrentRolloutBufferSamples,
|
|
RNNStates,
|
|
)
|
|
|
|
|
|
def pad(
|
|
seq_start_indices: np.ndarray,
|
|
seq_end_indices: np.ndarray,
|
|
device: th.device,
|
|
tensor: np.ndarray,
|
|
padding_value: float = 0.0,
|
|
) -> th.Tensor:
|
|
"""
|
|
Chunk sequences and pad them to have constant dimensions.
|
|
|
|
:param seq_start_indices: Indices of the transitions that start a sequence
|
|
:param seq_end_indices: Indices of the transitions that end a sequence
|
|
:param device: PyTorch device
|
|
:param tensor: Tensor of shape (batch_size, *tensor_shape)
|
|
:param padding_value: Value used to pad sequence to the same length
|
|
(zero padding by default)
|
|
:return: (n_seq, max_length, *tensor_shape)
|
|
"""
|
|
# Create sequences given start and end
|
|
seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)]
|
|
return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value)
|
|
|
|
|
|
def pad_and_flatten(
|
|
seq_start_indices: np.ndarray,
|
|
seq_end_indices: np.ndarray,
|
|
device: th.device,
|
|
tensor: np.ndarray,
|
|
padding_value: float = 0.0,
|
|
) -> th.Tensor:
|
|
"""
|
|
Pad and flatten the sequences of scalar values,
|
|
while keeping the sequence order.
|
|
From (batch_size, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,)
|
|
|
|
:param seq_start_indices: Indices of the transitions that start a sequence
|
|
:param seq_end_indices: Indices of the transitions that end a sequence
|
|
:param device: PyTorch device (cpu, gpu, ...)
|
|
:param tensor: Tensor of shape (max_length, n_seq, 1)
|
|
:param padding_value: Value used to pad sequence to the same length
|
|
(zero padding by default)
|
|
:return: (n_seq * max_length,) aka (padded_batch_size,)
|
|
"""
|
|
return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten()
|
|
|
|
|
|
def create_sequencers(
|
|
episode_starts: np.ndarray,
|
|
env_change: np.ndarray,
|
|
device: th.device,
|
|
) -> Tuple[np.ndarray, Callable, Callable]:
|
|
"""
|
|
Create the utility function to chunk data into
|
|
sequences and pad them to create fixed size tensors.
|
|
|
|
:param episode_starts: Indices where an episode starts
|
|
:param env_change: Indices where the data collected
|
|
come from a different env (when using multiple env for data collection)
|
|
:param device: PyTorch device
|
|
:return: Indices of the transitions that start a sequence,
|
|
pad and pad_and_flatten utilities tailored for this batch
|
|
(sequence starts and ends indices are fixed)
|
|
"""
|
|
# Create sequence if env changes too
|
|
seq_start = np.logical_or(episode_starts, env_change).flatten()
|
|
# First index is always the beginning of a sequence
|
|
seq_start[0] = True
|
|
# Retrieve indices of sequence starts
|
|
seq_start_indices = np.where(seq_start == True)[0] # noqa: E712
|
|
# End of sequence are just before sequence starts
|
|
# Last index is also always end of a sequence
|
|
seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])])
|
|
|
|
# Create padding method for this minibatch
|
|
# to avoid repeating arguments (seq_start_indices, seq_end_indices)
|
|
local_pad = partial(pad, seq_start_indices, seq_end_indices, device)
|
|
local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device)
|
|
return seq_start_indices, local_pad, local_pad_and_flatten
|
|
|
|
|
|
class RecurrentRolloutBuffer(RolloutBuffer):
|
|
"""
|
|
Rollout buffer that also stores the LSTM cell and hidden states.
|
|
|
|
:param buffer_size: Max number of element in the buffer
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param hidden_state_shape: Shape of the buffer that will collect lstm states
|
|
(n_steps, lstm.num_layers, n_envs, lstm.hidden_size)
|
|
:param device: PyTorch device
|
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
|
Equivalent to classic advantage when set to 1.
|
|
:param gamma: Discount factor
|
|
:param n_envs: Number of parallel environments
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
buffer_size: int,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
hidden_state_shape: Tuple[int, int, int, int],
|
|
device: Union[th.device, str] = "auto",
|
|
gae_lambda: float = 1,
|
|
gamma: float = 0.99,
|
|
n_envs: int = 1,
|
|
):
|
|
self.hidden_state_shape = hidden_state_shape
|
|
self.seq_start_indices, self.seq_end_indices = None, None
|
|
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
|
|
|
|
def reset(self):
|
|
super().reset()
|
|
self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
|
self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
|
self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
|
self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
|
|
|
def add(self, *args, lstm_states: RNNStates, **kwargs) -> None:
|
|
"""
|
|
:param hidden_states: LSTM cell and hidden state
|
|
"""
|
|
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
|
|
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
|
|
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
|
|
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
|
|
|
|
super().add(*args, **kwargs)
|
|
|
|
def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]:
|
|
assert self.full, "Rollout buffer must be full before sampling from it"
|
|
|
|
# Prepare the data
|
|
if not self.generator_ready:
|
|
# hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)
|
|
# swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size)
|
|
for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]:
|
|
self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2)
|
|
|
|
# flatten but keep the sequence order
|
|
# 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape)
|
|
# 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape)
|
|
for tensor in [
|
|
"observations",
|
|
"actions",
|
|
"values",
|
|
"log_probs",
|
|
"advantages",
|
|
"returns",
|
|
"hidden_states_pi",
|
|
"cell_states_pi",
|
|
"hidden_states_vf",
|
|
"cell_states_vf",
|
|
"episode_starts",
|
|
]:
|
|
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
|
self.generator_ready = True
|
|
|
|
# Return everything, don't create minibatches
|
|
if batch_size is None:
|
|
batch_size = self.buffer_size * self.n_envs
|
|
|
|
# Sampling strategy that allows any mini batch size but requires
|
|
# more complexity and use of padding
|
|
# Trick to shuffle a bit: keep the sequence order
|
|
# but split the indices in two
|
|
split_index = np.random.randint(self.buffer_size * self.n_envs)
|
|
indices = np.arange(self.buffer_size * self.n_envs)
|
|
indices = np.concatenate((indices[split_index:], indices[:split_index]))
|
|
|
|
env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs)
|
|
# Flag first timestep as change of environment
|
|
env_change[0, :] = 1.0
|
|
env_change = self.swap_and_flatten(env_change)
|
|
|
|
start_idx = 0
|
|
while start_idx < self.buffer_size * self.n_envs:
|
|
batch_inds = indices[start_idx : start_idx + batch_size]
|
|
yield self._get_samples(batch_inds, env_change)
|
|
start_idx += batch_size
|
|
|
|
def _get_samples(
|
|
self,
|
|
batch_inds: np.ndarray,
|
|
env_change: np.ndarray,
|
|
env: Optional[VecNormalize] = None,
|
|
) -> RecurrentRolloutBufferSamples:
|
|
# Retrieve sequence starts and utility function
|
|
self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers(
|
|
self.episode_starts[batch_inds], env_change[batch_inds], self.device
|
|
)
|
|
|
|
# Number of sequences
|
|
n_seq = len(self.seq_start_indices)
|
|
max_length = self.pad(self.actions[batch_inds]).shape[1]
|
|
padded_batch_size = n_seq * max_length
|
|
# We retrieve the lstm hidden states that will allow
|
|
# to properly initialize the LSTM at the beginning of each sequence
|
|
lstm_states_pi = (
|
|
# 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim)
|
|
# 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim)
|
|
# 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim)
|
|
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
|
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
|
)
|
|
lstm_states_vf = (
|
|
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
|
|
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
|
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
|
)
|
|
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
|
|
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())
|
|
|
|
return RecurrentRolloutBufferSamples(
|
|
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
|
|
observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size,) + self.obs_shape),
|
|
actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]),
|
|
old_values=self.pad_and_flatten(self.values[batch_inds]),
|
|
old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]),
|
|
advantages=self.pad_and_flatten(self.advantages[batch_inds]),
|
|
returns=self.pad_and_flatten(self.returns[batch_inds]),
|
|
lstm_states=RNNStates(lstm_states_pi, lstm_states_vf),
|
|
episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]),
|
|
mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])),
|
|
)
|
|
|
|
|
|
class RecurrentDictRolloutBuffer(DictRolloutBuffer):
|
|
"""
|
|
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
|
|
Extends the RecurrentRolloutBuffer to use dictionary observations
|
|
|
|
:param buffer_size: Max number of element in the buffer
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param hidden_state_shape: Shape of the buffer that will collect lstm states
|
|
:param device: PyTorch device
|
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
|
Equivalent to classic advantage when set to 1.
|
|
:param gamma: Discount factor
|
|
:param n_envs: Number of parallel environments
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
buffer_size: int,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
hidden_state_shape: Tuple[int, int, int, int],
|
|
device: Union[th.device, str] = "auto",
|
|
gae_lambda: float = 1,
|
|
gamma: float = 0.99,
|
|
n_envs: int = 1,
|
|
):
|
|
self.hidden_state_shape = hidden_state_shape
|
|
self.seq_start_indices, self.seq_end_indices = None, None
|
|
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
|
|
|
|
def reset(self):
|
|
super().reset()
|
|
self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
|
self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
|
self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
|
self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
|
|
|
|
def add(self, *args, lstm_states: RNNStates, **kwargs) -> None:
|
|
"""
|
|
:param hidden_states: LSTM cell and hidden state
|
|
"""
|
|
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
|
|
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
|
|
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
|
|
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
|
|
|
|
super().add(*args, **kwargs)
|
|
|
|
def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]:
|
|
assert self.full, "Rollout buffer must be full before sampling from it"
|
|
|
|
# Prepare the data
|
|
if not self.generator_ready:
|
|
# hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)
|
|
# swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size)
|
|
for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]:
|
|
self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2)
|
|
|
|
for key, obs in self.observations.items():
|
|
self.observations[key] = self.swap_and_flatten(obs)
|
|
|
|
for tensor in [
|
|
"actions",
|
|
"values",
|
|
"log_probs",
|
|
"advantages",
|
|
"returns",
|
|
"hidden_states_pi",
|
|
"cell_states_pi",
|
|
"hidden_states_vf",
|
|
"cell_states_vf",
|
|
"episode_starts",
|
|
]:
|
|
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
|
self.generator_ready = True
|
|
|
|
# Return everything, don't create minibatches
|
|
if batch_size is None:
|
|
batch_size = self.buffer_size * self.n_envs
|
|
|
|
# Trick to shuffle a bit: keep the sequence order
|
|
# but split the indices in two
|
|
split_index = np.random.randint(self.buffer_size * self.n_envs)
|
|
indices = np.arange(self.buffer_size * self.n_envs)
|
|
indices = np.concatenate((indices[split_index:], indices[:split_index]))
|
|
|
|
env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs)
|
|
# Flag first timestep as change of environment
|
|
env_change[0, :] = 1.0
|
|
env_change = self.swap_and_flatten(env_change)
|
|
|
|
start_idx = 0
|
|
while start_idx < self.buffer_size * self.n_envs:
|
|
batch_inds = indices[start_idx : start_idx + batch_size]
|
|
yield self._get_samples(batch_inds, env_change)
|
|
start_idx += batch_size
|
|
|
|
def _get_samples(
|
|
self,
|
|
batch_inds: np.ndarray,
|
|
env_change: np.ndarray,
|
|
env: Optional[VecNormalize] = None,
|
|
) -> RecurrentDictRolloutBufferSamples:
|
|
# Retrieve sequence starts and utility function
|
|
self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers(
|
|
self.episode_starts[batch_inds], env_change[batch_inds], self.device
|
|
)
|
|
|
|
n_seq = len(self.seq_start_indices)
|
|
max_length = self.pad(self.actions[batch_inds]).shape[1]
|
|
padded_batch_size = n_seq * max_length
|
|
# We retrieve the lstm hidden states that will allow
|
|
# to properly initialize the LSTM at the beginning of each sequence
|
|
lstm_states_pi = (
|
|
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
|
|
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
|
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
|
)
|
|
lstm_states_vf = (
|
|
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
|
|
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
|
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
|
)
|
|
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
|
|
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())
|
|
|
|
observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()}
|
|
observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()}
|
|
|
|
return RecurrentDictRolloutBufferSamples(
|
|
observations=observations,
|
|
actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]),
|
|
old_values=self.pad_and_flatten(self.values[batch_inds]),
|
|
old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]),
|
|
advantages=self.pad_and_flatten(self.advantages[batch_inds]),
|
|
returns=self.pad_and_flatten(self.returns[batch_inds]),
|
|
lstm_states=RNNStates(lstm_states_pi, lstm_states_vf),
|
|
episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]),
|
|
mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])),
|
|
)
|