39 lines
1.1 KiB
Python
39 lines
1.1 KiB
Python
import numpy as np
|
|
from stable_baselines3.common.type_aliases import GymEnv
|
|
from stable_baselines3.common.vec_env import VecEnv
|
|
|
|
EXPECTED_METHOD_NAME = "action_masks"
|
|
|
|
|
|
def get_action_masks(env: GymEnv) -> np.ndarray:
|
|
"""
|
|
Checks whether gym env exposes a method returning invalid action masks
|
|
|
|
:param env: the Gym environment to get masks from
|
|
:return: A numpy array of the masks
|
|
"""
|
|
|
|
if isinstance(env, VecEnv):
|
|
return np.stack(env.env_method(EXPECTED_METHOD_NAME))
|
|
else:
|
|
return getattr(env, EXPECTED_METHOD_NAME)()
|
|
|
|
|
|
def is_masking_supported(env: GymEnv) -> bool:
|
|
"""
|
|
Checks whether gym env exposes a method returning invalid action masks
|
|
|
|
:param env: the Gym environment to check
|
|
:return: True if the method is found, False otherwise
|
|
"""
|
|
|
|
if isinstance(env, VecEnv):
|
|
try:
|
|
# TODO: add VecEnv.has_attr()
|
|
env.get_attr(EXPECTED_METHOD_NAME)
|
|
return True
|
|
except AttributeError:
|
|
return False
|
|
else:
|
|
return hasattr(env, EXPECTED_METHOD_NAME)
|