Fix masked quantities in RecurrentPPO (#78)

* Ignore masked indexes when calculating the loss functions
This commit is contained in:
rnederstigt 2022-06-13 16:00:40 +02:00 committed by GitHub
parent 75b2de1399
commit bfa86ce4fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 11 deletions

View File

@ -22,6 +22,7 @@ New Features:
- Added ``RecurrentPPO`` (aka PPO LSTM)
Bug Fixes:
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
^^^^^^^^^^
Deprecations:
@ -271,4 +272,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
Contributors:
-------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt

View File

@ -378,6 +378,9 @@ class RecurrentPPO(OnPolicyAlgorithm):
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()
# Convert mask from float to bool
mask = rollout_data.mask > 1e-8
# Re-sample the noise matrix because the log_std has changed
if self.use_sde:
self.policy.reset_noise(self.batch_size)
@ -393,7 +396,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
# Normalize advantage
advantages = rollout_data.advantages
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)
@ -401,14 +404,11 @@ class RecurrentPPO(OnPolicyAlgorithm):
# clipped surrogate loss
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
# Mask padded sequences
policy_loss_1 = policy_loss_1 * rollout_data.mask
policy_loss_2 = policy_loss_2 * rollout_data.mask
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask])
# Logging
pg_losses.append(policy_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
@ -422,16 +422,16 @@ class RecurrentPPO(OnPolicyAlgorithm):
)
# Value loss using the TD(gae_lambda) target
# Mask padded sequences
value_loss = th.mean(((rollout_data.returns - values_pred) * rollout_data.mask) ** 2)
value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask])
value_losses.append(value_loss.item())
# Entropy loss favor exploration
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -th.mean(-(log_prob * rollout_data.mask))
entropy_loss = -th.mean(-log_prob[mask])
else:
entropy_loss = -th.mean(entropy * rollout_data.mask)
entropy_loss = -th.mean(entropy[mask])
entropy_losses.append(entropy_loss.item())
@ -443,7 +443,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
# and Schulman blog: http://joschu.net/blog/kl-approx.html
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: