From 610fd3dcf66adb22470799644f460f45e0ed1dc1 Mon Sep 17 00:00:00 2001 From: Paul Auerbach Date: Mon, 7 Aug 2023 13:16:00 +0200 Subject: [PATCH] Added run test for SACD --- tests/test_run.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_run.py b/tests/test_run.py index 90b5703..9ce4c71 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -3,7 +3,7 @@ import pytest from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import VecNormalize -from sb3_contrib import ARS, QRDQN, TQC, TRPO, MaskablePPO +from sb3_contrib import ARS, QRDQN, SACD, TQC, TRPO, MaskablePPO from sb3_contrib.common.envs import InvalidActionEnvDiscrete from sb3_contrib.common.vec_env import AsyncEval @@ -61,6 +61,17 @@ def test_qrdqn(): model.learn(total_timesteps=500) +def test_sacd(): + model = SACD( + "MlpPolicy", + "CartPole-v1", + policy_kwargs=dict(net_arch=[64, 64], n_critics=2), + learning_starts=100, + verbose=1, + ) + model.learn(total_timesteps=3000) + + @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_trpo(env_id): model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)