From 30cc2065781ea5aef3967f7bbf7d348cb85cff0a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 12 May 2021 11:39:56 +0200 Subject: [PATCH] Add test for pytorch variables --- sb3_contrib/tqc/tqc.py | 4 ++-- tests/test_save_load.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 103c25e..a9d5df4 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -303,9 +303,9 @@ class TQC(OffPolicyAlgorithm): def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] - saved_pytorch_variables = ["log_ent_coef"] if self.ent_coef_optimizer is not None: + saved_pytorch_variables = ["log_ent_coef"] state_dicts.append("ent_coef_optimizer") else: - saved_pytorch_variables.append("ent_coef_tensor") + saved_pytorch_variables = ["ent_coef_tensor"] return state_dicts, saved_pytorch_variables diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 716f202..d2ee3a2 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -433,3 +433,40 @@ def test_save_load_q_net(tmp_path, model_class, policy_str): # clear file from os os.remove(tmp_path / "q_net.pkl") + + +def test_save_load_pytorch_var(tmp_path): + model = TQC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1)) + model.learn(200) + save_path = str(tmp_path / "tqc_pendulum") + model.save(save_path) + env = model.get_env() + log_ent_coef_before = model.log_ent_coef + + del model + + model = TQC.load(save_path, env=env) + assert th.allclose(log_ent_coef_before, model.log_ent_coef) + model.learn(200) + log_ent_coef_after = model.log_ent_coef + # Check that the entropy coefficient is still optimized + assert not th.allclose(log_ent_coef_before, log_ent_coef_after) + + # With a fixed entropy coef + model = TQC("MlpPolicy", "Pendulum-v0", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1)) + model.learn(200) + save_path = str(tmp_path / "tqc_pendulum") + model.save(save_path) + env = model.get_env() + assert model.log_ent_coef is None + ent_coef_before = model.ent_coef_tensor + + del model + + model = TQC.load(save_path, env=env) + assert th.allclose(ent_coef_before, model.ent_coef_tensor) + model.learn(200) + ent_coef_after = model.ent_coef_tensor + assert model.log_ent_coef is None + # Check that the entropy coefficient is still the same + assert th.allclose(ent_coef_before, ent_coef_after)