Added run test for SACD
This commit is contained in:
parent
fca2c6d490
commit
610fd3dcf6
|
|
@ -3,7 +3,7 @@ import pytest
|
||||||
from stable_baselines3.common.env_util import make_vec_env
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
from stable_baselines3.common.vec_env import VecNormalize
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
||||||
from sb3_contrib import ARS, QRDQN, TQC, TRPO, MaskablePPO
|
from sb3_contrib import ARS, QRDQN, SACD, TQC, TRPO, MaskablePPO
|
||||||
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
|
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
|
||||||
from sb3_contrib.common.vec_env import AsyncEval
|
from sb3_contrib.common.vec_env import AsyncEval
|
||||||
|
|
||||||
|
|
@ -61,6 +61,17 @@ def test_qrdqn():
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=500)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sacd():
|
||||||
|
model = SACD(
|
||||||
|
"MlpPolicy",
|
||||||
|
"CartPole-v1",
|
||||||
|
policy_kwargs=dict(net_arch=[64, 64], n_critics=2),
|
||||||
|
learning_starts=100,
|
||||||
|
verbose=1,
|
||||||
|
)
|
||||||
|
model.learn(total_timesteps=3000)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
||||||
def test_trpo(env_id):
|
def test_trpo(env_id):
|
||||||
model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue