Add test for pytorch variables
This commit is contained in:
parent
3665695d1e
commit
30cc206578
|
|
@ -303,9 +303,9 @@ class TQC(OffPolicyAlgorithm):
|
||||||
|
|
||||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||||
saved_pytorch_variables = ["log_ent_coef"]
|
|
||||||
if self.ent_coef_optimizer is not None:
|
if self.ent_coef_optimizer is not None:
|
||||||
|
saved_pytorch_variables = ["log_ent_coef"]
|
||||||
state_dicts.append("ent_coef_optimizer")
|
state_dicts.append("ent_coef_optimizer")
|
||||||
else:
|
else:
|
||||||
saved_pytorch_variables.append("ent_coef_tensor")
|
saved_pytorch_variables = ["ent_coef_tensor"]
|
||||||
return state_dicts, saved_pytorch_variables
|
return state_dicts, saved_pytorch_variables
|
||||||
|
|
|
||||||
|
|
@ -433,3 +433,40 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
|
||||||
|
|
||||||
# clear file from os
|
# clear file from os
|
||||||
os.remove(tmp_path / "q_net.pkl")
|
os.remove(tmp_path / "q_net.pkl")
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_load_pytorch_var(tmp_path):
|
||||||
|
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
|
||||||
|
model.learn(200)
|
||||||
|
save_path = str(tmp_path / "tqc_pendulum")
|
||||||
|
model.save(save_path)
|
||||||
|
env = model.get_env()
|
||||||
|
log_ent_coef_before = model.log_ent_coef
|
||||||
|
|
||||||
|
del model
|
||||||
|
|
||||||
|
model = TQC.load(save_path, env=env)
|
||||||
|
assert th.allclose(log_ent_coef_before, model.log_ent_coef)
|
||||||
|
model.learn(200)
|
||||||
|
log_ent_coef_after = model.log_ent_coef
|
||||||
|
# Check that the entropy coefficient is still optimized
|
||||||
|
assert not th.allclose(log_ent_coef_before, log_ent_coef_after)
|
||||||
|
|
||||||
|
# With a fixed entropy coef
|
||||||
|
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
|
||||||
|
model.learn(200)
|
||||||
|
save_path = str(tmp_path / "tqc_pendulum")
|
||||||
|
model.save(save_path)
|
||||||
|
env = model.get_env()
|
||||||
|
assert model.log_ent_coef is None
|
||||||
|
ent_coef_before = model.ent_coef_tensor
|
||||||
|
|
||||||
|
del model
|
||||||
|
|
||||||
|
model = TQC.load(save_path, env=env)
|
||||||
|
assert th.allclose(ent_coef_before, model.ent_coef_tensor)
|
||||||
|
model.learn(200)
|
||||||
|
ent_coef_after = model.ent_coef_tensor
|
||||||
|
assert model.log_ent_coef is None
|
||||||
|
# Check that the entropy coefficient is still the same
|
||||||
|
assert th.allclose(ent_coef_before, ent_coef_after)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue