From bc08ee985e191156faf8fe224e026edd9819f8e1 Mon Sep 17 00:00:00 2001 From: Paul Auerbach Date: Mon, 7 Aug 2023 14:23:07 +0200 Subject: [PATCH] Added save_load test for SACD --- tests/test_save_load.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 502d239..5528b23 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -12,16 +12,16 @@ from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnv from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv -from sb3_contrib import ARS, QRDQN, TQC, TRPO +from sb3_contrib import ARS, QRDQN, SACD, TQC, TRPO -MODEL_LIST = [ARS, QRDQN, TQC, TRPO] +MODEL_LIST = [ARS, QRDQN, SACD, TQC, TRPO] def select_env(model_class: BaseAlgorithm) -> gym.Env: """ Selects an environment with the correct action space as QRDQN only supports discrete action space """ - if model_class == QRDQN: + if model_class in [QRDQN, SACD]: return IdentityEnv(10) else: return IdentityEnvBox(-10, 10) @@ -281,12 +281,16 @@ def test_save_load_policy(tmp_path, model_class, policy_str): learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), ) + elif model_class == SACD: + kwargs = dict( + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), + ) else: kwargs = dict( n_steps=128, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), ) - env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class in [QRDQN, SACD]) # Reduce number of quantiles for faster tests if model_class in [TQC, QRDQN]: