stable-baselines3-contrib-sacd/sb3_contrib/trpo/policies.py

17 lines
494 B
Python

# This file is here just to define MlpPolicy/CnnPolicy
# that work for TRPO
from stable_baselines3.common.policies import (
ActorCriticCnnPolicy,
ActorCriticPolicy,
MultiInputActorCriticPolicy,
register_policy,
)
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy
register_policy("MlpPolicy", ActorCriticPolicy)
register_policy("CnnPolicy", ActorCriticCnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)