diff --git a/sb3_contrib/sacd/sacd.py b/sb3_contrib/sacd/sacd.py index 4b612a0..9e48513 100644 --- a/sb3_contrib/sacd/sacd.py +++ b/sb3_contrib/sacd/sacd.py @@ -68,6 +68,7 @@ class SACD(OffPolicyAlgorithm): during the warm up phase (before learning starts) :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over + :param max_grad_norm: The maximum value for the gradient clipping :param tensorboard_log: the log location for tensorboard (if None, no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for @@ -111,6 +112,7 @@ class SACD(OffPolicyAlgorithm): sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, stats_window_size: int = 100, + max_grad_norm = 5.0, tensorboard_log: Optional[str] = None, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, @@ -153,6 +155,7 @@ class SACD(OffPolicyAlgorithm): self.ent_coef = ent_coef self.target_update_interval = target_update_interval self.ent_coef_optimizer: Optional[th.optim.Adam] = None + self.gradient_clip_norm = max_grad_norm if _init_setup_model: self._setup_model() @@ -220,12 +223,12 @@ class SACD(OffPolicyAlgorithm): # Compute the critic loss critic_loss = self.calc_critic_loss(replay_data) critic_losses.append(critic_loss.item()) - self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, 5.0) + self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, self.gradient_clip_norm) # Compute the actor loss actor_loss, log_action_prob = self.calc_actor_loss(replay_data) actor_losses.append(actor_loss.item()) - self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, 5.0) + self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm) # Compute entropy loss and optimize ent_coeff = self.calc_entropy_loss(log_action_prob) @@ -235,7 +238,7 @@ class SACD(OffPolicyAlgorithm): if gradient_step % self.target_update_interval == 0: polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) # Copy running stats, see GH issue #996 - # polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) + polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) self._n_updates += gradient_steps @@ -257,17 +260,10 @@ class SACD(OffPolicyAlgorithm): with th.no_grad(): # Select action according to policy action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) - # Compute the next Q values: min over all critics targets - # next_q_values = th.cat(self.critic_target(replay_data.next_observations), dim=1) - next_q_values = self.critic_target(replay_data.next_observations) - # print(self.critic_target(replay_data.next_observations)) - # exit(0) - # print(next_q_values) - # next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) - next_q_values = th.min(*next_q_values) - # print(next_q_values) - # exit(0) + # Compute the next Q values: min over all critics targets + next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2) + next_q_values, _ = th.min(next_q_values, dim=2) next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1) target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values @@ -285,10 +281,8 @@ class SACD(OffPolicyAlgorithm): action_prob, log_prob = self.actor.action_log_prob(replay_data.observations) # Min over all critic networks - # q_values_pi = th.cat(self.critic(replay_data.observations), dim=1) - q_values_pi = self.critic(replay_data.observations) - # min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) - min_qf_pi = th.min(*q_values_pi) + q_values_pi = th.stack(self.critic(replay_data.observations), dim=2) + min_qf_pi, _ = th.min(q_values_pi, dim=2) inside_term = self.ent_coef * log_prob - min_qf_pi actor_loss = (action_prob * inside_term).sum(dim=1).mean()