Allow PPO to turn off advantage normalization (#61)

* Allow PPO to turn off advantage normalization

* Quick fix

* Add test cases

* Update docs

* Quick fix

* Quick fix

* Fix sort
This commit is contained in:
Costa Huang 2022-02-23 04:11:16 -05:00 committed by GitHub
parent 901a648507
commit f5c1aaa194
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 2 deletions

View File

@ -14,6 +14,7 @@ Breaking Changes:
New Features: New Features:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
- Allow PPO to turn of advantage normalization (see `PR #61 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/61>`_) @vwxyzjn
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^

View File

@ -46,6 +46,7 @@ class MaskablePPO(OnPolicyAlgorithm):
This is a parameter specific to the OpenAI implementation. If None is passed (default), This is a parameter specific to the OpenAI implementation. If None is passed (default),
no clipping will be done on the value function. no clipping will be done on the value function.
IMPORTANT: this clipping depends on the reward scaling. IMPORTANT: this clipping depends on the reward scaling.
:param normalize_advantage: Whether to normalize or not the advantage
:param ent_coef: Entropy coefficient for the loss calculation :param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping :param max_grad_norm: The maximum value for the gradient clipping
@ -76,6 +77,7 @@ class MaskablePPO(OnPolicyAlgorithm):
gae_lambda: float = 0.95, gae_lambda: float = 0.95,
clip_range: Union[float, Schedule] = 0.2, clip_range: Union[float, Schedule] = 0.2,
clip_range_vf: Union[None, float, Schedule] = None, clip_range_vf: Union[None, float, Schedule] = None,
normalize_advantage: bool = True,
ent_coef: float = 0.0, ent_coef: float = 0.0,
vf_coef: float = 0.5, vf_coef: float = 0.5,
max_grad_norm: float = 0.5, max_grad_norm: float = 0.5,
@ -119,6 +121,7 @@ class MaskablePPO(OnPolicyAlgorithm):
self.n_epochs = n_epochs self.n_epochs = n_epochs
self.clip_range = clip_range self.clip_range = clip_range
self.clip_range_vf = clip_range_vf self.clip_range_vf = clip_range_vf
self.normalize_advantage = normalize_advantage
self.target_kl = target_kl self.target_kl = target_kl
if _init_setup_model: if _init_setup_model:
@ -432,7 +435,8 @@ class MaskablePPO(OnPolicyAlgorithm):
values = values.flatten() values = values.flatten()
# Normalize advantage # Normalize advantage
advantages = rollout_data.advantages advantages = rollout_data.advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration # ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob) ratio = th.exp(log_prob - rollout_data.old_log_prob)

View File

@ -3,7 +3,8 @@ import pytest
from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import ARS, QRDQN, TQC, TRPO from sb3_contrib import ARS, QRDQN, TQC, TRPO, MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sb3_contrib.common.vec_env import AsyncEval from sb3_contrib.common.vec_env import AsyncEval
@ -142,3 +143,10 @@ def test_offpolicy_multi_env(model_class):
train_freq=5, train_freq=5,
) )
model.learn(total_timesteps=150) model.learn(total_timesteps=150)
@pytest.mark.parametrize("normalize_advantage", [False, True])
def test_advantage_normalization(normalize_advantage):
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
model = MaskablePPO("MlpPolicy", env, n_steps=64, normalize_advantage=normalize_advantage)
model.learn(64)