stable-baselines3-contrib-sacd/sb3_contrib/common/maskable/utils.py

39 lines
1.1 KiB
Python

import numpy as np
from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.vec_env import VecEnv
EXPECTED_METHOD_NAME = "action_masks"
def get_action_masks(env: GymEnv) -> np.ndarray:
"""
Checks whether gym env exposes a method returning invalid action masks
:param env: the Gym environment to get masks from
:return: A numpy array of the masks
"""
if isinstance(env, VecEnv):
return np.stack(env.env_method(EXPECTED_METHOD_NAME))
else:
return getattr(env, EXPECTED_METHOD_NAME)()
def is_masking_supported(env: GymEnv) -> bool:
"""
Checks whether gym env exposes a method returning invalid action masks
:param env: the Gym environment to check
:return: True if the method is found, False otherwise
"""
if isinstance(env, VecEnv):
try:
# TODO: add VecEnv.has_attr()
env.get_attr(EXPECTED_METHOD_NAME)
return True
except AttributeError:
return False
else:
return hasattr(env, EXPECTED_METHOD_NAME)