Added save_load test for SACD
This commit is contained in:
parent
d97dbc727c
commit
bc08ee985e
|
|
@ -12,16 +12,16 @@ from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnv
|
||||||
from stable_baselines3.common.utils import get_device
|
from stable_baselines3.common.utils import get_device
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||||
|
|
||||||
from sb3_contrib import ARS, QRDQN, TQC, TRPO
|
from sb3_contrib import ARS, QRDQN, SACD, TQC, TRPO
|
||||||
|
|
||||||
MODEL_LIST = [ARS, QRDQN, TQC, TRPO]
|
MODEL_LIST = [ARS, QRDQN, SACD, TQC, TRPO]
|
||||||
|
|
||||||
|
|
||||||
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
||||||
"""
|
"""
|
||||||
Selects an environment with the correct action space as QRDQN only supports discrete action space
|
Selects an environment with the correct action space as QRDQN only supports discrete action space
|
||||||
"""
|
"""
|
||||||
if model_class == QRDQN:
|
if model_class in [QRDQN, SACD]:
|
||||||
return IdentityEnv(10)
|
return IdentityEnv(10)
|
||||||
else:
|
else:
|
||||||
return IdentityEnvBox(-10, 10)
|
return IdentityEnvBox(-10, 10)
|
||||||
|
|
@ -281,12 +281,16 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
||||||
learning_starts=100,
|
learning_starts=100,
|
||||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||||
)
|
)
|
||||||
|
elif model_class == SACD:
|
||||||
|
kwargs = dict(
|
||||||
|
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
n_steps=128,
|
n_steps=128,
|
||||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||||
)
|
)
|
||||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN)
|
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class in [QRDQN, SACD])
|
||||||
|
|
||||||
# Reduce number of quantiles for faster tests
|
# Reduce number of quantiles for faster tests
|
||||||
if model_class in [TQC, QRDQN]:
|
if model_class in [TQC, QRDQN]:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue