diff --git a/tests/test_run.py b/tests/test_run.py index c2bc941..8d59976 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -14,7 +14,7 @@ def test_tqc(ent_coef): create_eval_env=True, ent_coef=ent_coef, ) - model.learn(total_timesteps=500, eval_freq=250) + model.learn(total_timesteps=300, eval_freq=250) @pytest.mark.parametrize("n_critics", [1, 3]) @@ -23,7 +23,7 @@ def test_n_critics(n_critics): model = TQC( "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1 ) - model.learn(total_timesteps=500) + model.learn(total_timesteps=300) def test_sde(): @@ -35,6 +35,6 @@ def test_sde(): learning_starts=100, verbose=1, ) - model.learn(total_timesteps=500) + model.learn(total_timesteps=300) model.policy.reset_noise() model.policy.actor.get_std() diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 9a73c77..396bc62 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -43,7 +43,7 @@ def test_save_load(tmp_path, model_class): # create model model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1) - model.learn(total_timesteps=500) + model.learn(total_timesteps=300) env.reset() observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) @@ -152,7 +152,7 @@ def test_save_load(tmp_path, model_class): assert np.allclose(selected_actions, new_selected_actions, 1e-4) # check if learn still works - model.learn(total_timesteps=500) + model.learn(total_timesteps=300) del model @@ -224,7 +224,7 @@ def test_save_load_replay_buffer(tmp_path, model_class): path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl") path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning model = model_class("MlpPolicy", select_env(model_class), buffer_size=1000) - model.learn(500) + model.learn(300) old_replay_buffer = deepcopy(model.replay_buffer) model.save_replay_buffer(path) model.replay_buffer = None @@ -268,7 +268,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str): # create model model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs) - model.learn(total_timesteps=500) + model.learn(total_timesteps=300) env.reset() observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)