stable-baselines3-contrib-sacd/sb3_contrib/common/maskable/buffers.py

226 lines
9.0 KiB
Python

from typing import Generator, NamedTuple, Optional, Union
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.vec_env import VecNormalize
class MaskableRolloutBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
action_masks: th.Tensor
class MaskableDictRolloutBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
action_masks: th.Tensor
class MaskableRolloutBuffer(RolloutBuffer):
"""
Rollout buffer that also stores the invalid action masks associated with each observation.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
action_masks: np.ndarray
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
def reset(self) -> None:
if isinstance(self.action_space, spaces.Discrete):
mask_dims = self.action_space.n
elif isinstance(self.action_space, spaces.MultiDiscrete):
mask_dims = sum(self.action_space.nvec)
elif isinstance(self.action_space, spaces.MultiBinary):
assert isinstance(self.action_space.n, int), (
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
"You can flatten it instead."
)
mask_dims = 2 * self.action_space.n # One mask per binary outcome
else:
raise ValueError(f"Unsupported action space {type(self.action_space)}")
self.mask_dims = mask_dims
self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
super().reset()
def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
"""
:param action_masks: Masks applied to constrain the choice of possible actions.
"""
if action_masks is not None:
self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: # type: ignore[override]
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for tensor in [
"observations",
"actions",
"values",
"log_probs",
"advantages",
"returns",
"action_masks",
]:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: # type: ignore[override]
data = (
self.observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
self.action_masks[batch_inds].reshape(-1, self.mask_dims),
)
return MaskableRolloutBufferSamples(*map(self.to_torch, data))
class MaskableDictRolloutBuffer(DictRolloutBuffer):
"""
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
Extends the RolloutBuffer to use dictionary observations
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
action_masks: np.ndarray
def __init__(
self,
buffer_size: int,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
def reset(self) -> None:
if isinstance(self.action_space, spaces.Discrete):
mask_dims = self.action_space.n
elif isinstance(self.action_space, spaces.MultiDiscrete):
mask_dims = sum(self.action_space.nvec)
elif isinstance(self.action_space, spaces.MultiBinary):
assert isinstance(self.action_space.n, int), (
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
"You can flatten it instead."
)
mask_dims = 2 * self.action_space.n # One mask per binary outcome
else:
raise ValueError(f"Unsupported action space {type(self.action_space)}")
self.mask_dims = mask_dims
self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
super().reset()
def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
"""
:param action_masks: Masks applied to constrain the choice of possible actions.
"""
if action_masks is not None:
self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: # type: ignore[override]
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns", "action_masks"]
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: # type: ignore[override]
return MaskableDictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),
old_values=self.to_torch(self.values[batch_inds].flatten()),
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
returns=self.to_torch(self.returns[batch_inds].flatten()),
action_masks=self.to_torch(self.action_masks[batch_inds].reshape(-1, self.mask_dims)),
)