Add test for pytorch variables

This commit is contained in:
Antonin Raffin 2021-05-12 11:39:56 +02:00
parent 3665695d1e
commit 30cc206578
2 changed files with 39 additions and 2 deletions

View File

@ -303,9 +303,9 @@ class TQC(OffPolicyAlgorithm):
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
saved_pytorch_variables = ["log_ent_coef"]
if self.ent_coef_optimizer is not None:
saved_pytorch_variables = ["log_ent_coef"]
state_dicts.append("ent_coef_optimizer")
else:
saved_pytorch_variables.append("ent_coef_tensor")
saved_pytorch_variables = ["ent_coef_tensor"]
return state_dicts, saved_pytorch_variables

View File

@ -433,3 +433,40 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
# clear file from os
os.remove(tmp_path / "q_net.pkl")
def test_save_load_pytorch_var(tmp_path):
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200)
save_path = str(tmp_path / "tqc_pendulum")
model.save(save_path)
env = model.get_env()
log_ent_coef_before = model.log_ent_coef
del model
model = TQC.load(save_path, env=env)
assert th.allclose(log_ent_coef_before, model.log_ent_coef)
model.learn(200)
log_ent_coef_after = model.log_ent_coef
# Check that the entropy coefficient is still optimized
assert not th.allclose(log_ent_coef_before, log_ent_coef_after)
# With a fixed entropy coef
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200)
save_path = str(tmp_path / "tqc_pendulum")
model.save(save_path)
env = model.get_env()
assert model.log_ent_coef is None
ent_coef_before = model.ent_coef_tensor
del model
model = TQC.load(save_path, env=env)
assert th.allclose(ent_coef_before, model.ent_coef_tensor)
model.learn(200)
ent_coef_after = model.ent_coef_tensor
assert model.log_ent_coef is None
# Check that the entropy coefficient is still the same
assert th.allclose(ent_coef_before, ent_coef_after)