Added run test for SACD

This commit is contained in:
Paul Auerbach 2023-08-07 13:16:00 +02:00
parent fca2c6d490
commit 610fd3dcf6
1 changed files with 12 additions and 1 deletions

View File

@ -3,7 +3,7 @@ import pytest
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import ARS, QRDQN, TQC, TRPO, MaskablePPO
from sb3_contrib import ARS, QRDQN, SACD, TQC, TRPO, MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sb3_contrib.common.vec_env import AsyncEval
@ -61,6 +61,17 @@ def test_qrdqn():
model.learn(total_timesteps=500)
def test_sacd():
model = SACD(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(net_arch=[64, 64], n_critics=2),
learning_starts=100,
verbose=1,
)
model.learn(total_timesteps=3000)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
def test_trpo(env_id):
model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)