117 lines
4.2 KiB
Python
117 lines
4.2 KiB
Python
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
from gym import spaces
|
|
from stable_baselines3.common.envs import IdentityEnv
|
|
|
|
|
|
class InvalidActionEnvDiscrete(IdentityEnv):
|
|
"""
|
|
Identity env with a discrete action space. Supports action masking.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: Optional[int] = None,
|
|
ep_length: int = 100,
|
|
n_invalid_actions: int = 0,
|
|
):
|
|
if dim is None:
|
|
dim = 1
|
|
assert n_invalid_actions < dim, f"Too many invalid actions: {n_invalid_actions} < {dim}"
|
|
|
|
space = spaces.Discrete(dim)
|
|
self.n_invalid_actions = n_invalid_actions
|
|
self.possible_actions = np.arange(space.n)
|
|
self.invalid_actions: List[int] = []
|
|
super().__init__(space=space, ep_length=ep_length)
|
|
|
|
def _choose_next_state(self) -> None:
|
|
self.state = self.action_space.sample()
|
|
# Randomly choose invalid actions that are not the current state
|
|
potential_invalid_actions = [i for i in self.possible_actions if i != self.state]
|
|
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
|
|
|
|
def action_masks(self) -> List[bool]:
|
|
return [action not in self.invalid_actions for action in self.possible_actions]
|
|
|
|
|
|
class InvalidActionEnvMultiDiscrete(IdentityEnv):
|
|
"""
|
|
Identity env with a multidiscrete action space. Supports action masking.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dims: Optional[List[int]] = None,
|
|
ep_length: int = 100,
|
|
n_invalid_actions: int = 0,
|
|
):
|
|
if dims is None:
|
|
dims = [1, 1]
|
|
|
|
if n_invalid_actions > sum(dims) - len(dims):
|
|
raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {sum(dims) - len(dims)}")
|
|
|
|
space = spaces.MultiDiscrete(dims)
|
|
self.n_invalid_actions = n_invalid_actions
|
|
self.possible_actions = np.arange(sum(dims))
|
|
self.invalid_actions: List[int] = []
|
|
super().__init__(space=space, ep_length=ep_length)
|
|
|
|
def _choose_next_state(self) -> None:
|
|
self.state = self.action_space.sample()
|
|
|
|
converted_state: List[int] = []
|
|
running_total = 0
|
|
for i in range(len(self.action_space.nvec)):
|
|
converted_state.append(running_total + self.state[i])
|
|
running_total += self.action_space.nvec[i]
|
|
|
|
# Randomly choose invalid actions that are not the current state
|
|
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
|
|
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
|
|
|
|
def action_masks(self) -> List[bool]:
|
|
return [action not in self.invalid_actions for action in self.possible_actions]
|
|
|
|
|
|
class InvalidActionEnvMultiBinary(IdentityEnv):
|
|
"""
|
|
Identity env with a multibinary action space. Supports action masking.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dims: Optional[int] = None,
|
|
ep_length: int = 100,
|
|
n_invalid_actions: int = 0,
|
|
):
|
|
if dims is None:
|
|
dims = 1
|
|
|
|
if n_invalid_actions > dims:
|
|
raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}")
|
|
|
|
space = spaces.MultiBinary(dims)
|
|
self.n_invalid_actions = n_invalid_actions
|
|
self.possible_actions = np.arange(2 * dims)
|
|
self.invalid_actions: List[int] = []
|
|
super().__init__(space=space, ep_length=ep_length)
|
|
|
|
def _choose_next_state(self) -> None:
|
|
self.state = self.action_space.sample()
|
|
|
|
converted_state: List[int] = []
|
|
running_total = 0
|
|
for i in range(self.action_space.n):
|
|
converted_state.append(running_total + self.state[i])
|
|
running_total += 2
|
|
|
|
# Randomly choose invalid actions that are not the current state
|
|
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
|
|
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
|
|
|
|
def action_masks(self) -> List[bool]:
|
|
return [action not in self.invalid_actions for action in self.possible_actions]
|