Fix masked quantities in RecurrentPPO (#78)
* Ignore masked indexes when calculating the loss functions
This commit is contained in:
parent
75b2de1399
commit
bfa86ce4fe
|
|
@ -22,6 +22,7 @@ New Features:
|
||||||
- Added ``RecurrentPPO`` (aka PPO LSTM)
|
- Added ``RecurrentPPO`` (aka PPO LSTM)
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
|
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
|
|
@ -271,4 +272,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
|
||||||
Contributors:
|
Contributors:
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar
|
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt
|
||||||
|
|
|
||||||
|
|
@ -378,6 +378,9 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
# Convert discrete action from float to long
|
# Convert discrete action from float to long
|
||||||
actions = rollout_data.actions.long().flatten()
|
actions = rollout_data.actions.long().flatten()
|
||||||
|
|
||||||
|
# Convert mask from float to bool
|
||||||
|
mask = rollout_data.mask > 1e-8
|
||||||
|
|
||||||
# Re-sample the noise matrix because the log_std has changed
|
# Re-sample the noise matrix because the log_std has changed
|
||||||
if self.use_sde:
|
if self.use_sde:
|
||||||
self.policy.reset_noise(self.batch_size)
|
self.policy.reset_noise(self.batch_size)
|
||||||
|
|
@ -393,7 +396,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
# Normalize advantage
|
# Normalize advantage
|
||||||
advantages = rollout_data.advantages
|
advantages = rollout_data.advantages
|
||||||
if self.normalize_advantage:
|
if self.normalize_advantage:
|
||||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8)
|
||||||
|
|
||||||
# ratio between old and new policy, should be one at the first iteration
|
# ratio between old and new policy, should be one at the first iteration
|
||||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||||
|
|
@ -401,14 +404,11 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
# clipped surrogate loss
|
# clipped surrogate loss
|
||||||
policy_loss_1 = advantages * ratio
|
policy_loss_1 = advantages * ratio
|
||||||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||||
# Mask padded sequences
|
policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask])
|
||||||
policy_loss_1 = policy_loss_1 * rollout_data.mask
|
|
||||||
policy_loss_2 = policy_loss_2 * rollout_data.mask
|
|
||||||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
pg_losses.append(policy_loss.item())
|
pg_losses.append(policy_loss.item())
|
||||||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item()
|
||||||
clip_fractions.append(clip_fraction)
|
clip_fractions.append(clip_fraction)
|
||||||
|
|
||||||
if self.clip_range_vf is None:
|
if self.clip_range_vf is None:
|
||||||
|
|
@ -422,16 +422,16 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
)
|
)
|
||||||
# Value loss using the TD(gae_lambda) target
|
# Value loss using the TD(gae_lambda) target
|
||||||
# Mask padded sequences
|
# Mask padded sequences
|
||||||
value_loss = th.mean(((rollout_data.returns - values_pred) * rollout_data.mask) ** 2)
|
value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask])
|
||||||
|
|
||||||
value_losses.append(value_loss.item())
|
value_losses.append(value_loss.item())
|
||||||
|
|
||||||
# Entropy loss favor exploration
|
# Entropy loss favor exploration
|
||||||
if entropy is None:
|
if entropy is None:
|
||||||
# Approximate entropy when no analytical form
|
# Approximate entropy when no analytical form
|
||||||
entropy_loss = -th.mean(-(log_prob * rollout_data.mask))
|
entropy_loss = -th.mean(-log_prob[mask])
|
||||||
else:
|
else:
|
||||||
entropy_loss = -th.mean(entropy * rollout_data.mask)
|
entropy_loss = -th.mean(entropy[mask])
|
||||||
|
|
||||||
entropy_losses.append(entropy_loss.item())
|
entropy_losses.append(entropy_loss.item())
|
||||||
|
|
||||||
|
|
@ -443,7 +443,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
# and Schulman blog: http://joschu.net/blog/kl-approx.html
|
# and Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
log_ratio = log_prob - rollout_data.old_log_prob
|
log_ratio = log_prob - rollout_data.old_log_prob
|
||||||
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy()
|
||||||
approx_kl_divs.append(approx_kl_div)
|
approx_kl_divs.append(approx_kl_div)
|
||||||
|
|
||||||
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue