diff --git a/sb3_contrib/sacd/policies.py b/sb3_contrib/sacd/policies.py index 49f560a..fa5b3fd 100644 --- a/sb3_contrib/sacd/policies.py +++ b/sb3_contrib/sacd/policies.py @@ -61,7 +61,6 @@ class Actor(BasePolicy): action_space, features_extractor=features_extractor, normalize_images=normalize_images, - # squash_output=True, squash_output=False, ) @@ -78,7 +77,7 @@ class Actor(BasePolicy): num_actions = self.action_space.n - latent_pi_net = create_mlp(features_dim, num_actions, net_arch, activation_fn) + latent_pi_net = create_mlp(features_dim, num_actions.item(), net_arch, activation_fn) self.latent_pi = nn.Sequential(*latent_pi_net) self.output_activation = nn.Softmax(dim=1) @@ -184,8 +183,8 @@ class DiscreteCritic(BaseModel): self.n_critics = n_critics self.q_networks = [] for idx in range(n_critics): - q_net = create_mlp(features_dim, num_actions, net_arch, activation_fn) - q_net = nn.Sequential(*q_net) + q_net_list = create_mlp(features_dim, num_actions.item(), net_arch, activation_fn) + q_net = nn.Sequential(*q_net_list) self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) @@ -195,8 +194,9 @@ class DiscreteCritic(BaseModel): def forward(self, obs: th.Tensor) -> Tuple[th.Tensor, ...]: # Learn the features extractor using the policy loss only # when the features_extractor is shared with the actor - with th.set_grad_enabled(not self.share_features_extractor): - features = self.extract_features(obs, self.features_extractor) + if self.features_extractor is not None: + with th.set_grad_enabled(not self.share_features_extractor): + features = self.extract_features(obs, self.features_extractor) return tuple(q_net(features) for q_net in self.q_networks) diff --git a/sb3_contrib/sacd/sacd.py b/sb3_contrib/sacd/sacd.py index 8433728..f1796b5 100644 --- a/sb3_contrib/sacd/sacd.py +++ b/sb3_contrib/sacd/sacd.py @@ -18,17 +18,14 @@ SelfSACD = TypeVar("SelfSACD", bound="SACD") class SACD(OffPolicyAlgorithm): """ - Soft Actor-Critic (SACD) - Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, - This implementation borrows code from original implementation (https://github.com/haarnoja/sac) - from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo - (https://github.com/rail-berkeley/softlearning/) - and from Stable Baselines (https://github.com/hill-a/stable-baselines) - Paper: https://arxiv.org/abs/1801.01290 - Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html + Discrete Soft Actor-Critic (SACD) + The Soft Actor-Critic Algorithm modified for discrete Action spaces. + This implementation borrows code from the original implementation of + the papers author (https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) + from toshikawa (https://github.com/toshikwa/sac-discrete.pytorch#references) + and from Stable Baselines (https://github.com/DLR-RM/stable-baselines3) - Note: we use double q target and not value target as discussed - in https://github.com/hill-a/stable-baselines/issues/270 + Paper: https://arxiv.org/abs/1910.07207 :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) @@ -168,7 +165,7 @@ class SACD(OffPolicyAlgorithm): # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": # we set the max possible entropy as the target entropy - self.target_entropy = 0.98 * -np.log(1 / np.prod(self.env.action_space.shape)) + self.target_entropy = 0.98 * -np.log(1 / np.prod(np.array(self.action_space.shape))) else: # Force conversion # this will also throw an error for unexpected string @@ -220,17 +217,45 @@ class SACD(OffPolicyAlgorithm): replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] # Compute the critic loss - critic_loss = self.calc_critic_loss(replay_data) + with th.no_grad(): + # Select action according to policy + action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) + + # Compute the next Q values: min over all critics targets + next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2) + next_q_values, _ = th.min(next_q_values, dim=2) + + next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1) + target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values + + # Get current Q-values estimates for each critic network + current_q_values = self.critic(replay_data.observations) + critic_loss = 0.5 * sum( + F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values + ) + assert isinstance(critic_loss, th.Tensor) # for type checker critic_losses.append(critic_loss.item()) self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, self.gradient_clip_norm) # Compute the actor loss - actor_loss, log_action_prob = self.calc_actor_loss(replay_data) + action_prob, log_action_prob = self.actor.action_log_prob(replay_data.observations) + + # Min over all critic networks + q_values_pi = th.stack(self.critic(replay_data.observations), dim=2) + min_qf_pi, _ = th.min(q_values_pi, dim=2) + + inside_term = self.ent_coef_tensor * log_action_prob - min_qf_pi + actor_loss = (action_prob * inside_term).sum(dim=1).mean() actor_losses.append(actor_loss.item()) self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm) - # Compute entropy loss and optimize - self.ent_coef_tensor = self.calc_entropy_loss(log_action_prob) + # Compute entropy loss + if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: + ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean() + ent_coef_losses.append(ent_coef_loss.item()) + self.take_optimisation_step(self.ent_coef_optimizer, None, ent_coef_loss, None) + self.ent_coef_tensor = th.exp(self.log_ent_coef.detach()) + ent_coefs.append(self.ent_coef_tensor.item()) # Update target networks @@ -248,59 +273,19 @@ class SACD(OffPolicyAlgorithm): if len(ent_coef_losses) > 0: self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) - def take_optimisation_step(self, optimizer, network, loss, clipping_norm=None): + def take_optimisation_step( + self, + optimizer: th.optim.Optimizer, + network: Optional[th.nn.Module], + loss: th.Tensor, + clipping_norm: Optional[float] = None, + ) -> None: optimizer.zero_grad() loss.backward() - if clipping_norm is not None: + if clipping_norm is not None and network is not None: th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) # clip gradients to help stabilise training optimizer.step() - def calc_critic_loss(self, replay_data): - with th.no_grad(): - # Select action according to policy - action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) - - # Compute the next Q values: min over all critics targets - next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2) - next_q_values, _ = th.min(next_q_values, dim=2) - - next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1) - target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values - - # Get current Q-values estimates for each critic network - # using action from the replay buffer - current_q_values = self.critic(replay_data.observations) - - # Compute critic loss - critic_loss = 0.5 * sum( - F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values - ) - - return critic_loss - - def calc_actor_loss(self, replay_data): - action_prob, log_prob = self.actor.action_log_prob(replay_data.observations) - - # Min over all critic networks - q_values_pi = th.stack(self.critic(replay_data.observations), dim=2) - min_qf_pi, _ = th.min(q_values_pi, dim=2) - - inside_term = self.ent_coef_tensor * log_prob - min_qf_pi - actor_loss = (action_prob * inside_term).sum(dim=1).mean() - return actor_loss, log_prob - - def calc_entropy_loss(self, log_action_prob): - if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: - # Important: detach the variable from the graph - # so we don't change it with other losses - # see https://github.com/rail-berkeley/softlearning/issues/60 - ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean() - ent_coef_loss.backward() - self.ent_coef_optimizer.step() - return th.exp(self.log_ent_coef.detach()) - else: - return self.ent_coef_tensor - def learn( self: SelfSACD, total_timesteps: int,