Add notes about MaskablePPO (#227)
* Add note about MaskableEvalCallback * Update doc --------- Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
parent
3f0c5088b3
commit
bc3c0a9595
|
|
@ -35,6 +35,12 @@ MaskablePPO
|
|||
|
||||
Train a PPO with invalid action masking agent on a toy environment.
|
||||
|
||||
.. warning::
|
||||
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
|
||||
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
|
||||
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sb3_contrib import MaskablePPO
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ Others:
|
|||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Add some additional notes about ``MaskablePPO`` (evaluation and multi-process) (@icheered)
|
||||
|
||||
|
||||
Release 2.2.1 (2023-11-17)
|
||||
|
|
@ -544,4 +545,4 @@ Contributors:
|
|||
-------------
|
||||
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
||||
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher
|
||||
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered
|
||||
|
|
|
|||
|
|
@ -45,6 +45,16 @@ MultiBinary ✔️ ✔️
|
|||
Dict ❌ ✔️
|
||||
============= ====== ===========
|
||||
|
||||
.. warning::
|
||||
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
|
||||
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
|
||||
|
||||
|
||||
.. warning::
|
||||
In order to use ``SubprocVecEnv`` with ``MaskablePPO``, you must implement the ``action_masks`` inside the environment (``ActionMasker`` cannot be used).
|
||||
You can have a look at the `built-in environments with invalid action masks <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/envs/invalid_actions_env.py>`_ to have a working example.
|
||||
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
|
@ -58,6 +68,8 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
|
|||
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
|
||||
from sb3_contrib.common.maskable.evaluation import evaluate_policy
|
||||
from sb3_contrib.common.maskable.utils import get_action_masks
|
||||
# This is a drop-in replacement for EvalCallback
|
||||
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
||||
|
||||
|
||||
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
||||
|
|
|
|||
Loading…
Reference in New Issue