from typing import Callable, Union import gymnasium as gym import numpy as np class ActionMasker(gym.Wrapper): """ Env wrapper providing the method required to support masking. Exposes a method called action_masks(), which returns masks for the wrapped env. This wrapper is not needed if the env exposes the expected method itself. :param env: the Gym environment to wrap :param action_mask_fn: A function that takes a Gym environment and returns an action mask, or the name of such a method provided by the environment. """ def __init__(self, env: gym.Env, action_mask_fn: Union[str, Callable[[gym.Env], np.ndarray]]): super().__init__(env) if isinstance(action_mask_fn, str): found_method = getattr(self.env, action_mask_fn) if not callable(found_method): raise ValueError(f"Environment attribute {action_mask_fn} is not a method") self._action_mask_fn = found_method else: self._action_mask_fn = action_mask_fn def action_masks(self) -> np.ndarray: return self._action_mask_fn(self.env)