99 lines
3.6 KiB
Python
99 lines
3.6 KiB
Python
from typing import Generator, NamedTuple, Optional
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from gym import spaces
|
|
from stable_baselines3.common.buffers import RolloutBuffer
|
|
from stable_baselines3.common.vec_env import VecNormalize
|
|
|
|
|
|
class MaskableRolloutBufferSamples(NamedTuple):
|
|
observations: th.Tensor
|
|
actions: th.Tensor
|
|
old_values: th.Tensor
|
|
old_log_prob: th.Tensor
|
|
advantages: th.Tensor
|
|
returns: th.Tensor
|
|
action_masks: th.Tensor
|
|
|
|
|
|
class MaskableRolloutBuffer(RolloutBuffer):
|
|
"""
|
|
Rollout buffer that also stores the invalid action masks associated with each observation.
|
|
|
|
:param buffer_size: Max number of element in the buffer
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param device:
|
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
|
Equivalent to classic advantage when set to 1.
|
|
:param gamma: Discount factor
|
|
:param n_envs: Number of parallel environments
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.action_masks = None
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def reset(self) -> None:
|
|
if isinstance(self.action_space, spaces.Discrete):
|
|
mask_dims = self.action_space.n
|
|
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
|
mask_dims = sum(self.action_space.nvec)
|
|
elif isinstance(self.action_space, spaces.MultiBinary):
|
|
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
|
else:
|
|
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
|
|
|
self.mask_dims = mask_dims
|
|
self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
|
|
|
|
super().reset()
|
|
|
|
def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
|
|
"""
|
|
:param action_masks: Masks applied to constrain the choice of possible actions.
|
|
"""
|
|
if action_masks is not None:
|
|
self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
|
|
|
|
super().add(*args, **kwargs)
|
|
|
|
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
|
|
assert self.full, ""
|
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
|
# Prepare the data
|
|
if not self.generator_ready:
|
|
for tensor in [
|
|
"observations",
|
|
"actions",
|
|
"values",
|
|
"log_probs",
|
|
"advantages",
|
|
"returns",
|
|
"action_masks",
|
|
]:
|
|
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
|
self.generator_ready = True
|
|
|
|
# Return everything, don't create minibatches
|
|
if batch_size is None:
|
|
batch_size = self.buffer_size * self.n_envs
|
|
|
|
start_idx = 0
|
|
while start_idx < self.buffer_size * self.n_envs:
|
|
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
|
start_idx += batch_size
|
|
|
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
|
|
data = (
|
|
self.observations[batch_inds],
|
|
self.actions[batch_inds],
|
|
self.values[batch_inds].flatten(),
|
|
self.log_probs[batch_inds].flatten(),
|
|
self.advantages[batch_inds].flatten(),
|
|
self.returns[batch_inds].flatten(),
|
|
self.action_masks[batch_inds].reshape(-1, self.mask_dims),
|
|
)
|
|
return MaskableRolloutBufferSamples(*map(self.to_torch, data))
|