Improve coverage
This commit is contained in:
parent
7609c87e84
commit
5d7b79d41a
|
|
@ -183,18 +183,6 @@ class Actor(BasePolicy):
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
return self.forward(observation, deterministic)
|
return self.forward(observation, deterministic)
|
||||||
|
|
||||||
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
|
|
||||||
"""
|
|
||||||
Evaluate actions according to the current policy,
|
|
||||||
given the observations. Only useful when using SDE.
|
|
||||||
:param obs:
|
|
||||||
:param actions:
|
|
||||||
:return: log likelihood of taking those actions
|
|
||||||
"""
|
|
||||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
|
||||||
self.action_dist.proba_distribution(mean_actions, log_std, **kwargs)
|
|
||||||
return self.action_dist.log_prob(actions)
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(BaseModel):
|
class Critic(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,8 @@ def test_n_critics(n_critics):
|
||||||
|
|
||||||
def test_sde():
|
def test_sde():
|
||||||
model = TQC(
|
model = TQC(
|
||||||
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64]), use_sde=True, learning_starts=100, verbose=1
|
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], sde_net_arch=[8]), use_sde=True, learning_starts=100, verbose=1
|
||||||
)
|
)
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=500)
|
||||||
|
model.policy.reset_noise()
|
||||||
|
model.policy.actor.get_std()
|
||||||
|
|
|
||||||
|
|
@ -177,17 +177,17 @@ def test_set_env(model_class):
|
||||||
# create model
|
# create model
|
||||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
|
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
|
||||||
# learn
|
# learn
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
# change env
|
# change env
|
||||||
model.set_env(env2)
|
model.set_env(env2)
|
||||||
# learn again
|
# learn again
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
# change env test wrapping
|
# change env test wrapping
|
||||||
model.set_env(env3)
|
model.set_env(env3)
|
||||||
# learn again
|
# learn again
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue