Fix masked quantities in RecurrentPPO (#78)

* Ignore masked indexes when calculating the loss functions
This commit is contained in:
rnederstigt 2022-06-13 16:00:40 +02:00 committed by GitHub
parent 75b2de1399
commit bfa86ce4fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 11 deletions

View File

@ -22,6 +22,7 @@ New Features:
- Added ``RecurrentPPO`` (aka PPO LSTM) - Added ``RecurrentPPO`` (aka PPO LSTM)
Bug Fixes: Bug Fixes:
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
^^^^^^^^^^ ^^^^^^^^^^
Deprecations: Deprecations:
@ -271,4 +272,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
Contributors: Contributors:
------------- -------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt

View File

@ -378,6 +378,9 @@ class RecurrentPPO(OnPolicyAlgorithm):
# Convert discrete action from float to long # Convert discrete action from float to long
actions = rollout_data.actions.long().flatten() actions = rollout_data.actions.long().flatten()
# Convert mask from float to bool
mask = rollout_data.mask > 1e-8
# Re-sample the noise matrix because the log_std has changed # Re-sample the noise matrix because the log_std has changed
if self.use_sde: if self.use_sde:
self.policy.reset_noise(self.batch_size) self.policy.reset_noise(self.batch_size)
@ -393,7 +396,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
# Normalize advantage # Normalize advantage
advantages = rollout_data.advantages advantages = rollout_data.advantages
if self.normalize_advantage: if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration # ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob) ratio = th.exp(log_prob - rollout_data.old_log_prob)
@ -401,14 +404,11 @@ class RecurrentPPO(OnPolicyAlgorithm):
# clipped surrogate loss # clipped surrogate loss
policy_loss_1 = advantages * ratio policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
# Mask padded sequences policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask])
policy_loss_1 = policy_loss_1 * rollout_data.mask
policy_loss_2 = policy_loss_2 * rollout_data.mask
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
# Logging # Logging
pg_losses.append(policy_loss.item()) pg_losses.append(policy_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item()
clip_fractions.append(clip_fraction) clip_fractions.append(clip_fraction)
if self.clip_range_vf is None: if self.clip_range_vf is None:
@ -422,16 +422,16 @@ class RecurrentPPO(OnPolicyAlgorithm):
) )
# Value loss using the TD(gae_lambda) target # Value loss using the TD(gae_lambda) target
# Mask padded sequences # Mask padded sequences
value_loss = th.mean(((rollout_data.returns - values_pred) * rollout_data.mask) ** 2) value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask])
value_losses.append(value_loss.item()) value_losses.append(value_loss.item())
# Entropy loss favor exploration # Entropy loss favor exploration
if entropy is None: if entropy is None:
# Approximate entropy when no analytical form # Approximate entropy when no analytical form
entropy_loss = -th.mean(-(log_prob * rollout_data.mask)) entropy_loss = -th.mean(-log_prob[mask])
else: else:
entropy_loss = -th.mean(entropy * rollout_data.mask) entropy_loss = -th.mean(entropy[mask])
entropy_losses.append(entropy_loss.item()) entropy_losses.append(entropy_loss.item())
@ -443,7 +443,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
# and Schulman blog: http://joschu.net/blog/kl-approx.html # and Schulman blog: http://joschu.net/blog/kl-approx.html
with th.no_grad(): with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy()
approx_kl_divs.append(approx_kl_div) approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: