Faster tests
This commit is contained in:
parent
926e488196
commit
72fe9a2072
|
|
@ -14,7 +14,7 @@ def test_tqc(ent_coef):
|
|||
create_eval_env=True,
|
||||
ent_coef=ent_coef,
|
||||
)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
model.learn(total_timesteps=300, eval_freq=250)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_critics", [1, 3])
|
||||
|
|
@ -23,7 +23,7 @@ def test_n_critics(n_critics):
|
|||
model = TQC(
|
||||
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1
|
||||
)
|
||||
model.learn(total_timesteps=500)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
|
||||
def test_sde():
|
||||
|
|
@ -35,6 +35,6 @@ def test_sde():
|
|||
learning_starts=100,
|
||||
verbose=1,
|
||||
)
|
||||
model.learn(total_timesteps=500)
|
||||
model.learn(total_timesteps=300)
|
||||
model.policy.reset_noise()
|
||||
model.policy.actor.get_std()
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def test_save_load(tmp_path, model_class):
|
|||
|
||||
# create model
|
||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||
model.learn(total_timesteps=500)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
env.reset()
|
||||
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||
|
|
@ -152,7 +152,7 @@ def test_save_load(tmp_path, model_class):
|
|||
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||
|
||||
# check if learn still works
|
||||
model.learn(total_timesteps=500)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
del model
|
||||
|
||||
|
|
@ -224,7 +224,7 @@ def test_save_load_replay_buffer(tmp_path, model_class):
|
|||
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
|
||||
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
|
||||
model = model_class("MlpPolicy", select_env(model_class), buffer_size=1000)
|
||||
model.learn(500)
|
||||
model.learn(300)
|
||||
old_replay_buffer = deepcopy(model.replay_buffer)
|
||||
model.save_replay_buffer(path)
|
||||
model.replay_buffer = None
|
||||
|
|
@ -268,7 +268,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
|
||||
# create model
|
||||
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
|
||||
model.learn(total_timesteps=500)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
env.reset()
|
||||
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||
|
|
|
|||
Loading…
Reference in New Issue