282 lines
11 KiB
Python
282 lines
11 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import List, Optional, Tuple, TypeVar
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from gymnasium import spaces
|
|
from stable_baselines3.common.distributions import Distribution
|
|
from torch import nn
|
|
from torch.distributions import Categorical
|
|
from torch.distributions.utils import logits_to_probs
|
|
|
|
SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistribution", bound="MaskableCategoricalDistribution")
|
|
SelfMaskableMultiCategoricalDistribution = TypeVar(
|
|
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
|
|
)
|
|
|
|
|
|
class MaskableCategorical(Categorical):
|
|
"""
|
|
Modified PyTorch Categorical distribution with support for invalid action masking.
|
|
|
|
To instantiate, must provide either probs or logits, but not both.
|
|
|
|
:param probs: Tensor containing finite non-negative values, which will be renormalized
|
|
to sum to 1 along the last dimension.
|
|
:param logits: Tensor of unnormalized log probabilities.
|
|
:param validate_args: Whether or not to validate that arguments to methods like lob_prob()
|
|
and icdf() match the distribution's shape, support, etc.
|
|
:param masks: An optional boolean ndarray of compatible shape with the distribution.
|
|
If True, the corresponding choice's logit value is preserved. If False, it is set to a
|
|
large negative value, resulting in near 0 probability.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
probs: Optional[th.Tensor] = None,
|
|
logits: Optional[th.Tensor] = None,
|
|
validate_args: Optional[bool] = None,
|
|
masks: Optional[np.ndarray] = None,
|
|
):
|
|
self.masks: Optional[th.Tensor] = None
|
|
super().__init__(probs, logits, validate_args)
|
|
self._original_logits = self.logits
|
|
self.apply_masking(masks)
|
|
|
|
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
|
"""
|
|
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
|
|
|
|
:param masks: An optional boolean ndarray of compatible shape with the distribution.
|
|
If True, the corresponding choice's logit value is preserved. If False, it is set
|
|
to a large negative value, resulting in near 0 probability. If masks is None, any
|
|
previously applied masking is removed, and the original logits are restored.
|
|
"""
|
|
|
|
if masks is not None:
|
|
device = self.logits.device
|
|
self.masks = th.as_tensor(masks, dtype=th.bool, device=device).reshape(self.logits.shape)
|
|
HUGE_NEG = th.tensor(-1e8, dtype=self.logits.dtype, device=device)
|
|
|
|
logits = th.where(self.masks, self._original_logits, HUGE_NEG)
|
|
else:
|
|
self.masks = None
|
|
logits = self._original_logits
|
|
|
|
# Reinitialize with updated logits
|
|
super().__init__(logits=logits)
|
|
|
|
# self.probs may already be cached, so we must force an update
|
|
self.probs = logits_to_probs(self.logits)
|
|
|
|
def entropy(self) -> th.Tensor:
|
|
if self.masks is None:
|
|
return super().entropy()
|
|
|
|
# Highly negative logits don't result in 0 probs, so we must replace
|
|
# with 0s to ensure 0 contribution to the distribution's entropy, since
|
|
# masked actions possess no uncertainty.
|
|
device = self.logits.device
|
|
p_log_p = self.logits * self.probs
|
|
p_log_p = th.where(self.masks, p_log_p, th.tensor(0.0, device=device))
|
|
return -p_log_p.sum(-1)
|
|
|
|
|
|
class MaskableDistribution(Distribution, ABC):
|
|
@abstractmethod
|
|
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
|
"""
|
|
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.
|
|
|
|
:param masks: An optional boolean ndarray of compatible shape with the distribution.
|
|
If True, the corresponding choice's logit value is preserved. If False, it is set
|
|
to a large negative value, resulting in near 0 probability. If masks is None, any
|
|
previously applied masking is removed, and the original logits are restored.
|
|
"""
|
|
|
|
|
|
class MaskableCategoricalDistribution(MaskableDistribution):
|
|
"""
|
|
Categorical distribution for discrete actions. Supports invalid action masking.
|
|
|
|
:param action_dim: Number of discrete actions
|
|
"""
|
|
|
|
def __init__(self, action_dim: int):
|
|
super().__init__()
|
|
self.distribution: Optional[MaskableCategorical] = None
|
|
self.action_dim = action_dim
|
|
|
|
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
|
"""
|
|
Create the layer that represents the distribution:
|
|
it will be the logits of the Categorical distribution.
|
|
You can then get probabilities using a softmax.
|
|
|
|
:param latent_dim: Dimension of the last layer
|
|
of the policy network (before the action layer)
|
|
:return:
|
|
"""
|
|
action_logits = nn.Linear(latent_dim, self.action_dim)
|
|
return action_logits
|
|
|
|
def proba_distribution(
|
|
self: SelfMaskableCategoricalDistribution, action_logits: th.Tensor
|
|
) -> SelfMaskableCategoricalDistribution:
|
|
# Restructure shape to align with logits
|
|
reshaped_logits = action_logits.view(-1, self.action_dim)
|
|
self.distribution = MaskableCategorical(logits=reshaped_logits)
|
|
return self
|
|
|
|
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
|
assert self.distribution is not None, "Must set distribution parameters"
|
|
return self.distribution.log_prob(actions)
|
|
|
|
def entropy(self) -> th.Tensor:
|
|
assert self.distribution is not None, "Must set distribution parameters"
|
|
return self.distribution.entropy()
|
|
|
|
def sample(self) -> th.Tensor:
|
|
assert self.distribution is not None, "Must set distribution parameters"
|
|
return self.distribution.sample()
|
|
|
|
def mode(self) -> th.Tensor:
|
|
assert self.distribution is not None, "Must set distribution parameters"
|
|
return th.argmax(self.distribution.probs, dim=1)
|
|
|
|
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
|
# Update the proba distribution
|
|
self.proba_distribution(action_logits)
|
|
return self.get_actions(deterministic=deterministic)
|
|
|
|
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
|
actions = self.actions_from_params(action_logits)
|
|
log_prob = self.log_prob(actions)
|
|
return actions, log_prob
|
|
|
|
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
|
assert self.distribution is not None, "Must set distribution parameters"
|
|
self.distribution.apply_masking(masks)
|
|
|
|
|
|
class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
|
"""
|
|
MultiCategorical distribution for multi discrete actions. Supports invalid action masking.
|
|
|
|
:param action_dims: List of sizes of discrete action spaces
|
|
"""
|
|
|
|
def __init__(self, action_dims: List[int]):
|
|
super().__init__()
|
|
self.distributions: List[MaskableCategorical] = []
|
|
self.action_dims = action_dims
|
|
|
|
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
|
"""
|
|
Create the layer that represents the distribution:
|
|
it will be the logits (flattened) of the MultiCategorical distribution.
|
|
You can then get probabilities using a softmax on each sub-space.
|
|
|
|
:param latent_dim: Dimension of the last layer
|
|
of the policy network (before the action layer)
|
|
:return:
|
|
"""
|
|
|
|
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
|
|
return action_logits
|
|
|
|
def proba_distribution(
|
|
self: SelfMaskableMultiCategoricalDistribution, action_logits: th.Tensor
|
|
) -> SelfMaskableMultiCategoricalDistribution:
|
|
# Restructure shape to align with logits
|
|
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
|
|
|
|
self.distributions = [
|
|
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1)
|
|
]
|
|
return self
|
|
|
|
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
|
assert len(self.distributions) > 0, "Must set distribution parameters"
|
|
|
|
# Restructure shape to align with each categorical
|
|
actions = actions.view(-1, len(self.action_dims))
|
|
|
|
# Extract each discrete action and compute log prob for their respective distributions
|
|
return th.stack(
|
|
[dist.log_prob(action) for dist, action in zip(self.distributions, th.unbind(actions, dim=1))], dim=1
|
|
).sum(dim=1)
|
|
|
|
def entropy(self) -> th.Tensor:
|
|
assert len(self.distributions) > 0, "Must set distribution parameters"
|
|
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
|
|
|
|
def sample(self) -> th.Tensor:
|
|
assert len(self.distributions) > 0, "Must set distribution parameters"
|
|
return th.stack([dist.sample() for dist in self.distributions], dim=1)
|
|
|
|
def mode(self) -> th.Tensor:
|
|
assert len(self.distributions) > 0, "Must set distribution parameters"
|
|
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
|
|
|
|
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
|
# Update the proba distribution
|
|
self.proba_distribution(action_logits)
|
|
return self.get_actions(deterministic=deterministic)
|
|
|
|
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
|
actions = self.actions_from_params(action_logits)
|
|
log_prob = self.log_prob(actions)
|
|
return actions, log_prob
|
|
|
|
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
|
assert len(self.distributions) > 0, "Must set distribution parameters"
|
|
|
|
split_masks = [None] * len(self.distributions)
|
|
if masks is not None:
|
|
masks = th.as_tensor(masks)
|
|
|
|
# Restructure shape to align with logits
|
|
masks = masks.view(-1, sum(self.action_dims))
|
|
|
|
# Then split columnwise for each discrete action
|
|
split_masks = th.split(masks, tuple(self.action_dims), dim=1)
|
|
|
|
for distribution, mask in zip(self.distributions, split_masks):
|
|
distribution.apply_masking(mask)
|
|
|
|
|
|
class MaskableBernoulliDistribution(MaskableMultiCategoricalDistribution):
|
|
"""
|
|
Bernoulli distribution for multibinary actions. Supports invalid action masking.
|
|
|
|
:param action_dim: Number of binary actions
|
|
"""
|
|
|
|
def __init__(self, action_dim: int):
|
|
# Two states per binary action
|
|
action_dims = [2] * action_dim
|
|
super().__init__(action_dims)
|
|
|
|
|
|
def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistribution:
|
|
"""
|
|
Return an instance of Distribution for the correct type of action space
|
|
|
|
:param action_space: the input action space
|
|
:return: the appropriate Distribution object
|
|
"""
|
|
|
|
if isinstance(action_space, spaces.Discrete):
|
|
return MaskableCategoricalDistribution(action_space.n)
|
|
elif isinstance(action_space, spaces.MultiDiscrete):
|
|
return MaskableMultiCategoricalDistribution(action_space.nvec)
|
|
elif isinstance(action_space, spaces.MultiBinary):
|
|
return MaskableBernoulliDistribution(action_space.n)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Error: probability distribution, not implemented for action space"
|
|
f"of type {type(action_space)}."
|
|
" Must be of type Gym Spaces: Discrete, MultiDiscrete."
|
|
)
|