diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index 19a4037..f728c3a 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -183,18 +183,6 @@ class Actor(BasePolicy): def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: return self.forward(observation, deterministic) - def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: - """ - Evaluate actions according to the current policy, - given the observations. Only useful when using SDE. - :param obs: - :param actions: - :return: log likelihood of taking those actions - """ - mean_actions, log_std, kwargs = self.get_action_dist_params(obs) - self.action_dist.proba_distribution(mean_actions, log_std, **kwargs) - return self.action_dist.log_prob(actions) - class Critic(BaseModel): """ diff --git a/tests/test_run.py b/tests/test_run.py index 49cd2bf..241ea92 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -27,6 +27,8 @@ def test_n_critics(n_critics): def test_sde(): model = TQC( - "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64]), use_sde=True, learning_starts=100, verbose=1 + "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], sde_net_arch=[8]), use_sde=True, learning_starts=100, verbose=1 ) model.learn(total_timesteps=500) + model.policy.reset_noise() + model.policy.actor.get_std() diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 875efb9..8a704ac 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -177,17 +177,17 @@ def test_set_env(model_class): # create model model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16])) # learn - model.learn(total_timesteps=500) + model.learn(total_timesteps=300) # change env model.set_env(env2) # learn again - model.learn(total_timesteps=500) + model.learn(total_timesteps=300) # change env test wrapping model.set_env(env3) # learn again - model.learn(total_timesteps=500) + model.learn(total_timesteps=300) @pytest.mark.parametrize("model_class", MODEL_LIST)