stable-baselines3-contrib-sacd/sb3_contrib/common/maskable/buffers.py

99 lines
3.6 KiB
Python

from typing import Generator, NamedTuple, Optional
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.vec_env import VecNormalize
class MaskableRolloutBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
action_masks: th.Tensor
class MaskableRolloutBuffer(RolloutBuffer):
"""
Rollout buffer that also stores the invalid action masks associated with each observation.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
def __init__(self, *args, **kwargs):
self.action_masks = None
super().__init__(*args, **kwargs)
def reset(self) -> None:
if isinstance(self.action_space, spaces.Discrete):
mask_dims = self.action_space.n
elif isinstance(self.action_space, spaces.MultiDiscrete):
mask_dims = sum(self.action_space.nvec)
elif isinstance(self.action_space, spaces.MultiBinary):
mask_dims = 2 * self.action_space.n # One mask per binary outcome
else:
raise ValueError(f"Unsupported action space {type(self.action_space)}")
self.mask_dims = mask_dims
self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
super().reset()
def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
"""
:param action_masks: Masks applied to constrain the choice of possible actions.
"""
if action_masks is not None:
self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for tensor in [
"observations",
"actions",
"values",
"log_probs",
"advantages",
"returns",
"action_masks",
]:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
data = (
self.observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
self.action_masks[batch_inds].reshape(-1, self.mask_dims),
)
return MaskableRolloutBufferSamples(*map(self.to_torch, data))