stable-baselines3-contrib-sacd/sb3_contrib/ppo_mask/policies.py

16 lines
575 B
Python

from stable_baselines3.common.policies import register_policy
from sb3_contrib.common.maskable.policies import ( # MaskableMultiInputActorCriticPolicy,
MaskableActorCriticCnnPolicy,
MaskableActorCriticPolicy,
)
MlpPolicy = MaskableActorCriticPolicy
CnnPolicy = MaskableActorCriticCnnPolicy
# MultiInputPolicy = MaskableMultiInputActorCriticPolicy
register_policy("MlpPolicy", MaskableActorCriticPolicy)
register_policy("CnnPolicy", MaskableActorCriticCnnPolicy)
# Currently not supported
# register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)