diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5554ded..fafe6f6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -22,6 +22,7 @@ New Features: - Added ``RecurrentPPO`` (aka PPO LSTM) Bug Fixes: +- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt) ^^^^^^^^^^ Deprecations: @@ -271,4 +272,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) Contributors: ------------- -@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar +@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index f0920f9..3f9b6e4 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -378,6 +378,9 @@ class RecurrentPPO(OnPolicyAlgorithm): # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() + # Convert mask from float to bool + mask = rollout_data.mask > 1e-8 + # Re-sample the noise matrix because the log_std has changed if self.use_sde: self.policy.reset_noise(self.batch_size) @@ -393,7 +396,7 @@ class RecurrentPPO(OnPolicyAlgorithm): # Normalize advantage advantages = rollout_data.advantages if self.normalize_advantage: - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) @@ -401,14 +404,11 @@ class RecurrentPPO(OnPolicyAlgorithm): # clipped surrogate loss policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) - # Mask padded sequences - policy_loss_1 = policy_loss_1 * rollout_data.mask - policy_loss_2 = policy_loss_2 * rollout_data.mask - policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() + policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask]) # Logging pg_losses.append(policy_loss.item()) - clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: @@ -422,16 +422,16 @@ class RecurrentPPO(OnPolicyAlgorithm): ) # Value loss using the TD(gae_lambda) target # Mask padded sequences - value_loss = th.mean(((rollout_data.returns - values_pred) * rollout_data.mask) ** 2) + value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask]) value_losses.append(value_loss.item()) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form - entropy_loss = -th.mean(-(log_prob * rollout_data.mask)) + entropy_loss = -th.mean(-log_prob[mask]) else: - entropy_loss = -th.mean(entropy * rollout_data.mask) + entropy_loss = -th.mean(entropy[mask]) entropy_losses.append(entropy_loss.item()) @@ -443,7 +443,7 @@ class RecurrentPPO(OnPolicyAlgorithm): # and Schulman blog: http://joschu.net/blog/kl-approx.html with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob - approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: