stable-baselines3-contrib-sacd/sb3_contrib/common/wrappers/action_masker.py

33 lines
1.1 KiB
Python

from typing import Callable, Union
import gymnasium as gym
import numpy as np
class ActionMasker(gym.Wrapper):
"""
Env wrapper providing the method required to support masking.
Exposes a method called action_masks(), which returns masks for the wrapped env.
This wrapper is not needed if the env exposes the expected method itself.
:param env: the Gym environment to wrap
:param action_mask_fn: A function that takes a Gym environment and returns an action mask,
or the name of such a method provided by the environment.
"""
def __init__(self, env: gym.Env, action_mask_fn: Union[str, Callable[[gym.Env], np.ndarray]]):
super().__init__(env)
if isinstance(action_mask_fn, str):
found_method = getattr(self.env, action_mask_fn)
if not callable(found_method):
raise ValueError(f"Environment attribute {action_mask_fn} is not a method")
self._action_mask_fn = found_method
else:
self._action_mask_fn = action_mask_fn
def action_masks(self) -> np.ndarray:
return self._action_mask_fn(self.env)