17 lines
494 B
Python
17 lines
494 B
Python
# This file is here just to define MlpPolicy/CnnPolicy
|
|
# that work for TRPO
|
|
from stable_baselines3.common.policies import (
|
|
ActorCriticCnnPolicy,
|
|
ActorCriticPolicy,
|
|
MultiInputActorCriticPolicy,
|
|
register_policy,
|
|
)
|
|
|
|
MlpPolicy = ActorCriticPolicy
|
|
CnnPolicy = ActorCriticCnnPolicy
|
|
MultiInputPolicy = MultiInputActorCriticPolicy
|
|
|
|
register_policy("MlpPolicy", ActorCriticPolicy)
|
|
register_policy("CnnPolicy", ActorCriticCnnPolicy)
|
|
register_policy("MultiInputPolicy", MultiInputPolicy)
|