Allow PPO to turn off advantage normalization (#61)

* Allow PPO to turn off advantage normalization

* Quick fix

* Add test cases

* Update docs

* Quick fix

* Quick fix

* Fix sort
This commit is contained in:
Costa Huang 2022-02-23 04:11:16 -05:00 committed by GitHub
parent 901a648507
commit f5c1aaa194
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 2 deletions

View File

@ -14,6 +14,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Allow PPO to turn of advantage normalization (see `PR #61 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/61>`_) @vwxyzjn
Bug Fixes:
^^^^^^^^^^

View File

@ -46,6 +46,7 @@ class MaskablePPO(OnPolicyAlgorithm):
This is a parameter specific to the OpenAI implementation. If None is passed (default),
no clipping will be done on the value function.
IMPORTANT: this clipping depends on the reward scaling.
:param normalize_advantage: Whether to normalize or not the advantage
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
@ -76,6 +77,7 @@ class MaskablePPO(OnPolicyAlgorithm):
gae_lambda: float = 0.95,
clip_range: Union[float, Schedule] = 0.2,
clip_range_vf: Union[None, float, Schedule] = None,
normalize_advantage: bool = True,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
@ -119,6 +121,7 @@ class MaskablePPO(OnPolicyAlgorithm):
self.n_epochs = n_epochs
self.clip_range = clip_range
self.clip_range_vf = clip_range_vf
self.normalize_advantage = normalize_advantage
self.target_kl = target_kl
if _init_setup_model:
@ -432,6 +435,7 @@ class MaskablePPO(OnPolicyAlgorithm):
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration

View File

@ -3,7 +3,8 @@ import pytest
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import ARS, QRDQN, TQC, TRPO
from sb3_contrib import ARS, QRDQN, TQC, TRPO, MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sb3_contrib.common.vec_env import AsyncEval
@ -142,3 +143,10 @@ def test_offpolicy_multi_env(model_class):
train_freq=5,
)
model.learn(total_timesteps=150)
@pytest.mark.parametrize("normalize_advantage", [False, True])
def test_advantage_normalization(normalize_advantage):
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
model = MaskablePPO("MlpPolicy", env, n_steps=64, normalize_advantage=normalize_advantage)
model.learn(64)