Faster tests

This commit is contained in:
Antonin RAFFIN 2020-10-17 17:06:11 +02:00
parent 926e488196
commit 72fe9a2072
2 changed files with 7 additions and 7 deletions

View File

@ -14,7 +14,7 @@ def test_tqc(ent_coef):
create_eval_env=True,
ent_coef=ent_coef,
)
model.learn(total_timesteps=500, eval_freq=250)
model.learn(total_timesteps=300, eval_freq=250)
@pytest.mark.parametrize("n_critics", [1, 3])
@ -23,7 +23,7 @@ def test_n_critics(n_critics):
model = TQC(
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1
)
model.learn(total_timesteps=500)
model.learn(total_timesteps=300)
def test_sde():
@ -35,6 +35,6 @@ def test_sde():
learning_starts=100,
verbose=1,
)
model.learn(total_timesteps=500)
model.learn(total_timesteps=300)
model.policy.reset_noise()
model.policy.actor.get_std()

View File

@ -43,7 +43,7 @@ def test_save_load(tmp_path, model_class):
# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500)
model.learn(total_timesteps=300)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
@ -152,7 +152,7 @@ def test_save_load(tmp_path, model_class):
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# check if learn still works
model.learn(total_timesteps=500)
model.learn(total_timesteps=300)
del model
@ -224,7 +224,7 @@ def test_save_load_replay_buffer(tmp_path, model_class):
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
model = model_class("MlpPolicy", select_env(model_class), buffer_size=1000)
model.learn(500)
model.learn(300)
old_replay_buffer = deepcopy(model.replay_buffer)
model.save_replay_buffer(path)
model.replay_buffer = None
@ -268,7 +268,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
# create model
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
model.learn(total_timesteps=500)
model.learn(total_timesteps=300)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)