Add notes about MaskablePPO (#227)
* Add note about MaskableEvalCallback * Update doc --------- Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
parent
3f0c5088b3
commit
bc3c0a9595
|
|
@ -35,6 +35,12 @@ MaskablePPO
|
||||||
|
|
||||||
Train a PPO with invalid action masking agent on a toy environment.
|
Train a PPO with invalid action masking agent on a toy environment.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
|
||||||
|
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from sb3_contrib import MaskablePPO
|
from sb3_contrib import MaskablePPO
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ Others:
|
||||||
|
|
||||||
Documentation:
|
Documentation:
|
||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
- Add some additional notes about ``MaskablePPO`` (evaluation and multi-process) (@icheered)
|
||||||
|
|
||||||
|
|
||||||
Release 2.2.1 (2023-11-17)
|
Release 2.2.1 (2023-11-17)
|
||||||
|
|
@ -544,4 +545,4 @@ Contributors:
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
||||||
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher
|
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,16 @@ MultiBinary ✔️ ✔️
|
||||||
Dict ❌ ✔️
|
Dict ❌ ✔️
|
||||||
============= ====== ===========
|
============= ====== ===========
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
|
||||||
|
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
|
||||||
|
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
In order to use ``SubprocVecEnv`` with ``MaskablePPO``, you must implement the ``action_masks`` inside the environment (``ActionMasker`` cannot be used).
|
||||||
|
You can have a look at the `built-in environments with invalid action masks <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/envs/invalid_actions_env.py>`_ to have a working example.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Example
|
Example
|
||||||
-------
|
-------
|
||||||
|
|
@ -58,6 +68,8 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
|
||||||
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
|
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
|
||||||
from sb3_contrib.common.maskable.evaluation import evaluate_policy
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy
|
||||||
from sb3_contrib.common.maskable.utils import get_action_masks
|
from sb3_contrib.common.maskable.utils import get_action_masks
|
||||||
|
# This is a drop-in replacement for EvalCallback
|
||||||
|
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
||||||
|
|
||||||
|
|
||||||
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue