stable-baselines3-contrib-sacd/sb3_contrib/common/envs/invalid_actions_env.py

117 lines
4.2 KiB
Python

from typing import List, Optional
import numpy as np
from gym import spaces
from stable_baselines3.common.envs import IdentityEnv
class InvalidActionEnvDiscrete(IdentityEnv):
"""
Identity env with a discrete action space. Supports action masking.
"""
def __init__(
self,
dim: Optional[int] = None,
ep_length: int = 100,
n_invalid_actions: int = 0,
):
if dim is None:
dim = 1
assert n_invalid_actions < dim, f"Too many invalid actions: {n_invalid_actions} < {dim}"
space = spaces.Discrete(dim)
self.n_invalid_actions = n_invalid_actions
self.possible_actions = np.arange(space.n)
self.invalid_actions: List[int] = []
super().__init__(space=space, ep_length=ep_length)
def _choose_next_state(self) -> None:
self.state = self.action_space.sample()
# Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i != self.state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions]
class InvalidActionEnvMultiDiscrete(IdentityEnv):
"""
Identity env with a multidiscrete action space. Supports action masking.
"""
def __init__(
self,
dims: Optional[List[int]] = None,
ep_length: int = 100,
n_invalid_actions: int = 0,
):
if dims is None:
dims = [1, 1]
if n_invalid_actions > sum(dims) - len(dims):
raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {sum(dims) - len(dims)}")
space = spaces.MultiDiscrete(dims)
self.n_invalid_actions = n_invalid_actions
self.possible_actions = np.arange(sum(dims))
self.invalid_actions: List[int] = []
super().__init__(space=space, ep_length=ep_length)
def _choose_next_state(self) -> None:
self.state = self.action_space.sample()
converted_state: List[int] = []
running_total = 0
for i in range(len(self.action_space.nvec)):
converted_state.append(running_total + self.state[i])
running_total += self.action_space.nvec[i]
# Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions]
class InvalidActionEnvMultiBinary(IdentityEnv):
"""
Identity env with a multibinary action space. Supports action masking.
"""
def __init__(
self,
dims: Optional[int] = None,
ep_length: int = 100,
n_invalid_actions: int = 0,
):
if dims is None:
dims = 1
if n_invalid_actions > dims:
raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}")
space = spaces.MultiBinary(dims)
self.n_invalid_actions = n_invalid_actions
self.possible_actions = np.arange(2 * dims)
self.invalid_actions: List[int] = []
super().__init__(space=space, ep_length=ep_length)
def _choose_next_state(self) -> None:
self.state = self.action_space.sample()
converted_state: List[int] = []
running_total = 0
for i in range(self.action_space.n):
converted_state.append(running_total + self.state[i])
running_total += 2
# Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions]