Fix masked quantities in RecurrentPPO (#78)
* Ignore masked indexes when calculating the loss functions
This commit is contained in:
parent
75b2de1399
commit
bfa86ce4fe
|
|
@ -22,6 +22,7 @@ New Features:
|
|||
- Added ``RecurrentPPO`` (aka PPO LSTM)
|
||||
|
||||
Bug Fixes:
|
||||
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
|
|
@ -271,4 +272,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
|
|||
Contributors:
|
||||
-------------
|
||||
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt
|
||||
|
|
|
|||
|
|
@ -378,6 +378,9 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
# Convert discrete action from float to long
|
||||
actions = rollout_data.actions.long().flatten()
|
||||
|
||||
# Convert mask from float to bool
|
||||
mask = rollout_data.mask > 1e-8
|
||||
|
||||
# Re-sample the noise matrix because the log_std has changed
|
||||
if self.use_sde:
|
||||
self.policy.reset_noise(self.batch_size)
|
||||
|
|
@ -393,7 +396,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
# Normalize advantage
|
||||
advantages = rollout_data.advantages
|
||||
if self.normalize_advantage:
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8)
|
||||
|
||||
# ratio between old and new policy, should be one at the first iteration
|
||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||
|
|
@ -401,14 +404,11 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
# clipped surrogate loss
|
||||
policy_loss_1 = advantages * ratio
|
||||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
# Mask padded sequences
|
||||
policy_loss_1 = policy_loss_1 * rollout_data.mask
|
||||
policy_loss_2 = policy_loss_2 * rollout_data.mask
|
||||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
||||
policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask])
|
||||
|
||||
# Logging
|
||||
pg_losses.append(policy_loss.item())
|
||||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
||||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item()
|
||||
clip_fractions.append(clip_fraction)
|
||||
|
||||
if self.clip_range_vf is None:
|
||||
|
|
@ -422,16 +422,16 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
)
|
||||
# Value loss using the TD(gae_lambda) target
|
||||
# Mask padded sequences
|
||||
value_loss = th.mean(((rollout_data.returns - values_pred) * rollout_data.mask) ** 2)
|
||||
value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask])
|
||||
|
||||
value_losses.append(value_loss.item())
|
||||
|
||||
# Entropy loss favor exploration
|
||||
if entropy is None:
|
||||
# Approximate entropy when no analytical form
|
||||
entropy_loss = -th.mean(-(log_prob * rollout_data.mask))
|
||||
entropy_loss = -th.mean(-log_prob[mask])
|
||||
else:
|
||||
entropy_loss = -th.mean(entropy * rollout_data.mask)
|
||||
entropy_loss = -th.mean(entropy[mask])
|
||||
|
||||
entropy_losses.append(entropy_loss.item())
|
||||
|
||||
|
|
@ -443,7 +443,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
# and Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
with th.no_grad():
|
||||
log_ratio = log_prob - rollout_data.old_log_prob
|
||||
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
||||
approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy()
|
||||
approx_kl_divs.append(approx_kl_div)
|
||||
|
||||
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
||||
|
|
|
|||
Loading…
Reference in New Issue