diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 3fbd28d..9e5f4cd 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -6,6 +6,7 @@ from sb3_contrib.ppo_recurrent import RecurrentPPO from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC from sb3_contrib.trpo import TRPO +from sb3_contrib.sacd import SACD # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") @@ -19,4 +20,5 @@ __all__ = [ "QRDQN", "TQC", "TRPO", + "SACD", ] diff --git a/sb3_contrib/sacd/policies.py b/sb3_contrib/sacd/policies.py index 1676d58..f4efef5 100644 --- a/sb3_contrib/sacd/policies.py +++ b/sb3_contrib/sacd/policies.py @@ -178,13 +178,13 @@ class DiscreteCritic(BaseModel): normalize_images=normalize_images, ) - action_dim = get_action_dim(self.action_space) + num_actions = self.action_space.n self.share_features_extractor = share_features_extractor self.n_critics = n_critics self.q_networks = [] for idx in range(n_critics): - q_net = create_mlp(features_dim, action_dim, net_arch, activation_fn) + q_net = create_mlp(features_dim, num_actions, net_arch, activation_fn) q_net = nn.Sequential(*q_net) self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) diff --git a/sb3_contrib/sacd/sacd.py b/sb3_contrib/sacd/sacd.py index 1119bf0..4b612a0 100644 --- a/sb3_contrib/sacd/sacd.py +++ b/sb3_contrib/sacd/sacd.py @@ -186,6 +186,7 @@ class SACD(OffPolicyAlgorithm): # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 # self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True) self.log_ent_coef = th.zeros(1, device=self.device, requires_grad=True) + self.ent_coef = th.exp(self.log_ent_coef) self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1)) else: # Force conversion to float @@ -216,80 +217,19 @@ class SACD(OffPolicyAlgorithm): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] - # We need to sample because `log_std` may have changed between two gradient steps - # if self.use_sde: - # self.actor.reset_noise() + # Compute the critic loss + critic_loss = self.calc_critic_loss(replay_data) + critic_losses.append(critic_loss.item()) + self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, 5.0) - # Action by the current actor for the sampled state - actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations) - - # Compute entropy loss - ent_coef_loss = None - if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: - # Important: detach the variable from the graph - # so we don't change it with other losses - # see https://github.com/rail-berkeley/softlearning/issues/60 - ent_coef = th.exp(self.log_ent_coef.detach()) - ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() - ent_coef_losses.append(ent_coef_loss.item()) - else: - ent_coef = self.ent_coef_tensor - - ent_coefs.append(ent_coef.item()) - - # print(f"Alpha Loss{ent_coef_loss.item()}") - - # Optimize entropy coefficient, also called - # entropy temperature or alpha in the paper - if ent_coef_loss is not None and self.ent_coef_optimizer is not None: - self.ent_coef_optimizer.zero_grad() - ent_coef_loss.backward() - self.ent_coef_optimizer.step() - - with th.no_grad(): - # Select action according to policy - action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) - # Compute the next Q values: min over all critics targets - next_q_values = th.cat(self.critic_target(replay_data.next_observations), dim=1) - next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) - - # add entropy term - next_q_values = (action_prob * next_q_values - ent_coef * next_log_prob).sum(dim=1).unsqueeze(-1) - # td error + entropy term - target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values - - # Get current Q-values estimates for each critic network - # using action from the replay buffer - current_q_values = self.critic(replay_data.observations) - - # Compute critic loss - critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) - critic_losses.append(critic_loss.item()) # type: ignore[union-attr] - - # print(f"Critic Loss{critic_loss.item()}") - - # Optimize the critic - self.critic.optimizer.zero_grad() - critic_loss.backward() - th.nn.utils.clip_grad_norm(self.actor.parameters(), 5.0) - self.critic.optimizer.step() - - # Compute actor loss - # Min over all critic networks - q_values_pi = th.cat(self.critic(replay_data.observations), dim=1) - min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) - - inside_term = ent_coef * log_prob - min_qf_pi - actor_loss = (actions_pi * inside_term).sum(dim=1).mean() + # Compute the actor loss + actor_loss, log_action_prob = self.calc_actor_loss(replay_data) actor_losses.append(actor_loss.item()) + self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, 5.0) - # print(f"Actor Loss{actor_loss.item()}") - - # Optimize the actor - self.actor.optimizer.zero_grad() - actor_loss.backward() - th.nn.utils.clip_grad_norm(self.critic.parameters(), 5.0) - self.actor.optimizer.step() + # Compute entropy loss and optimize + ent_coeff = self.calc_entropy_loss(log_action_prob) + ent_coefs.append(self.ent_coef.item()) # Update target networks if gradient_step % self.target_update_interval == 0: @@ -306,6 +246,70 @@ class SACD(OffPolicyAlgorithm): if len(ent_coef_losses) > 0: self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) + def take_optimisation_step(self, optimizer, network, loss, clipping_norm=None): + optimizer.zero_grad() + loss.backward() + if clipping_norm is not None: + th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) #clip gradients to help stabilise training + optimizer.step() + + def calc_critic_loss(self, replay_data): + with th.no_grad(): + # Select action according to policy + action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) + # Compute the next Q values: min over all critics targets + # next_q_values = th.cat(self.critic_target(replay_data.next_observations), dim=1) + next_q_values = self.critic_target(replay_data.next_observations) + # print(self.critic_target(replay_data.next_observations)) + # exit(0) + # print(next_q_values) + + # next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) + next_q_values = th.min(*next_q_values) + # print(next_q_values) + # exit(0) + + next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1) + target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values + + # Get current Q-values estimates for each critic network + # using action from the replay buffer + current_q_values = self.critic(replay_data.observations) + + # Compute critic loss + critic_loss = 0.5 * sum(F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values) + + return critic_loss + + def calc_actor_loss(self, replay_data): + action_prob, log_prob = self.actor.action_log_prob(replay_data.observations) + + # Min over all critic networks + # q_values_pi = th.cat(self.critic(replay_data.observations), dim=1) + q_values_pi = self.critic(replay_data.observations) + # min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) + min_qf_pi = th.min(*q_values_pi) + + inside_term = self.ent_coef * log_prob - min_qf_pi + actor_loss = (action_prob * inside_term).sum(dim=1).mean() + return actor_loss, log_prob + + def calc_entropy_loss(self, log_action_prob): + ent_coef_loss = None + if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: + # Important: detach the variable from the graph + # so we don't change it with other losses + # see https://github.com/rail-berkeley/softlearning/issues/60 + ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean() + # ent_coef_losses.append(ent_coef_loss.item()) + ent_coef_loss.backward() + self.ent_coef_optimizer.step() + self.ent_coef = th.exp(self.log_ent_coef.detach()) + else: + self.ent_coef = self.ent_coef_tensor + + return self.ent_coef + def learn( self: SelfSACD, total_timesteps: int,