Add notes about MaskablePPO (#227)

* Add note about MaskableEvalCallback

* Update doc

---------

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
Tjeerd Bakker 2024-01-18 17:03:02 +01:00 committed by Andreas Schaarschmidt
parent 3f0c5088b3
commit bc3c0a9595
3 changed files with 20 additions and 1 deletions

View File

@ -35,6 +35,12 @@ MaskablePPO
Train a PPO with invalid action masking agent on a toy environment. Train a PPO with invalid action masking agent on a toy environment.
.. warning::
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
.. code-block:: python .. code-block:: python
from sb3_contrib import MaskablePPO from sb3_contrib import MaskablePPO

View File

@ -34,6 +34,7 @@ Others:
Documentation: Documentation:
^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^
- Add some additional notes about ``MaskablePPO`` (evaluation and multi-process) (@icheered)
Release 2.2.1 (2023-11-17) Release 2.2.1 (2023-11-17)
@ -544,4 +545,4 @@ Contributors:
------------- -------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered

View File

@ -45,6 +45,16 @@ MultiBinary ✔️ ✔️
Dict ❌ ✔️ Dict ❌ ✔️
============= ====== =========== ============= ====== ===========
.. warning::
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
.. warning::
In order to use ``SubprocVecEnv`` with ``MaskablePPO``, you must implement the ``action_masks`` inside the environment (``ActionMasker`` cannot be used).
You can have a look at the `built-in environments with invalid action masks <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/envs/invalid_actions_env.py>`_ to have a working example.
Example Example
------- -------
@ -58,6 +68,8 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
from sb3_contrib.common.envs import InvalidActionEnvDiscrete from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sb3_contrib.common.maskable.evaluation import evaluate_policy from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.utils import get_action_masks from sb3_contrib.common.maskable.utils import get_action_masks
# This is a drop-in replacement for EvalCallback
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60) env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)