diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 825b80c..621043c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -14,6 +14,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Allow PPO to turn of advantage normalization (see `PR #61 `_) @vwxyzjn Bug Fixes: ^^^^^^^^^^ diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 767f36e..b9cd6b5 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -46,6 +46,7 @@ class MaskablePPO(OnPolicyAlgorithm): This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling. + :param normalize_advantage: Whether to normalize or not the advantage :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping @@ -76,6 +77,7 @@ class MaskablePPO(OnPolicyAlgorithm): gae_lambda: float = 0.95, clip_range: Union[float, Schedule] = 0.2, clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, @@ -119,6 +121,7 @@ class MaskablePPO(OnPolicyAlgorithm): self.n_epochs = n_epochs self.clip_range = clip_range self.clip_range_vf = clip_range_vf + self.normalize_advantage = normalize_advantage self.target_kl = target_kl if _init_setup_model: @@ -432,7 +435,8 @@ class MaskablePPO(OnPolicyAlgorithm): values = values.flatten() # Normalize advantage advantages = rollout_data.advantages - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) diff --git a/tests/test_run.py b/tests/test_run.py index d6a2307..c9c8584 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -3,7 +3,8 @@ import pytest from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import VecNormalize -from sb3_contrib import ARS, QRDQN, TQC, TRPO +from sb3_contrib import ARS, QRDQN, TQC, TRPO, MaskablePPO +from sb3_contrib.common.envs import InvalidActionEnvDiscrete from sb3_contrib.common.vec_env import AsyncEval @@ -142,3 +143,10 @@ def test_offpolicy_multi_env(model_class): train_freq=5, ) model.learn(total_timesteps=150) + + +@pytest.mark.parametrize("normalize_advantage", [False, True]) +def test_advantage_normalization(normalize_advantage): + env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60) + model = MaskablePPO("MlpPolicy", env, n_steps=64, normalize_advantage=normalize_advantage) + model.learn(64)