Added save_load test for SACD

This commit is contained in:
Paul Auerbach 2023-08-07 14:23:07 +02:00
parent d97dbc727c
commit bc08ee985e
1 changed files with 8 additions and 4 deletions

View File

@ -12,16 +12,16 @@ from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnv
from stable_baselines3.common.utils import get_device from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib import ARS, QRDQN, TQC, TRPO from sb3_contrib import ARS, QRDQN, SACD, TQC, TRPO
MODEL_LIST = [ARS, QRDQN, TQC, TRPO] MODEL_LIST = [ARS, QRDQN, SACD, TQC, TRPO]
def select_env(model_class: BaseAlgorithm) -> gym.Env: def select_env(model_class: BaseAlgorithm) -> gym.Env:
""" """
Selects an environment with the correct action space as QRDQN only supports discrete action space Selects an environment with the correct action space as QRDQN only supports discrete action space
""" """
if model_class == QRDQN: if model_class in [QRDQN, SACD]:
return IdentityEnv(10) return IdentityEnv(10)
else: else:
return IdentityEnvBox(-10, 10) return IdentityEnvBox(-10, 10)
@ -281,12 +281,16 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
learning_starts=100, learning_starts=100,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
) )
elif model_class == SACD:
kwargs = dict(
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
)
else: else:
kwargs = dict( kwargs = dict(
n_steps=128, n_steps=128,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
) )
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN) env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class in [QRDQN, SACD])
# Reduce number of quantiles for faster tests # Reduce number of quantiles for faster tests
if model_class in [TQC, QRDQN]: if model_class in [TQC, QRDQN]: