Added save_load test for SACD
This commit is contained in:
parent
d97dbc727c
commit
bc08ee985e
|
|
@ -12,16 +12,16 @@ from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnv
|
|||
from stable_baselines3.common.utils import get_device
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from sb3_contrib import ARS, QRDQN, TQC, TRPO
|
||||
from sb3_contrib import ARS, QRDQN, SACD, TQC, TRPO
|
||||
|
||||
MODEL_LIST = [ARS, QRDQN, TQC, TRPO]
|
||||
MODEL_LIST = [ARS, QRDQN, SACD, TQC, TRPO]
|
||||
|
||||
|
||||
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
||||
"""
|
||||
Selects an environment with the correct action space as QRDQN only supports discrete action space
|
||||
"""
|
||||
if model_class == QRDQN:
|
||||
if model_class in [QRDQN, SACD]:
|
||||
return IdentityEnv(10)
|
||||
else:
|
||||
return IdentityEnvBox(-10, 10)
|
||||
|
|
@ -281,12 +281,16 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
learning_starts=100,
|
||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||
)
|
||||
elif model_class == SACD:
|
||||
kwargs = dict(
|
||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||
)
|
||||
else:
|
||||
kwargs = dict(
|
||||
n_steps=128,
|
||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||
)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class in [QRDQN, SACD])
|
||||
|
||||
# Reduce number of quantiles for faster tests
|
||||
if model_class in [TQC, QRDQN]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue