33 lines
1.1 KiB
Python
33 lines
1.1 KiB
Python
from typing import Callable, Union
|
|
|
|
import gym
|
|
import numpy as np
|
|
|
|
|
|
class ActionMasker(gym.Wrapper):
|
|
"""
|
|
Env wrapper providing the method required to support masking.
|
|
|
|
Exposes a method called action_masks(), which returns masks for the wrapped env.
|
|
This wrapper is not needed if the env exposes the expected method itself.
|
|
|
|
:param env: the Gym environment to wrap
|
|
:param action_mask_fn: A function that takes a Gym environment and returns an action mask,
|
|
or the name of such a method provided by the environment.
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env, action_mask_fn: Union[str, Callable[[gym.Env], np.ndarray]]):
|
|
super().__init__(env)
|
|
|
|
if isinstance(action_mask_fn, str):
|
|
found_method = getattr(self.env, action_mask_fn)
|
|
if not callable(found_method):
|
|
raise ValueError(f"Environment attribute {action_mask_fn} is not a method")
|
|
|
|
self._action_mask_fn = found_method
|
|
else:
|
|
self._action_mask_fn = action_mask_fn
|
|
|
|
def action_masks(self) -> np.ndarray:
|
|
return self._action_mask_fn(self.env)
|