16 lines
575 B
Python
16 lines
575 B
Python
from stable_baselines3.common.policies import register_policy
|
|
|
|
from sb3_contrib.common.maskable.policies import ( # MaskableMultiInputActorCriticPolicy,
|
|
MaskableActorCriticCnnPolicy,
|
|
MaskableActorCriticPolicy,
|
|
)
|
|
|
|
MlpPolicy = MaskableActorCriticPolicy
|
|
CnnPolicy = MaskableActorCriticCnnPolicy
|
|
# MultiInputPolicy = MaskableMultiInputActorCriticPolicy
|
|
|
|
register_policy("MlpPolicy", MaskableActorCriticPolicy)
|
|
register_policy("CnnPolicy", MaskableActorCriticCnnPolicy)
|
|
# Currently not supported
|
|
# register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)
|