From bc3c0a95952916fcc64d989ca631a0e98e676091 Mon Sep 17 00:00:00 2001 From: Tjeerd Bakker <37150873+icheered@users.noreply.github.com> Date: Thu, 18 Jan 2024 17:03:02 +0100 Subject: [PATCH] Add notes about MaskablePPO (#227) * Add note about MaskableEvalCallback * Update doc --------- Co-authored-by: Antonin Raffin --- docs/guide/examples.rst | 6 ++++++ docs/misc/changelog.rst | 3 ++- docs/modules/ppo_mask.rst | 12 ++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index de1f946..83546cf 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -35,6 +35,12 @@ MaskablePPO Train a PPO with invalid action masking agent on a toy environment. +.. warning:: + You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks. + Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one. + + + .. code-block:: python from sb3_contrib import MaskablePPO diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index ba3bb60..a9e94af 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -34,6 +34,7 @@ Others: Documentation: ^^^^^^^^^^^^^^ +- Add some additional notes about ``MaskablePPO`` (evaluation and multi-process) (@icheered) Release 2.2.1 (2023-11-17) @@ -544,4 +545,4 @@ Contributors: ------------- @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec -@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher +@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered diff --git a/docs/modules/ppo_mask.rst b/docs/modules/ppo_mask.rst index 89d14d1..4ff6f06 100644 --- a/docs/modules/ppo_mask.rst +++ b/docs/modules/ppo_mask.rst @@ -45,6 +45,16 @@ MultiBinary ✔️ ✔️ Dict ❌ ✔️ ============= ====== =========== +.. warning:: + You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks. + Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one. + + +.. warning:: + In order to use ``SubprocVecEnv`` with ``MaskablePPO``, you must implement the ``action_masks`` inside the environment (``ActionMasker`` cannot be used). + You can have a look at the `built-in environments with invalid action masks `_ to have a working example. + + Example ------- @@ -58,6 +68,8 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe from sb3_contrib.common.envs import InvalidActionEnvDiscrete from sb3_contrib.common.maskable.evaluation import evaluate_policy from sb3_contrib.common.maskable.utils import get_action_masks + # This is a drop-in replacement for EvalCallback + from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)