Added save_load test for SACD

This commit is contained in:
Paul Auerbach 2023-08-07 14:23:07 +02:00
parent d97dbc727c
commit bc08ee985e
1 changed files with 8 additions and 4 deletions

View File

@ -12,16 +12,16 @@ from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnv
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib import ARS, QRDQN, TQC, TRPO
from sb3_contrib import ARS, QRDQN, SACD, TQC, TRPO
MODEL_LIST = [ARS, QRDQN, TQC, TRPO]
MODEL_LIST = [ARS, QRDQN, SACD, TQC, TRPO]
def select_env(model_class: BaseAlgorithm) -> gym.Env:
"""
Selects an environment with the correct action space as QRDQN only supports discrete action space
"""
if model_class == QRDQN:
if model_class in [QRDQN, SACD]:
return IdentityEnv(10)
else:
return IdentityEnvBox(-10, 10)
@ -281,12 +281,16 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
learning_starts=100,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
)
elif model_class == SACD:
kwargs = dict(
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
)
else:
kwargs = dict(
n_steps=128,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class in [QRDQN, SACD])
# Reduce number of quantiles for faster tests
if model_class in [TQC, QRDQN]: