stable-baselines3-contrib-sacd/sb3_contrib/ppo_mask/policies.py

10 lines
294 B
Python

from sb3_contrib.common.maskable.policies import (
MaskableActorCriticCnnPolicy,
MaskableActorCriticPolicy,
MaskableMultiInputActorCriticPolicy,
)
MlpPolicy = MaskableActorCriticPolicy
CnnPolicy = MaskableActorCriticCnnPolicy
MultiInputPolicy = MaskableMultiInputActorCriticPolicy