Add test for pytorch variables
This commit is contained in:
parent
3665695d1e
commit
30cc206578
|
|
@ -303,9 +303,9 @@ class TQC(OffPolicyAlgorithm):
|
|||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||
saved_pytorch_variables = ["log_ent_coef"]
|
||||
if self.ent_coef_optimizer is not None:
|
||||
saved_pytorch_variables = ["log_ent_coef"]
|
||||
state_dicts.append("ent_coef_optimizer")
|
||||
else:
|
||||
saved_pytorch_variables.append("ent_coef_tensor")
|
||||
saved_pytorch_variables = ["ent_coef_tensor"]
|
||||
return state_dicts, saved_pytorch_variables
|
||||
|
|
|
|||
|
|
@ -433,3 +433,40 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
|
|||
|
||||
# clear file from os
|
||||
os.remove(tmp_path / "q_net.pkl")
|
||||
|
||||
|
||||
def test_save_load_pytorch_var(tmp_path):
|
||||
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
|
||||
model.learn(200)
|
||||
save_path = str(tmp_path / "tqc_pendulum")
|
||||
model.save(save_path)
|
||||
env = model.get_env()
|
||||
log_ent_coef_before = model.log_ent_coef
|
||||
|
||||
del model
|
||||
|
||||
model = TQC.load(save_path, env=env)
|
||||
assert th.allclose(log_ent_coef_before, model.log_ent_coef)
|
||||
model.learn(200)
|
||||
log_ent_coef_after = model.log_ent_coef
|
||||
# Check that the entropy coefficient is still optimized
|
||||
assert not th.allclose(log_ent_coef_before, log_ent_coef_after)
|
||||
|
||||
# With a fixed entropy coef
|
||||
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
|
||||
model.learn(200)
|
||||
save_path = str(tmp_path / "tqc_pendulum")
|
||||
model.save(save_path)
|
||||
env = model.get_env()
|
||||
assert model.log_ent_coef is None
|
||||
ent_coef_before = model.ent_coef_tensor
|
||||
|
||||
del model
|
||||
|
||||
model = TQC.load(save_path, env=env)
|
||||
assert th.allclose(ent_coef_before, model.ent_coef_tensor)
|
||||
model.learn(200)
|
||||
ent_coef_after = model.ent_coef_tensor
|
||||
assert model.log_ent_coef is None
|
||||
# Check that the entropy coefficient is still the same
|
||||
assert th.allclose(ent_coef_before, ent_coef_after)
|
||||
|
|
|
|||
Loading…
Reference in New Issue