from typing import Generator, NamedTuple, Optional import numpy as np import torch as th from gym import spaces from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.vec_env import VecNormalize class MaskableRolloutBufferSamples(NamedTuple): observations: th.Tensor actions: th.Tensor old_values: th.Tensor old_log_prob: th.Tensor advantages: th.Tensor returns: th.Tensor action_masks: th.Tensor class MaskableRolloutBuffer(RolloutBuffer): """ Rollout buffer that also stores the invalid action masks associated with each observation. :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space :param device: :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. :param gamma: Discount factor :param n_envs: Number of parallel environments """ def __init__(self, *args, **kwargs): self.action_masks = None super().__init__(*args, **kwargs) def reset(self) -> None: if isinstance(self.action_space, spaces.Discrete): mask_dims = self.action_space.n elif isinstance(self.action_space, spaces.MultiDiscrete): mask_dims = sum(self.action_space.nvec) elif isinstance(self.action_space, spaces.MultiBinary): mask_dims = 2 * self.action_space.n # One mask per binary outcome else: raise ValueError(f"Unsupported action space {type(self.action_space)}") self.mask_dims = mask_dims self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32) super().reset() def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: """ :param action_masks: Masks applied to constrain the choice of possible actions. """ if action_masks is not None: self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims)) super().add(*args, **kwargs) def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: assert self.full, "" indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data if not self.generator_ready: for tensor in [ "observations", "actions", "values", "log_probs", "advantages", "returns", "action_masks", ]: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.generator_ready = True # Return everything, don't create minibatches if batch_size is None: batch_size = self.buffer_size * self.n_envs start_idx = 0 while start_idx < self.buffer_size * self.n_envs: yield self._get_samples(indices[start_idx : start_idx + batch_size]) start_idx += batch_size def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: data = ( self.observations[batch_inds], self.actions[batch_inds], self.values[batch_inds].flatten(), self.log_probs[batch_inds].flatten(), self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten(), self.action_masks[batch_inds].reshape(-1, self.mask_dims), ) return MaskableRolloutBufferSamples(*map(self.to_torch, data))