PPO variant with invalid action masking (#25)
* Add wrappers * Add maskable distributions * Add mypy configuration * Add maskable base datastructures * Add ppo_mask package * Fix circular dependency and remove test code that slipped in * Automatically mask vecenv if env is masked * Fix debugging change that slipped in * Workaround for subclassing RolloutBufferSamples * Duplicate lots of policy code in order to swap out the distributions used * Fix pytype error * Maintain py 3.6 compatibility * Fix isort lint errors * Use pyproject.toml to configure black line length * Blacken * Remove mypy.ini * Fully replace RolloutBufferSamples * Drop support for continuous distributions, remove SDE-related code * Eliminate MaskableAlgorithm and MaskableOnPolicyAlgorithm * Fix formatting * Override superclass methods as needed, fix circular import, improve naming * Fix codestyle * Eliminate VecActionMasker, replace with utils * Fix codestyle * Support masking for MultiDiscrete action spaces * Fix codestyle * Don't require the env to provide the mask already flattened * Consistent naming, prefer 'Maskable' to 'Masked' * Register policy * Link to abstract instead of pdf * Allow distribution masking to be unapplied + improved comments and docstrings * Don't use deprecated implicit optional typing * Check codestyle * Add docstring and remove misplaced TODO * Simplify env masking API, error if API unmet. Make use_masking a learn() kwarg * Fix codestyle * Update various internals to be consistent with latest SB3 * Simplify MaskableRolloutBuffer reset * Add docstring and type annotations * Ensure old probs aren't cached * Fix for new logger * Add test + fixes * Start doc * Fix type annotation * Remove abstract class + add test * Fix evaluation (add support for multi envs) * Handle merge conflicts in documentation * Bugfix: mask updates should apply to original logits, not the last masked output * Add test of distribution masking behavior * Reformat * Add MultiBinary support, remove unneeded distribution type checks * Remove unused import * Fix when using multiple envs * Remove addressed TODO * Upgrade for SB3 1.2.0 * Update docs with results + how to replicate * Add action masker tests, move wrapper tests * Move distributions, add more distribution tests * Add MaskablePPO tests, simplify and rename discrete test env * Address TODO * Add tests for MaskableMultiCategoricalDistribution, fix distributions * Add maskable identity envs for all supported action spaces, add tests, fix bug * Formatting fixes * Update doc env * Dict support not ready * Cleanup Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
b2e7126840
commit
ab24f8039f
2
Makefile
2
Makefile
|
|
@ -24,7 +24,7 @@ check-codestyle:
|
|||
# Sort imports
|
||||
isort --check ${LINT_PATHS}
|
||||
# Reformat using black
|
||||
black --check -l 127 ${LINT_PATHS}
|
||||
black --check ${LINT_PATHS}
|
||||
|
||||
commit-checks: format type lint
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ dependencies:
|
|||
- cpuonly=1.0=0
|
||||
- pip=20.2
|
||||
- python=3.6
|
||||
- pytorch=1.5.0=py3.6_cpu_0
|
||||
- pytorch=1.8.1=py3.6_cpu_0
|
||||
- pip:
|
||||
- gym>=0.17.2
|
||||
- cloudpickle
|
||||
|
|
@ -15,4 +15,6 @@ dependencies:
|
|||
- numpy
|
||||
- matplotlib
|
||||
- sphinx_autodoc_typehints
|
||||
- stable-baselines3>=0.10.0a0
|
||||
- stable-baselines3>=1.2.0
|
||||
# See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
|
||||
- sphinx_rtd_theme>=1.0
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 76 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 59 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 80 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 58 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 61 KiB |
|
|
@ -33,6 +33,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
|
|||
|
||||
modules/tqc
|
||||
modules/qrdqn
|
||||
modules/ppo_mask
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
|
|
|||
|
|
@ -3,6 +3,31 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.2.1a0 (WIP)
|
||||
-------------------------------
|
||||
|
||||
**Invalid action masking for PPO**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``MaskablePPO`` algorithm (@kronion)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Release 1.2.0 (2021-09-08)
|
||||
-------------------------------
|
||||
|
|
@ -13,6 +38,9 @@ Breaking Changes:
|
|||
^^^^^^^^^^^^^^^^^
|
||||
- Upgraded to Stable-Baselines3 >= 1.2.0
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright)
|
||||
|
|
@ -156,4 +184,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
|
|||
Contributors:
|
||||
-------------
|
||||
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion
|
||||
|
|
|
|||
|
|
@ -0,0 +1,256 @@
|
|||
.. _ppo_mask:
|
||||
|
||||
.. automodule:: sb3_contrib.ppo_mask
|
||||
|
||||
Maskable PPO
|
||||
============
|
||||
|
||||
Implementation of `invalid action masking <https://arxiv.org/abs/2006.14171>`_ for the Proximal Policy Optimization(PPO)
|
||||
algorithm. Other than adding support for action masking, the behavior is the same as in SB3's core PPO algorithm.
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
MlpPolicy
|
||||
CnnPolicy
|
||||
.. MultiInputPolicy
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
- Paper: https://arxiv.org/abs/2006.14171
|
||||
- Blog post: https://costa.sh/blog-a-closer-look-at-invalid-action-masking-in-policy-gradient-algorithms.html
|
||||
- Additional Blog post: https://boring-guy.sh/posts/masking-rl/
|
||||
|
||||
|
||||
Can I use?
|
||||
----------
|
||||
|
||||
- Recurrent policies: ❌
|
||||
- Multi processing: ✔️
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ✔️ ✔️
|
||||
Box ❌ ✔️
|
||||
MultiDiscrete ✔️ ✔️
|
||||
MultiBinary ✔️ ✔️
|
||||
Dict ❌ ❌
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
Train a PPO agent on ``InvalidActionEnvDiscrete``. ``InvalidActionEnvDiscrete`` has a ``action_masks`` method that
|
||||
returns the invalid action mask (``True`` if the action is valid, ``False`` otherwise).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sb3_contrib import MaskablePPO
|
||||
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
|
||||
from sb3_contrib.common.maskable.evaluation import evaluate_policy
|
||||
from sb3_contrib.common.maskable.utils import get_action_masks
|
||||
|
||||
|
||||
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
||||
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, verbose=1)
|
||||
model.learn(5000)
|
||||
|
||||
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||
|
||||
model.save("ppo_mask")
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = MaskablePPO.load("ppo_mask")
|
||||
|
||||
obs = env.reset()
|
||||
while True:
|
||||
# Retrieve current action mask
|
||||
action_masks = get_action_masks(env)
|
||||
action, _states = model.predict(obs, action_masks=action_masks)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
|
||||
If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker``
|
||||
to specify the name (see `PR #25 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/25>`_):
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
||||
from sb3_contrib.common.wrappers import ActionMasker
|
||||
from sb3_contrib.ppo_mask import MaskablePPO
|
||||
|
||||
|
||||
def mask_fn(env: gym.Env) -> np.ndarray:
|
||||
# Do whatever you'd like in this function to return the action mask
|
||||
# for the current env. In this example, we assume the env has a
|
||||
# helpful method we can rely on.
|
||||
return env.valid_action_mask()
|
||||
|
||||
|
||||
env = ... # Initialize env
|
||||
env = ActionMasker(env, mask_fn) # Wrap to enable masking
|
||||
|
||||
# MaskablePPO behaves the same as SB3's PPO unless the env is wrapped
|
||||
# with ActionMasker. If the wrapper is detected, the masks are automatically
|
||||
# retrieved and used when learning. Note that MaskablePPO does not accept
|
||||
# a new action_mask_fn kwarg, as it did in an earlier draft.
|
||||
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)
|
||||
model.learn()
|
||||
|
||||
# Note that use of masks is manual and optional outside of learning,
|
||||
# so masking can be "removed" at testing time
|
||||
model.predict(observation, action_masks=valid_action_array)
|
||||
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
Results are shown for two MicroRTS benchmarks: MicrortsMining4x4F9-v0 (600K steps) and MicrortsMining10x10F9-v0
|
||||
(1.5M steps). For each, models were trained with and without masking, using 3 seeds.
|
||||
|
||||
4x4
|
||||
^^^
|
||||
|
||||
No masking
|
||||
""""""""""
|
||||
|
||||
.. image:: /images/4x4_no_mask.png
|
||||
|
||||
With masking
|
||||
""""""""""""
|
||||
|
||||
.. image:: /images/4x4_mask.png
|
||||
|
||||
Combined
|
||||
""""""""
|
||||
|
||||
.. image:: /images/4x4_combined.png
|
||||
|
||||
10x10
|
||||
^^^^^
|
||||
|
||||
No masking
|
||||
""""""""""
|
||||
|
||||
.. image:: /images/10x10_no_mask.png
|
||||
|
||||
With masking
|
||||
""""""""""""
|
||||
|
||||
""""""""
|
||||
|
||||
.. image:: /images/10x10_mask.png
|
||||
|
||||
Combined
|
||||
""""""""
|
||||
|
||||
.. image:: /images/10x10_combined.png
|
||||
|
||||
More information may be found in the
|
||||
`associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/25>`_.
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Clone the repo for the experiment:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone git@github.com:kronion/microrts-ppo-comparison.git
|
||||
cd microrts-ppo-comparison
|
||||
|
||||
Install dependencies:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Install MicroRTS:
|
||||
rm -fR ~/microrts && mkdir ~/microrts && \
|
||||
wget -O ~/microrts/microrts.zip http://microrts.s3.amazonaws.com/microrts/artifacts/202004222224.microrts.zip && \
|
||||
unzip ~/microrts/microrts.zip -d ~/microrts/
|
||||
|
||||
# You may want to make a venv before installing packages
|
||||
pip install -r requirements.txt
|
||||
|
||||
Train several times with various seeds, with and without masking:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# python sb/train_ppo.py [output dir] [MicroRTS map size] [--mask] [--seed int]
|
||||
|
||||
# 4x4 unmasked
|
||||
python sb3/train_ppo.py zoo 4 --seed 42
|
||||
python sb3/train_ppo.py zoo 4 --seed 43
|
||||
python sb3/train_ppo.py zoo 4 --seed 44
|
||||
|
||||
# 4x4 masked
|
||||
python sb3/train_ppo.py zoo 4 --mask --seed 42
|
||||
python sb3/train_ppo.py zoo 4 --mask --seed 43
|
||||
python sb3/train_ppo.py zoo 4 --mask --seed 44
|
||||
|
||||
# 10x10 unmasked
|
||||
python sb3/train_ppo.py zoo 10 --seed 42
|
||||
python sb3/train_ppo.py zoo 10 --seed 43
|
||||
python sb3/train_ppo.py zoo 10 --seed 44
|
||||
|
||||
# 10x10 masked
|
||||
python sb3/train_ppo.py zoo 10 --mask --seed 42
|
||||
python sb3/train_ppo.py zoo 10 --mask --seed 43
|
||||
python sb3/train_ppo.py zoo 10 --mask --seed 44
|
||||
|
||||
View the tensorboard log output:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# For 4x4 environment
|
||||
tensorboard --logdir zoo/4x4/runs
|
||||
|
||||
# For 10x10 environment
|
||||
tensorboard --logdir zoo/10x10/runs
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
.. autoclass:: MaskablePPO
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
|
||||
MaskablePPO Policies
|
||||
--------------------
|
||||
|
||||
.. autoclass:: MlpPolicy
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: sb3_contrib.common.maskable.policies.MaskableActorCriticPolicy
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
.. autoclass:: CnnPolicy
|
||||
:members:
|
||||
|
||||
.. autoclass:: sb3_contrib.common.maskable.policies.MaskableActorCriticCnnPolicy
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
.. .. autoclass:: MultiInputPolicy
|
||||
.. :members:
|
||||
..
|
||||
.. .. autoclass:: sb3_contrib.common.maskable.policies.MaskableMultiInputActorCriticPolicy
|
||||
.. :members:
|
||||
.. :noindex:
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
[tool.black]
|
||||
line-length = 127
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
|
||||
from sb3_contrib.ppo_mask import MaskablePPO
|
||||
from sb3_contrib.qrdqn import QRDQN
|
||||
from sb3_contrib.tqc import TQC
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
from sb3_contrib.common.envs.invalid_actions_env import (
|
||||
InvalidActionEnvDiscrete,
|
||||
InvalidActionEnvMultiBinary,
|
||||
InvalidActionEnvMultiDiscrete,
|
||||
)
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.envs import IdentityEnv
|
||||
|
||||
|
||||
class InvalidActionEnvDiscrete(IdentityEnv):
|
||||
"""
|
||||
Identity env with a discrete action space. Supports action masking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: Optional[int] = None,
|
||||
ep_length: int = 100,
|
||||
n_invalid_actions: int = 0,
|
||||
):
|
||||
if dim is None:
|
||||
dim = 1
|
||||
assert n_invalid_actions < dim, f"Too many invalid actions: {n_invalid_actions} < {dim}"
|
||||
|
||||
space = spaces.Discrete(dim)
|
||||
self.n_invalid_actions = n_invalid_actions
|
||||
self.possible_actions = np.arange(space.n)
|
||||
self.invalid_actions: List[int] = []
|
||||
super().__init__(space=space, ep_length=ep_length)
|
||||
|
||||
def _choose_next_state(self) -> None:
|
||||
self.state = self.action_space.sample()
|
||||
# Randomly choose invalid actions that are not the current state
|
||||
potential_invalid_actions = [i for i in self.possible_actions if i != self.state]
|
||||
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
|
||||
|
||||
def action_masks(self) -> List[bool]:
|
||||
return [action not in self.invalid_actions for action in self.possible_actions]
|
||||
|
||||
|
||||
class InvalidActionEnvMultiDiscrete(IdentityEnv):
|
||||
"""
|
||||
Identity env with a multidiscrete action space. Supports action masking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Optional[List[int]] = None,
|
||||
ep_length: int = 100,
|
||||
n_invalid_actions: int = 0,
|
||||
):
|
||||
if dims is None:
|
||||
dims = [1, 1]
|
||||
|
||||
if n_invalid_actions > sum(dims) - len(dims):
|
||||
raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {sum(dims) - len(dims)}")
|
||||
|
||||
space = spaces.MultiDiscrete(dims)
|
||||
self.n_invalid_actions = n_invalid_actions
|
||||
self.possible_actions = np.arange(sum(dims))
|
||||
self.invalid_actions: List[int] = []
|
||||
super().__init__(space=space, ep_length=ep_length)
|
||||
|
||||
def _choose_next_state(self) -> None:
|
||||
self.state = self.action_space.sample()
|
||||
|
||||
converted_state: List[int] = []
|
||||
running_total = 0
|
||||
for i in range(len(self.action_space.nvec)):
|
||||
converted_state.append(running_total + self.state[i])
|
||||
running_total += self.action_space.nvec[i]
|
||||
|
||||
# Randomly choose invalid actions that are not the current state
|
||||
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
|
||||
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
|
||||
|
||||
def action_masks(self) -> List[bool]:
|
||||
return [action not in self.invalid_actions for action in self.possible_actions]
|
||||
|
||||
|
||||
class InvalidActionEnvMultiBinary(IdentityEnv):
|
||||
"""
|
||||
Identity env with a multibinary action space. Supports action masking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Optional[int] = None,
|
||||
ep_length: int = 100,
|
||||
n_invalid_actions: int = 0,
|
||||
):
|
||||
if dims is None:
|
||||
dims = 1
|
||||
|
||||
if n_invalid_actions > dims:
|
||||
raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}")
|
||||
|
||||
space = spaces.MultiBinary(dims)
|
||||
self.n_invalid_actions = n_invalid_actions
|
||||
self.possible_actions = np.arange(2 * dims)
|
||||
self.invalid_actions: List[int] = []
|
||||
super().__init__(space=space, ep_length=ep_length)
|
||||
|
||||
def _choose_next_state(self) -> None:
|
||||
self.state = self.action_space.sample()
|
||||
|
||||
converted_state: List[int] = []
|
||||
running_total = 0
|
||||
for i in range(self.action_space.n):
|
||||
converted_state.append(running_total + self.state[i])
|
||||
running_total += 2
|
||||
|
||||
# Randomly choose invalid actions that are not the current state
|
||||
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
|
||||
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
|
||||
|
||||
def action_masks(self) -> List[bool]:
|
||||
return [action not in self.invalid_actions for action in self.possible_actions]
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
from typing import Generator, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
|
||||
class MaskableRolloutBufferSamples(NamedTuple):
|
||||
observations: th.Tensor
|
||||
actions: th.Tensor
|
||||
old_values: th.Tensor
|
||||
old_log_prob: th.Tensor
|
||||
advantages: th.Tensor
|
||||
returns: th.Tensor
|
||||
action_masks: th.Tensor
|
||||
|
||||
|
||||
class MaskableRolloutBuffer(RolloutBuffer):
|
||||
"""
|
||||
Rollout buffer that also stores the invalid action masks associated with each observation.
|
||||
|
||||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param device:
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
Equivalent to classic advantage when set to 1.
|
||||
:param gamma: Discount factor
|
||||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.action_masks = None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset(self) -> None:
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
mask_dims = self.action_space.n
|
||||
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
||||
mask_dims = sum(self.action_space.nvec)
|
||||
elif isinstance(self.action_space, spaces.MultiBinary):
|
||||
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
||||
|
||||
self.mask_dims = mask_dims
|
||||
self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
|
||||
|
||||
super().reset()
|
||||
|
||||
def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
|
||||
"""
|
||||
:param action_masks: Masks applied to constrain the choice of possible actions.
|
||||
"""
|
||||
if action_masks is not None:
|
||||
self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
|
||||
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
|
||||
assert self.full, ""
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
if not self.generator_ready:
|
||||
for tensor in [
|
||||
"observations",
|
||||
"actions",
|
||||
"values",
|
||||
"log_probs",
|
||||
"advantages",
|
||||
"returns",
|
||||
"action_masks",
|
||||
]:
|
||||
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
||||
self.generator_ready = True
|
||||
|
||||
# Return everything, don't create minibatches
|
||||
if batch_size is None:
|
||||
batch_size = self.buffer_size * self.n_envs
|
||||
|
||||
start_idx = 0
|
||||
while start_idx < self.buffer_size * self.n_envs:
|
||||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
|
||||
data = (
|
||||
self.observations[batch_inds],
|
||||
self.actions[batch_inds],
|
||||
self.values[batch_inds].flatten(),
|
||||
self.log_probs[batch_inds].flatten(),
|
||||
self.advantages[batch_inds].flatten(),
|
||||
self.returns[batch_inds].flatten(),
|
||||
self.action_masks[batch_inds].reshape(-1, self.mask_dims),
|
||||
)
|
||||
return MaskableRolloutBufferSamples(*map(self.to_torch, data))
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.vec_env import sync_envs_normalization
|
||||
|
||||
from sb3_contrib.common.maskable.evaluation import evaluate_policy
|
||||
|
||||
|
||||
class MaskableEvalCallback(EvalCallback):
|
||||
"""
|
||||
Callback for evaluating an agent. Supports invalid action masking.
|
||||
|
||||
:param eval_env: The environment used for initialization
|
||||
:param callback_on_new_best: Callback to trigger
|
||||
when there is a new best model according to the ``mean_reward``
|
||||
:param n_eval_episodes: The number of episodes to test the agent
|
||||
:param eval_freq: Evaluate the agent every eval_freq call of the callback.
|
||||
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
|
||||
will be saved. It will be updated at each evaluation.
|
||||
:param best_model_save_path: Path to a folder where the best model
|
||||
according to performance on the eval env will be saved.
|
||||
:param deterministic: Whether the evaluation should
|
||||
use a stochastic or deterministic actions.
|
||||
:param render: Whether to render or not the environment during evaluation
|
||||
:param verbose:
|
||||
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
|
||||
wrapped with a Monitor wrapper)
|
||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
||||
"""
|
||||
|
||||
def __init__(self, *args, use_masking: bool = True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.use_masking = use_masking
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
||||
# Sync training and eval env if there is VecNormalize
|
||||
sync_envs_normalization(self.training_env, self.eval_env)
|
||||
|
||||
# Reset success rate buffer
|
||||
self._is_success_buffer = []
|
||||
|
||||
# Note that evaluate_policy() has been patched to support masking
|
||||
episode_rewards, episode_lengths = evaluate_policy(
|
||||
self.model,
|
||||
self.eval_env,
|
||||
n_eval_episodes=self.n_eval_episodes,
|
||||
render=self.render,
|
||||
deterministic=self.deterministic,
|
||||
return_episode_rewards=True,
|
||||
warn=self.warn,
|
||||
callback=self._log_success_callback,
|
||||
use_masking=self.use_masking,
|
||||
)
|
||||
|
||||
if self.log_path is not None:
|
||||
self.evaluations_timesteps.append(self.num_timesteps)
|
||||
self.evaluations_results.append(episode_rewards)
|
||||
self.evaluations_length.append(episode_lengths)
|
||||
|
||||
kwargs = {}
|
||||
# Save success log if present
|
||||
if len(self._is_success_buffer) > 0:
|
||||
self.evaluations_successes.append(self._is_success_buffer)
|
||||
kwargs = dict(successes=self.evaluations_successes)
|
||||
|
||||
np.savez(
|
||||
self.log_path,
|
||||
timesteps=self.evaluations_timesteps,
|
||||
results=self.evaluations_results,
|
||||
ep_lengths=self.evaluations_length,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
||||
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
|
||||
self.last_mean_reward = mean_reward
|
||||
|
||||
if self.verbose > 0:
|
||||
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
|
||||
print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
|
||||
# Add to current Logger
|
||||
self.logger.record("eval/mean_reward", float(mean_reward))
|
||||
self.logger.record("eval/mean_ep_length", mean_ep_length)
|
||||
|
||||
if len(self._is_success_buffer) > 0:
|
||||
success_rate = np.mean(self._is_success_buffer)
|
||||
if self.verbose > 0:
|
||||
print(f"Success rate: {100 * success_rate:.2f}%")
|
||||
self.logger.record("eval/success_rate", success_rate)
|
||||
|
||||
# Dump log so the evaluation results are printed with the correct timestep
|
||||
self.logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(self.num_timesteps)
|
||||
|
||||
if mean_reward > self.best_mean_reward:
|
||||
if self.verbose > 0:
|
||||
print("New best mean reward!")
|
||||
if self.best_model_save_path is not None:
|
||||
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
||||
self.best_mean_reward = mean_reward
|
||||
# Trigger callback if needed
|
||||
if self.callback is not None:
|
||||
return self._on_event()
|
||||
|
||||
return True
|
||||
|
|
@ -0,0 +1,272 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.distributions import Distribution
|
||||
from torch import nn
|
||||
from torch.distributions import Categorical
|
||||
from torch.distributions.utils import logits_to_probs
|
||||
|
||||
|
||||
class MaskableCategorical(Categorical):
|
||||
"""
|
||||
Modified PyTorch Categorical distribution with support for invalid action masking.
|
||||
|
||||
To instantiate, must provide either probs or logits, but not both.
|
||||
|
||||
:param probs: Tensor containing finite non-negative values, which will be renormalized
|
||||
to sum to 1 along the last dimension.
|
||||
:param logits: Tensor of unnormalized log probabilities.
|
||||
:param validate_args: Whether or not to validate that arguments to methods like lob_prob()
|
||||
and icdf() match the distribution's shape, support, etc.
|
||||
:param masks: An optional boolean ndarray of compatible shape with the distribution.
|
||||
If True, the corresponding choice's logit value is preserved. If False, it is set to a
|
||||
large negative value, resulting in near 0 probability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
probs: Optional[th.Tensor] = None,
|
||||
logits: Optional[th.Tensor] = None,
|
||||
validate_args: Optional[bool] = None,
|
||||
masks: Optional[np.ndarray] = None,
|
||||
):
|
||||
self.masks: Optional[th.Tensor] = None
|
||||
super().__init__(probs, logits, validate_args)
|
||||
self._original_logits = self.logits
|
||||
self.apply_masking(masks)
|
||||
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
"""
|
||||
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
|
||||
|
||||
:param masks: An optional boolean ndarray of compatible shape with the distribution.
|
||||
If True, the corresponding choice's logit value is preserved. If False, it is set
|
||||
to a large negative value, resulting in near 0 probability. If masks is None, any
|
||||
previously applied masking is removed, and the original logits are restored.
|
||||
"""
|
||||
|
||||
if masks is not None:
|
||||
device = self.logits.device
|
||||
self.masks = th.as_tensor(masks, dtype=th.bool, device=device).reshape(self.logits.shape)
|
||||
HUGE_NEG = th.tensor(-1e8, dtype=self.logits.dtype, device=device)
|
||||
|
||||
logits = th.where(self.masks, self._original_logits, HUGE_NEG)
|
||||
else:
|
||||
self.masks = None
|
||||
logits = self._original_logits
|
||||
|
||||
# Reinitialize with updated logits
|
||||
super().__init__(logits=logits)
|
||||
|
||||
# self.probs may already be cached, so we must force an update
|
||||
self.probs = logits_to_probs(self.logits)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
if self.masks is None:
|
||||
return super().entropy()
|
||||
|
||||
# Highly negative logits don't result in 0 probs, so we must replace
|
||||
# with 0s to ensure 0 contribution to the distribution's entropy, since
|
||||
# masked actions possess no uncertainty.
|
||||
device = self.logits.device
|
||||
p_log_p = self.logits * self.probs
|
||||
p_log_p = th.where(self.masks, p_log_p, th.tensor(0.0, device=device))
|
||||
return -p_log_p.sum(-1)
|
||||
|
||||
|
||||
class MaskableDistribution(Distribution, ABC):
|
||||
@abstractmethod
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
"""
|
||||
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.
|
||||
|
||||
:param masks: An optional boolean ndarray of compatible shape with the distribution.
|
||||
If True, the corresponding choice's logit value is preserved. If False, it is set
|
||||
to a large negative value, resulting in near 0 probability. If masks is None, any
|
||||
previously applied masking is removed, and the original logits are restored.
|
||||
"""
|
||||
|
||||
|
||||
class MaskableCategoricalDistribution(MaskableDistribution):
|
||||
"""
|
||||
Categorical distribution for discrete actions. Supports invalid action masking.
|
||||
|
||||
:param action_dim: Number of discrete actions
|
||||
"""
|
||||
|
||||
def __init__(self, action_dim: int):
|
||||
super().__init__()
|
||||
self.distribution: Optional[MaskableCategorical] = None
|
||||
self.action_dim = action_dim
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
||||
"""
|
||||
Create the layer that represents the distribution:
|
||||
it will be the logits of the Categorical distribution.
|
||||
You can then get probabilities using a softmax.
|
||||
|
||||
:param latent_dim: Dimension of the last layer
|
||||
of the policy network (before the action layer)
|
||||
:return:
|
||||
"""
|
||||
action_logits = nn.Linear(latent_dim, self.action_dim)
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableCategoricalDistribution":
|
||||
# Restructure shape to align with logits
|
||||
reshaped_logits = action_logits.view(-1, self.action_dim)
|
||||
self.distribution = MaskableCategorical(logits=reshaped_logits)
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
assert self.distribution is not None, "Must set distribution parameters"
|
||||
return self.distribution.log_prob(actions)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
assert self.distribution is not None, "Must set distribution parameters"
|
||||
return self.distribution.entropy()
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
assert self.distribution is not None, "Must set distribution parameters"
|
||||
return self.distribution.sample()
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
assert self.distribution is not None, "Must set distribution parameters"
|
||||
return th.argmax(self.distribution.probs, dim=1)
|
||||
|
||||
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
# Update the proba distribution
|
||||
self.proba_distribution(action_logits)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
actions = self.actions_from_params(action_logits)
|
||||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
assert self.distribution is not None, "Must set distribution parameters"
|
||||
self.distribution.apply_masking(masks)
|
||||
|
||||
|
||||
class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
||||
"""
|
||||
MultiCategorical distribution for multi discrete actions. Supports invalid action masking.
|
||||
|
||||
:param action_dims: List of sizes of discrete action spaces
|
||||
"""
|
||||
|
||||
def __init__(self, action_dims: List[int]):
|
||||
super().__init__()
|
||||
self.distributions: List[MaskableCategorical] = []
|
||||
self.action_dims = action_dims
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
||||
"""
|
||||
Create the layer that represents the distribution:
|
||||
it will be the logits (flattened) of the MultiCategorical distribution.
|
||||
You can then get probabilities using a softmax on each sub-space.
|
||||
|
||||
:param latent_dim: Dimension of the last layer
|
||||
of the policy network (before the action layer)
|
||||
:return:
|
||||
"""
|
||||
|
||||
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
|
||||
return action_logits
|
||||
|
||||
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableMultiCategoricalDistribution":
|
||||
# Restructure shape to align with logits
|
||||
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
|
||||
|
||||
self.distributions = [
|
||||
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1)
|
||||
]
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
assert len(self.distributions) > 0, "Must set distribution parameters"
|
||||
|
||||
# Restructure shape to align with each categorical
|
||||
actions = actions.view(-1, len(self.action_dims))
|
||||
|
||||
# Extract each discrete action and compute log prob for their respective distributions
|
||||
return th.stack(
|
||||
[dist.log_prob(action) for dist, action in zip(self.distributions, th.unbind(actions, dim=1))], dim=1
|
||||
).sum(dim=1)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
assert len(self.distributions) > 0, "Must set distribution parameters"
|
||||
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
assert len(self.distributions) > 0, "Must set distribution parameters"
|
||||
return th.stack([dist.sample() for dist in self.distributions], dim=1)
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
assert len(self.distributions) > 0, "Must set distribution parameters"
|
||||
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
|
||||
|
||||
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
# Update the proba distribution
|
||||
self.proba_distribution(action_logits)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
actions = self.actions_from_params(action_logits)
|
||||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
assert len(self.distributions) > 0, "Must set distribution parameters"
|
||||
|
||||
split_masks = [None] * len(self.distributions)
|
||||
if masks is not None:
|
||||
masks = th.as_tensor(masks)
|
||||
|
||||
# Restructure shape to align with logits
|
||||
masks = masks.view(-1, sum(self.action_dims))
|
||||
|
||||
# Then split columnwise for each discrete action
|
||||
split_masks = th.split(masks, tuple(self.action_dims), dim=1)
|
||||
|
||||
for distribution, mask in zip(self.distributions, split_masks):
|
||||
distribution.apply_masking(mask)
|
||||
|
||||
|
||||
class MaskableBernoulliDistribution(MaskableMultiCategoricalDistribution):
|
||||
"""
|
||||
Bernoulli distribution for multibinary actions. Supports invalid action masking.
|
||||
|
||||
:param action_dim: Number of binary actions
|
||||
"""
|
||||
|
||||
def __init__(self, action_dim: int):
|
||||
# Two states per binary action
|
||||
action_dims = [2] * action_dim
|
||||
super().__init__(action_dims)
|
||||
|
||||
|
||||
def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistribution:
|
||||
"""
|
||||
Return an instance of Distribution for the correct type of action space
|
||||
|
||||
:param action_space: the input action space
|
||||
:return: the appropriate Distribution object
|
||||
"""
|
||||
|
||||
if isinstance(action_space, spaces.Discrete):
|
||||
return MaskableCategoricalDistribution(action_space.n)
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
return MaskableMultiCategoricalDistribution(action_space.nvec)
|
||||
elif isinstance(action_space, spaces.MultiBinary):
|
||||
return MaskableBernoulliDistribution(action_space.n)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Error: probability distribution, not implemented for action space"
|
||||
f"of type {type(action_space)}."
|
||||
" Must be of type Gym Spaces: Discrete, MultiDiscrete."
|
||||
)
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
|
||||
|
||||
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
|
||||
from sb3_contrib.ppo_mask import MaskablePPO
|
||||
|
||||
|
||||
def evaluate_policy( # noqa: C901
|
||||
model: MaskablePPO,
|
||||
env: Union[gym.Env, VecEnv],
|
||||
n_eval_episodes: int = 10,
|
||||
deterministic: bool = True,
|
||||
render: bool = False,
|
||||
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
|
||||
reward_threshold: Optional[float] = None,
|
||||
return_episode_rewards: bool = False,
|
||||
warn: bool = True,
|
||||
use_masking: bool = True,
|
||||
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
|
||||
"""
|
||||
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
|
||||
If a vector env is passed in, this divides the episodes to evaluate onto the
|
||||
different elements of the vector env. This static division of work is done to
|
||||
remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
|
||||
details and discussion.
|
||||
|
||||
.. note::
|
||||
If environment has not been wrapped with ``Monitor`` wrapper, reward and
|
||||
episode lengths are counted as it appears with ``env.step`` calls. If
|
||||
the environment contains wrappers that modify rewards or episode lengths
|
||||
(e.g. reward scaling, early episode reset), these will affect the evaluation
|
||||
results as well. You can avoid this by wrapping environment with ``Monitor``
|
||||
wrapper before anything else.
|
||||
|
||||
:param model: The RL agent you want to evaluate.
|
||||
:param env: The gym environment. In the case of a ``VecEnv``
|
||||
this must contain only one environment.
|
||||
:param n_eval_episodes: Number of episode to evaluate the agent
|
||||
:param deterministic: Whether to use deterministic or stochastic actions
|
||||
:param render: Whether to render the environment or not
|
||||
:param callback: callback function to do additional checks,
|
||||
called after each step. Gets locals() and globals() passed as parameters.
|
||||
:param reward_threshold: Minimum expected reward per episode,
|
||||
this will raise an error if the performance is not met
|
||||
:param return_episode_rewards: If True, a list of rewards and episde lengths
|
||||
per episode will be returned instead of the mean.
|
||||
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
|
||||
evaluation environment.
|
||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
||||
:return: Mean reward per episode, std of reward per episode.
|
||||
Returns ([float], [int]) when ``return_episode_rewards`` is True, first
|
||||
list containing per-episode rewards and second containing per-episode lengths
|
||||
(in number of steps).
|
||||
"""
|
||||
|
||||
if use_masking and not is_masking_supported(env):
|
||||
raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper")
|
||||
|
||||
is_monitor_wrapped = False
|
||||
|
||||
if not isinstance(env, VecEnv):
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
|
||||
|
||||
if not is_monitor_wrapped and warn:
|
||||
warnings.warn(
|
||||
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
|
||||
"This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
|
||||
"Consider wrapping environment first with ``Monitor`` wrapper.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
n_envs = env.num_envs
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
|
||||
episode_counts = np.zeros(n_envs, dtype="int")
|
||||
# Divides episodes among different sub environments in the vector as evenly as possible
|
||||
episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")
|
||||
|
||||
current_rewards = np.zeros(n_envs)
|
||||
current_lengths = np.zeros(n_envs, dtype="int")
|
||||
observations = env.reset()
|
||||
states = None
|
||||
|
||||
while (episode_counts < episode_count_targets).any():
|
||||
if use_masking:
|
||||
action_masks = get_action_masks(env)
|
||||
actions, state = model.predict(
|
||||
observations,
|
||||
state=states,
|
||||
deterministic=deterministic,
|
||||
action_masks=action_masks,
|
||||
)
|
||||
else:
|
||||
actions, states = model.predict(observations, state=states, deterministic=deterministic)
|
||||
observations, rewards, dones, infos = env.step(actions)
|
||||
current_rewards += rewards
|
||||
current_lengths += 1
|
||||
for i in range(n_envs):
|
||||
if episode_counts[i] < episode_count_targets[i]:
|
||||
|
||||
# unpack values so that the callback can access the local variables
|
||||
reward = rewards[i]
|
||||
done = dones[i]
|
||||
info = infos[i]
|
||||
|
||||
if callback is not None:
|
||||
callback(locals(), globals())
|
||||
|
||||
if dones[i]:
|
||||
if is_monitor_wrapped:
|
||||
# Atari wrapper can send a "done" signal when
|
||||
# the agent loses a life, but it does not correspond
|
||||
# to the true end of episode
|
||||
if "episode" in info.keys():
|
||||
# Do not trust "done" with episode endings.
|
||||
# Monitor wrapper includes "episode" key in info if environment
|
||||
# has been wrapped with it. Use those rewards instead.
|
||||
episode_rewards.append(info["episode"]["r"])
|
||||
episode_lengths.append(info["episode"]["l"])
|
||||
# Only increment at the real end of an episode
|
||||
episode_counts[i] += 1
|
||||
else:
|
||||
episode_rewards.append(current_rewards[i])
|
||||
episode_lengths.append(current_lengths[i])
|
||||
episode_counts[i] += 1
|
||||
current_rewards[i] = 0
|
||||
current_lengths[i] = 0
|
||||
if states is not None:
|
||||
states[i] *= 0
|
||||
|
||||
if render:
|
||||
env.render()
|
||||
|
||||
mean_reward = np.mean(episode_rewards)
|
||||
std_reward = np.std(episode_rewards)
|
||||
if reward_threshold is not None:
|
||||
assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
|
||||
if return_episode_rewards:
|
||||
return episode_rewards, episode_lengths
|
||||
return mean_reward, std_reward
|
||||
|
|
@ -0,0 +1,402 @@
|
|||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
CombinedExtractor,
|
||||
FlattenExtractor,
|
||||
MlpExtractor,
|
||||
NatureCNN,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from torch import nn
|
||||
|
||||
from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution
|
||||
|
||||
|
||||
class MaskableActorCriticPolicy(BasePolicy):
|
||||
"""
|
||||
Policy class for actor-critic algorithms (has both policy and value prediction).
|
||||
Used by A2C, PPO and the likes.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param net_arch: The specification of the policy and value networks.
|
||||
:param activation_fn: Activation function
|
||||
:param ortho_init: Whether to use or not orthogonal initialization
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
# Small values to avoid NaN in Adam optimizer
|
||||
if optimizer_class == th.optim.Adam:
|
||||
optimizer_kwargs["eps"] = 1e-5
|
||||
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
squash_output=False,
|
||||
)
|
||||
|
||||
# Default network architecture, from stable-baselines
|
||||
if net_arch is None:
|
||||
if features_extractor_class == NatureCNN:
|
||||
net_arch = []
|
||||
else:
|
||||
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.ortho_init = ortho_init
|
||||
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.normalize_images = normalize_images
|
||||
|
||||
# Action distribution
|
||||
self.action_dist = make_masked_proba_distribution(action_space)
|
||||
|
||||
self._build(lr_schedule)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: th.Tensor,
|
||||
deterministic: bool = False,
|
||||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Forward pass in all the networks (actor and critic)
|
||||
|
||||
:param obs: Observation
|
||||
:param deterministic: Whether to sample or use deterministic actions
|
||||
:param action_masks: Action masks to apply to the action distribution
|
||||
:return: action, value and log probability of the action
|
||||
"""
|
||||
latent_pi, latent_vf = self._get_latent(obs)
|
||||
# Evaluate the values for the given observations
|
||||
values = self.value_net(latent_vf)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
if action_masks is not None:
|
||||
distribution.apply_masking(action_masks)
|
||||
actions = distribution.get_actions(deterministic=deterministic)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
return actions, values, log_prob
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
data.update(
|
||||
dict(
|
||||
net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn,
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
ortho_init=self.ortho_init,
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
def _build_mlp_extractor(self) -> None:
|
||||
"""
|
||||
Create the policy and value networks.
|
||||
Part of the layers can be shared.
|
||||
"""
|
||||
# Note: If net_arch is None and some features extractor is used,
|
||||
# net_arch here is an empty list and mlp_extractor does not
|
||||
# really contain any layers (acts like an identity module).
|
||||
self.mlp_extractor = MlpExtractor(
|
||||
self.features_dim,
|
||||
net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def _build(self, lr_schedule: Schedule) -> None:
|
||||
"""
|
||||
Create the networks and the optimizer.
|
||||
|
||||
:param lr_schedule: Learning rate schedule
|
||||
lr_schedule(1) is the initial learning rate
|
||||
"""
|
||||
self._build_mlp_extractor()
|
||||
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi)
|
||||
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
|
||||
|
||||
# Init weights: use orthogonal initialization
|
||||
# with small initial weight for the output
|
||||
if self.ortho_init:
|
||||
# TODO: check for features_extractor
|
||||
# Values from stable-baselines.
|
||||
# features_extractor/mlp values are
|
||||
# originally from openai/baselines (default gains/init_scales).
|
||||
module_gains = {
|
||||
self.features_extractor: np.sqrt(2),
|
||||
self.mlp_extractor: np.sqrt(2),
|
||||
self.action_net: 0.01,
|
||||
self.value_net: 1,
|
||||
}
|
||||
for module, gain in module_gains.items():
|
||||
module.apply(partial(self.init_weights, gain=gain))
|
||||
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Get the latent code (i.e., activations of the last layer of each network)
|
||||
for the different networks.
|
||||
|
||||
:param obs: Observation
|
||||
:return: Latent codes
|
||||
for the actor, the value function and for gSDE function
|
||||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
|
||||
return latent_pi, latent_vf
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
|
||||
"""
|
||||
Retrieve action distribution given the latent codes.
|
||||
|
||||
:param latent_pi: Latent code for the actor
|
||||
:return: Action distribution
|
||||
"""
|
||||
action_logits = self.action_net(latent_pi)
|
||||
return self.action_dist.proba_distribution(action_logits=action_logits)
|
||||
|
||||
def _predict(
|
||||
self,
|
||||
observation: th.Tensor,
|
||||
deterministic: bool = False,
|
||||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
||||
:param observation:
|
||||
:param deterministic: Whether to use stochastic or deterministic actions
|
||||
:param action_masks: Action masks to apply to the action distribution
|
||||
:return: Taken action according to the policy
|
||||
"""
|
||||
latent_pi, _ = self._get_latent(observation)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
if action_masks is not None:
|
||||
distribution.apply_masking(action_masks)
|
||||
return distribution.get_actions(deterministic=deterministic)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
"""
|
||||
Get the policy action and state from an observation (and optional state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:param action_masks: Action masks to apply to the action distribution
|
||||
:return: the model's action and the next state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
# TODO (GH/1): add support for RNN policies
|
||||
# if state is None:
|
||||
# state = self.initial_state
|
||||
# if mask is None:
|
||||
# mask = [False for _ in range(self.n_envs)]
|
||||
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.set_training_mode(False)
|
||||
|
||||
observation, vectorized_env = self.obs_to_tensor(observation)
|
||||
|
||||
with th.no_grad():
|
||||
actions = self._predict(observation, deterministic=deterministic, action_masks=action_masks)
|
||||
# Convert to numpy
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if self.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.unscale_action(actions)
|
||||
else:
|
||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
if not vectorized_env:
|
||||
if state is not None:
|
||||
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
|
||||
actions = actions[0]
|
||||
|
||||
return actions, state
|
||||
|
||||
def evaluate_actions(
|
||||
self,
|
||||
obs: th.Tensor,
|
||||
actions: th.Tensor,
|
||||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
||||
:param obs:
|
||||
:param actions:
|
||||
:return: estimated value, log likelihood of taking those actions
|
||||
and entropy of the action distribution.
|
||||
"""
|
||||
latent_pi, latent_vf = self._get_latent(obs)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
if action_masks is not None:
|
||||
distribution.apply_masking(action_masks)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
values = self.value_net(latent_vf)
|
||||
return values, log_prob, distribution.entropy()
|
||||
|
||||
|
||||
class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
||||
"""
|
||||
CNN policy class for actor-critic algorithms (has both policy and value prediction).
|
||||
Used by A2C, PPO and the likes.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param net_arch: The specification of the policy and value networks.
|
||||
:param activation_fn: Activation function
|
||||
:param ortho_init: Whether to use or not orthogonal initialization
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(MaskableActorCriticCnnPolicy, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
activation_fn,
|
||||
ortho_init,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
||||
"""
|
||||
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
|
||||
Used by A2C, PPO and the likes.
|
||||
|
||||
:param observation_space: Observation space (Tuple)
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param net_arch: The specification of the policy and value networks.
|
||||
:param activation_fn: Activation function
|
||||
:param ortho_init: Whether to use or not orthogonal initialization
|
||||
:param features_extractor_class: Uses the CombinedExtractor
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Dict,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(MaskableMultiInputActorCriticPolicy, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
activation_fn,
|
||||
ortho_init,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
)
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
import numpy as np
|
||||
from stable_baselines3.common.type_aliases import GymEnv
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
EXPECTED_METHOD_NAME = "action_masks"
|
||||
|
||||
|
||||
def get_action_masks(env: GymEnv) -> np.ndarray:
|
||||
"""
|
||||
Checks whether gym env exposes a method returning invalid action masks
|
||||
|
||||
:param env: the Gym environment to get masks from
|
||||
:return: A numpy array of the masks
|
||||
"""
|
||||
|
||||
if isinstance(env, VecEnv):
|
||||
return np.stack(env.env_method(EXPECTED_METHOD_NAME))
|
||||
else:
|
||||
return getattr(env, EXPECTED_METHOD_NAME)()
|
||||
|
||||
|
||||
def is_masking_supported(env: GymEnv) -> bool:
|
||||
"""
|
||||
Checks whether gym env exposes a method returning invalid action masks
|
||||
|
||||
:param env: the Gym environment to check
|
||||
:return: True if the method is found, False otherwise
|
||||
"""
|
||||
|
||||
if isinstance(env, VecEnv):
|
||||
try:
|
||||
# TODO: add VecEnv.has_attr()
|
||||
env.get_attr(EXPECTED_METHOD_NAME)
|
||||
return True
|
||||
except AttributeError:
|
||||
return False
|
||||
else:
|
||||
return hasattr(env, EXPECTED_METHOD_NAME)
|
||||
|
|
@ -1 +1,2 @@
|
|||
from sb3_contrib.common.wrappers.action_masker import ActionMasker
|
||||
from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
from typing import Callable, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ActionMasker(gym.Wrapper):
|
||||
"""
|
||||
Env wrapper providing the method required to support masking.
|
||||
|
||||
Exposes a method called action_masks(), which returns masks for the wrapped env.
|
||||
This wrapper is not needed if the env exposes the expected method itself.
|
||||
|
||||
:param env: the Gym environment to wrap
|
||||
:param action_mask_fn: A function that takes a Gym environment and returns an action mask,
|
||||
or the name of such a method provided by the environment.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, action_mask_fn: Union[str, Callable[[gym.Env], np.ndarray]]):
|
||||
super().__init__(env)
|
||||
|
||||
if isinstance(action_mask_fn, str):
|
||||
found_method = getattr(self.env, action_mask_fn)
|
||||
if not callable(found_method):
|
||||
raise ValueError(f"Environment attribute {action_mask_fn} is not a method")
|
||||
|
||||
self._action_mask_fn = found_method
|
||||
else:
|
||||
self._action_mask_fn = action_mask_fn
|
||||
|
||||
def action_masks(self) -> np.ndarray:
|
||||
return self._action_mask_fn(self.env)
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy # , MultiInputPolicy
|
||||
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from stable_baselines3.common.policies import register_policy
|
||||
|
||||
from sb3_contrib.common.maskable.policies import ( # MaskableMultiInputActorCriticPolicy,
|
||||
MaskableActorCriticCnnPolicy,
|
||||
MaskableActorCriticPolicy,
|
||||
)
|
||||
|
||||
MlpPolicy = MaskableActorCriticPolicy
|
||||
CnnPolicy = MaskableActorCriticCnnPolicy
|
||||
# MultiInputPolicy = MaskableMultiInputActorCriticPolicy
|
||||
|
||||
register_policy("MlpPolicy", MaskableActorCriticPolicy)
|
||||
register_policy("CnnPolicy", MaskableActorCriticCnnPolicy)
|
||||
# Currently not supported
|
||||
# register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)
|
||||
|
|
@ -0,0 +1,551 @@
|
|||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common import utils
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from torch.nn import functional as F
|
||||
|
||||
from sb3_contrib.common.maskable.buffers import MaskableRolloutBuffer
|
||||
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
||||
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
|
||||
|
||||
|
||||
class MaskablePPO(OnPolicyAlgorithm):
|
||||
"""
|
||||
Proximal Policy Optimization algorithm (PPO) (clip version) with Invalid Action Masking.
|
||||
|
||||
Based on the original Stable Baselines 3 implementation.
|
||||
|
||||
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
|
||||
Background on Invalid Action Masking: https://arxiv.org/abs/2006.14171
|
||||
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: The learning rate, it can be a function
|
||||
of the current progress remaining (from 1 to 0)
|
||||
:param n_steps: The number of steps to run for each environment per update
|
||||
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
|
||||
:param batch_size: Minibatch size
|
||||
:param n_epochs: Number of epoch when optimizing the surrogate loss
|
||||
:param gamma: Discount factor
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
:param clip_range: Clipping parameter, it can be a function of the current progress
|
||||
remaining (from 1 to 0).
|
||||
:param clip_range_vf: Clipping parameter for the value function,
|
||||
it can be a function of the current progress remaining (from 1 to 0).
|
||||
This is a parameter specific to the OpenAI implementation. If None is passed (default),
|
||||
no clipping will be done on the value function.
|
||||
IMPORTANT: this clipping depends on the reward scaling.
|
||||
:param ent_coef: Entropy coefficient for the loss calculation
|
||||
:param vf_coef: Value function coefficient for the loss calculation
|
||||
:param max_grad_norm: The maximum value for the gradient clipping
|
||||
:param target_kl: Limit the KL divergence between updates,
|
||||
because the clipping is not enough to prevent large update
|
||||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||
By default, there is no limit on the kl div.
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[MaskableActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 3e-4,
|
||||
n_steps: int = 2048,
|
||||
batch_size: Optional[int] = 64,
|
||||
n_epochs: int = 10,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
clip_range: Union[float, Schedule] = 0.2,
|
||||
clip_range_vf: Union[None, float, Schedule] = None,
|
||||
ent_coef: float = 0.0,
|
||||
vf_coef: float = 0.5,
|
||||
max_grad_norm: float = 0.5,
|
||||
target_kl: Optional[float] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
gamma=gamma,
|
||||
gae_lambda=gae_lambda,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm,
|
||||
use_sde=False,
|
||||
sde_sample_freq=-1,
|
||||
tensorboard_log=tensorboard_log,
|
||||
create_eval_env=create_eval_env,
|
||||
policy_kwargs=policy_kwargs,
|
||||
policy_base=MaskableActorCriticPolicy,
|
||||
verbose=verbose,
|
||||
seed=seed,
|
||||
device=device,
|
||||
_init_setup_model=False,
|
||||
supported_action_spaces=(
|
||||
spaces.Discrete,
|
||||
spaces.MultiDiscrete,
|
||||
spaces.MultiBinary,
|
||||
),
|
||||
)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.n_epochs = n_epochs
|
||||
self.clip_range = clip_range
|
||||
self.clip_range_vf = clip_range_vf
|
||||
self.target_kl = target_kl
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
self.policy = self.policy_class(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.lr_schedule,
|
||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
||||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
if not isinstance(self.policy, MaskableActorCriticPolicy):
|
||||
raise ValueError("Policy must subclass MaskableActorCriticPolicy")
|
||||
|
||||
self.rollout_buffer = MaskableRolloutBuffer(
|
||||
self.n_steps,
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.device,
|
||||
gamma=self.gamma,
|
||||
gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs,
|
||||
)
|
||||
|
||||
# Initialize schedules for policy/value clipping
|
||||
self.clip_range = get_schedule_fn(self.clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
if isinstance(self.clip_range_vf, (float, int)):
|
||||
assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"
|
||||
|
||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||
|
||||
def _init_callback(
|
||||
self,
|
||||
callback: MaybeCallback,
|
||||
eval_env: Optional[VecEnv] = None,
|
||||
eval_freq: int = 10000,
|
||||
n_eval_episodes: int = 5,
|
||||
log_path: Optional[str] = None,
|
||||
use_masking: bool = True,
|
||||
) -> BaseCallback:
|
||||
"""
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
:param eval_freq: How many steps between evaluations; if None, do not evaluate.
|
||||
:param n_eval_episodes: How many episodes to play per evaluation
|
||||
:param n_eval_episodes: Number of episodes to rollout during evaluation.
|
||||
:param log_path: Path to a folder where the evaluations will be saved
|
||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
||||
:return: A hybrid callback calling `callback` and performing evaluation.
|
||||
"""
|
||||
# Convert a list of callbacks into a callback
|
||||
if isinstance(callback, list):
|
||||
callback = CallbackList(callback)
|
||||
|
||||
# Convert functional callback to object
|
||||
if not isinstance(callback, BaseCallback):
|
||||
callback = ConvertCallback(callback)
|
||||
|
||||
# Create eval callback in charge of the evaluation
|
||||
if eval_env is not None:
|
||||
# Avoid circular import error
|
||||
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
||||
|
||||
eval_callback = MaskableEvalCallback(
|
||||
eval_env,
|
||||
best_model_save_path=log_path,
|
||||
log_path=log_path,
|
||||
eval_freq=eval_freq,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
use_masking=use_masking,
|
||||
)
|
||||
callback = CallbackList([callback, eval_callback])
|
||||
|
||||
callback.init_callback(self)
|
||||
return callback
|
||||
|
||||
def _setup_learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
eval_env: Optional[GymEnv],
|
||||
callback: MaybeCallback = None,
|
||||
eval_freq: int = 10000,
|
||||
n_eval_episodes: int = 5,
|
||||
log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
tb_log_name: str = "run",
|
||||
use_masking: bool = True,
|
||||
) -> Tuple[int, BaseCallback]:
|
||||
"""
|
||||
Initialize different variables needed for training.
|
||||
|
||||
:param total_timesteps: The total number of samples (env steps) to train on
|
||||
:param eval_env: Environment to use for evaluation.
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
:param eval_freq: How many steps between evaluations
|
||||
:param n_eval_episodes: How many episodes to play per evaluation
|
||||
:param log_path: Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
||||
:param tb_log_name: the name of the run for tensorboard log
|
||||
:param use_masking: Whether or not to use invalid action masks during training
|
||||
:return:
|
||||
"""
|
||||
|
||||
self.start_time = time.time()
|
||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
||||
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
||||
self.ep_info_buffer = deque(maxlen=100)
|
||||
self.ep_success_buffer = deque(maxlen=100)
|
||||
|
||||
if reset_num_timesteps:
|
||||
self.num_timesteps = 0
|
||||
self._episode_num = 0
|
||||
else:
|
||||
# Make sure training timesteps are ahead of the internal counter
|
||||
total_timesteps += self.num_timesteps
|
||||
self._total_timesteps = total_timesteps
|
||||
|
||||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
||||
if reset_num_timesteps or self._last_obs is None:
|
||||
self._last_obs = self.env.reset()
|
||||
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
self._last_original_obs = self._vec_normalize_env.get_original_obs()
|
||||
|
||||
if eval_env is not None and self.seed is not None:
|
||||
eval_env.seed(self.seed)
|
||||
|
||||
eval_env = self._get_eval_env(eval_env)
|
||||
|
||||
# Configure logger's outputs if no logger was passed
|
||||
if not self._custom_logger:
|
||||
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
||||
|
||||
# Create eval callback if needed
|
||||
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking)
|
||||
|
||||
return total_timesteps, callback
|
||||
|
||||
def collect_rollouts(
|
||||
self,
|
||||
env: VecEnv,
|
||||
callback: BaseCallback,
|
||||
rollout_buffer: RolloutBuffer,
|
||||
n_rollout_steps: int,
|
||||
use_masking: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Collect experiences using the current policy and fill a ``RolloutBuffer``.
|
||||
The term rollout here refers to the model-free notion and should not
|
||||
be used with the concept of rollout used in model-based RL or planning.
|
||||
|
||||
This method is largely identical to the implementation found in the parent class.
|
||||
|
||||
:param env: The training environment
|
||||
:param callback: Callback that will be called at each step
|
||||
(and at the beginning and end of the rollout)
|
||||
:param rollout_buffer: Buffer to fill with rollouts
|
||||
:param n_steps: Number of experiences to collect per environment
|
||||
:param use_masking: Whether or not to use invalid action masks during training
|
||||
:return: True if function returned with at least `n_rollout_steps`
|
||||
collected, False if callback terminated rollout prematurely.
|
||||
"""
|
||||
|
||||
assert isinstance(rollout_buffer, MaskableRolloutBuffer), "RolloutBuffer doesn't support action masking"
|
||||
assert self._last_obs is not None, "No previous observation was provided"
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.policy.set_training_mode(False)
|
||||
n_steps = 0
|
||||
action_masks = None
|
||||
rollout_buffer.reset()
|
||||
|
||||
if use_masking and not is_masking_supported(env):
|
||||
raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper")
|
||||
|
||||
callback.on_rollout_start()
|
||||
|
||||
while n_steps < n_rollout_steps:
|
||||
with th.no_grad():
|
||||
# Convert to pytorch tensor or to TensorDict
|
||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||
|
||||
# This is the only change related to invalid action masking
|
||||
if use_masking:
|
||||
action_masks = get_action_masks(env)
|
||||
|
||||
actions, values, log_probs = self.policy.forward(obs_tensor, action_masks=action_masks)
|
||||
|
||||
actions = actions.cpu().numpy()
|
||||
new_obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
self.num_timesteps += env.num_envs
|
||||
|
||||
# Give access to local variables
|
||||
callback.update_locals(locals())
|
||||
if callback.on_step() is False:
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
n_steps += 1
|
||||
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Reshape in case of discrete action
|
||||
actions = actions.reshape(-1, 1)
|
||||
|
||||
rollout_buffer.add(
|
||||
self._last_obs,
|
||||
actions,
|
||||
rewards,
|
||||
self._last_episode_starts,
|
||||
values,
|
||||
log_probs,
|
||||
action_masks=action_masks,
|
||||
)
|
||||
self._last_obs = new_obs
|
||||
self._last_episode_starts = dones
|
||||
|
||||
with th.no_grad():
|
||||
# Compute value for the last timestep
|
||||
obs_tensor = obs_as_tensor(new_obs, self.device)
|
||||
|
||||
# Masking is not needed here, the choice of action doesn't matter.
|
||||
# We only want the value of the current observation.
|
||||
_, values, _ = self.policy.forward(obs_tensor)
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return True
|
||||
|
||||
def predict(
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
"""
|
||||
Get the model's action(s) from an observation.
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:param action_masks: Action masks to apply to the action distribution.
|
||||
:return: the model's action and the next state (used in recurrent policies)
|
||||
"""
|
||||
return self.policy.predict(observation, state, mask, deterministic, action_masks=action_masks)
|
||||
|
||||
def train(self) -> None:
|
||||
"""
|
||||
Update policy using the currently gathered rollout buffer.
|
||||
"""
|
||||
# Switch to train mode (this affects batch norm / dropout)
|
||||
self.policy.set_training_mode(True)
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# Compute current clip range
|
||||
clip_range = self.clip_range(self._current_progress_remaining)
|
||||
# Optional: clip range for the value function
|
||||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||||
|
||||
entropy_losses = []
|
||||
pg_losses, value_losses = [], []
|
||||
clip_fractions = []
|
||||
|
||||
continue_training = True
|
||||
|
||||
# train for n_epochs epochs
|
||||
for epoch in range(self.n_epochs):
|
||||
approx_kl_divs = []
|
||||
# Do a complete pass on the rollout buffer
|
||||
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||||
actions = rollout_data.actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Convert discrete action from float to long
|
||||
actions = rollout_data.actions.long().flatten()
|
||||
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(
|
||||
rollout_data.observations,
|
||||
actions,
|
||||
action_masks=rollout_data.action_masks,
|
||||
)
|
||||
|
||||
values = values.flatten()
|
||||
# Normalize advantage
|
||||
advantages = rollout_data.advantages
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# ratio between old and new policy, should be one at the first iteration
|
||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||
|
||||
# clipped surrogate loss
|
||||
policy_loss_1 = advantages * ratio
|
||||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
||||
|
||||
# Logging
|
||||
pg_losses.append(policy_loss.item())
|
||||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
||||
clip_fractions.append(clip_fraction)
|
||||
|
||||
if self.clip_range_vf is None:
|
||||
# No clipping
|
||||
values_pred = values
|
||||
else:
|
||||
# Clip the different between old and new value
|
||||
# NOTE: this depends on the reward scaling
|
||||
values_pred = rollout_data.old_values + th.clamp(
|
||||
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
|
||||
)
|
||||
# Value loss using the TD(gae_lambda) target
|
||||
value_loss = F.mse_loss(rollout_data.returns, values_pred)
|
||||
value_losses.append(value_loss.item())
|
||||
|
||||
# Entropy loss favor exploration
|
||||
if entropy is None:
|
||||
# Approximate entropy when no analytical form
|
||||
entropy_loss = -th.mean(-log_prob)
|
||||
else:
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
entropy_losses.append(entropy_loss.item())
|
||||
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
# Calculate approximate form of reverse KL Divergence for early stopping
|
||||
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
|
||||
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
|
||||
# and Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
with th.no_grad():
|
||||
log_ratio = log_prob - rollout_data.old_log_prob
|
||||
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
||||
approx_kl_divs.append(approx_kl_div)
|
||||
|
||||
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
||||
continue_training = False
|
||||
if self.verbose >= 1:
|
||||
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
||||
break
|
||||
|
||||
# Optimization step
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
||||
if not continue_training:
|
||||
break
|
||||
|
||||
self._n_updates += self.n_epochs
|
||||
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
|
||||
|
||||
# Logs
|
||||
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
|
||||
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
||||
self.logger.record("train/value_loss", np.mean(value_losses))
|
||||
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
||||
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
|
||||
self.logger.record("train/loss", loss.item())
|
||||
self.logger.record("train/explained_variance", explained_var)
|
||||
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
self.logger.record("train/clip_range", clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "PPO",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
use_masking: bool = True,
|
||||
) -> "MaskablePPO":
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
eval_env,
|
||||
callback,
|
||||
eval_freq,
|
||||
n_eval_episodes,
|
||||
eval_log_path,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
use_masking,
|
||||
)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
|
||||
|
||||
if continue_training is False:
|
||||
break
|
||||
|
||||
iteration += 1
|
||||
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
||||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("time/fps", fps)
|
||||
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
|
||||
self.train()
|
||||
|
||||
callback.on_training_end()
|
||||
|
||||
return self
|
||||
|
|
@ -1 +1 @@
|
|||
1.2.0
|
||||
1.2.1a0
|
||||
|
|
|
|||
|
|
@ -22,9 +22,12 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators
|
|||
# Ignore import not used when aliases are defined
|
||||
per-file-ignores =
|
||||
./sb3_contrib/__init__.py:F401
|
||||
./sb3_contrib/ppo_mask/__init__.py:F401
|
||||
./sb3_contrib/qrdqn/__init__.py:F401
|
||||
./sb3_contrib/tqc/__init__.py:F401
|
||||
./sb3_contrib/common/vec_env/wrappers/__init__.py:F401
|
||||
./sb3_contrib/common/wrappers/__init__.py:F401
|
||||
./sb3_contrib/common/envs/__init__.py:F401
|
||||
exclude =
|
||||
# No need to traverse our git directory
|
||||
.git,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,316 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
|
||||
from sb3_contrib.common.maskable.distributions import (
|
||||
MaskableBernoulliDistribution,
|
||||
MaskableCategorical,
|
||||
MaskableCategoricalDistribution,
|
||||
MaskableMultiCategoricalDistribution,
|
||||
)
|
||||
|
||||
|
||||
class TestMaskableCategorical:
|
||||
def test_applying_mask(self):
|
||||
"""
|
||||
Show that probs change as a result of masking
|
||||
"""
|
||||
|
||||
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
|
||||
expected_probs = th.Tensor([[0, 0.25, 0.75], [0, 0.5, 0.5]])
|
||||
mask = np.array([[False, True, True], [False, True, True]])
|
||||
|
||||
distribution = MaskableCategorical(probs=starting_probs)
|
||||
distribution.apply_masking(mask)
|
||||
assert th.allclose(distribution.probs, expected_probs)
|
||||
|
||||
def test_modifying_mask(self):
|
||||
"""
|
||||
Show that masks apply independently of each other
|
||||
"""
|
||||
|
||||
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
|
||||
expected_probs = th.Tensor([[0.5, 0.5, 0], [0, 1, 0]])
|
||||
first_mask = np.array([[False, True, True], [False, True, True]])
|
||||
second_mask = np.array([[True, True, False], [False, True, False]])
|
||||
|
||||
# pytorch converts probs to logits in a way that loses some precision and makes
|
||||
# 0 probability outcomes slightly non-zero.
|
||||
atol = 2e-07
|
||||
distribution = MaskableCategorical(probs=starting_probs)
|
||||
assert th.allclose(distribution.probs, starting_probs, atol=atol)
|
||||
|
||||
target_distribution = MaskableCategorical(probs=expected_probs)
|
||||
|
||||
distribution.apply_masking(first_mask)
|
||||
distribution.apply_masking(second_mask)
|
||||
|
||||
assert th.allclose(distribution.probs, target_distribution.probs, atol=atol)
|
||||
|
||||
def test_removing_mask(self):
|
||||
"""
|
||||
Show that masking may be unapplied to recover original probs
|
||||
"""
|
||||
|
||||
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
|
||||
mask = np.array([[False, True, True], [False, True, True]])
|
||||
|
||||
distribution = MaskableCategorical(probs=starting_probs)
|
||||
target_distribution = MaskableCategorical(probs=starting_probs)
|
||||
distribution.apply_masking(mask)
|
||||
distribution.apply_masking(None)
|
||||
assert th.allclose(distribution.probs, target_distribution.probs)
|
||||
|
||||
def test_masking_affects_entropy(self):
|
||||
# All outcomes equally likely
|
||||
NUM_DIMS = 3
|
||||
logits = th.Tensor([[0] * NUM_DIMS])
|
||||
dist = MaskableCategorical(logits=logits)
|
||||
|
||||
# For each possible number of valid actions v, show that e^entropy == v
|
||||
for v in range(1, NUM_DIMS + 1):
|
||||
masks = [j < v for j in range(NUM_DIMS)]
|
||||
dist.apply_masking(masks)
|
||||
assert int(dist.entropy().exp()) == v
|
||||
|
||||
|
||||
class TestMaskableCategoricalDistribution:
|
||||
def test_distribution_must_be_initialized(self):
|
||||
"""
|
||||
Cannot use distribution before it has logits
|
||||
"""
|
||||
|
||||
DIMS = 2
|
||||
dist = MaskableCategoricalDistribution(DIMS)
|
||||
with pytest.raises(AssertionError):
|
||||
dist.log_prob(th.randint(DIMS - 1, (1, 3)))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.entropy()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.sample()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.mode()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.apply_masking(None)
|
||||
|
||||
# But now we can
|
||||
dist.proba_distribution(th.randn(1, DIMS))
|
||||
dist.log_prob(th.randint(DIMS - 1, (3, 1)))
|
||||
dist.entropy()
|
||||
dist.sample()
|
||||
dist.mode()
|
||||
dist.apply_masking(None)
|
||||
|
||||
def test_logits_must_align_with_dims(self):
|
||||
NUM_DIMS = 3
|
||||
dist = MaskableCategoricalDistribution(NUM_DIMS)
|
||||
|
||||
# There should be one logit per dim, we're one short
|
||||
logits = th.randn(1, NUM_DIMS - 1)
|
||||
with pytest.raises(RuntimeError):
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
# That's better
|
||||
logits = th.randn(1, NUM_DIMS)
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
# Other numbers of dimensions are acceptable as long as they can be realigned
|
||||
logits = th.randn(NUM_DIMS)
|
||||
dist.proba_distribution(logits)
|
||||
logits = th.randn(3, NUM_DIMS, 3)
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
def test_dim_masking(self):
|
||||
NUM_DIMS = 2
|
||||
dist = MaskableCategoricalDistribution(NUM_DIMS)
|
||||
|
||||
logits = th.Tensor([[0] * NUM_DIMS])
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
assert (dist.distribution.probs == 0.5).all()
|
||||
assert int(dist.entropy().exp()) == NUM_DIMS
|
||||
|
||||
for i in range(NUM_DIMS):
|
||||
mask = np.array([False] * NUM_DIMS)
|
||||
mask[i] = True
|
||||
dist.apply_masking(mask)
|
||||
probs = dist.distribution.probs
|
||||
assert probs.sum() == 1
|
||||
assert probs[0][i] == 1
|
||||
assert int(dist.entropy().exp()) == 1
|
||||
|
||||
dist.apply_masking(None)
|
||||
assert (dist.distribution.probs == 0.5).all()
|
||||
assert int(dist.entropy().exp()) == NUM_DIMS
|
||||
|
||||
|
||||
class TestMaskableMultiCategoricalDistribution:
|
||||
def test_distribution_must_be_initialized(self):
|
||||
"""
|
||||
Cannot use distribution before it has logits
|
||||
"""
|
||||
|
||||
DIMS_PER_CAT = 2
|
||||
NUM_CATS = 2
|
||||
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS)))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.entropy()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.sample()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.mode()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.apply_masking(None)
|
||||
|
||||
# But now we can
|
||||
dist.proba_distribution(th.randn(1, DIMS_PER_CAT * NUM_CATS))
|
||||
dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS)))
|
||||
dist.entropy()
|
||||
dist.sample()
|
||||
dist.mode()
|
||||
dist.apply_masking(None)
|
||||
|
||||
def test_logits_must_align_with_dims(self):
|
||||
DIMS_PER_CAT = 3
|
||||
NUM_CATS = 2
|
||||
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
|
||||
|
||||
# There should be one logit per dim, we're one short
|
||||
logits = th.randn(1, DIMS_PER_CAT * NUM_CATS - 1)
|
||||
with pytest.raises(RuntimeError):
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
# That's better
|
||||
logits = th.randn(1, DIMS_PER_CAT * NUM_CATS)
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
# Other numbers of dimensions are acceptable as long as they can be realigned
|
||||
logits = th.randn(DIMS_PER_CAT * NUM_CATS)
|
||||
dist.proba_distribution(logits)
|
||||
logits = th.randn(3, DIMS_PER_CAT * NUM_CATS, 3)
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
def test_dim_masking(self):
|
||||
DIMS_PER_CAT = 2
|
||||
NUM_CATS = 3
|
||||
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
|
||||
|
||||
logits = th.Tensor([[0] * DIMS_PER_CAT * NUM_CATS])
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
assert len(dist.distributions) == NUM_CATS
|
||||
for i in range(NUM_CATS):
|
||||
assert (dist.distributions[i].probs == 0.5).all()
|
||||
assert int(dist.entropy().exp()) == DIMS_PER_CAT ** NUM_CATS
|
||||
|
||||
for i in range(DIMS_PER_CAT):
|
||||
mask = np.array([False] * DIMS_PER_CAT * NUM_CATS)
|
||||
for j in range(NUM_CATS):
|
||||
mask[j * DIMS_PER_CAT + i] = True
|
||||
|
||||
dist.apply_masking(mask)
|
||||
for j in range(NUM_CATS):
|
||||
probs = dist.distributions[j].probs
|
||||
assert probs.sum() == 1
|
||||
assert probs[0][i] == 1
|
||||
|
||||
assert int(dist.entropy().exp()) == 1
|
||||
|
||||
dist.apply_masking(None)
|
||||
for i in range(NUM_CATS):
|
||||
assert (dist.distributions[i].probs == 0.5).all()
|
||||
assert int(dist.entropy().exp()) == DIMS_PER_CAT ** NUM_CATS
|
||||
|
||||
|
||||
class TestMaskableBernoulliDistribution:
|
||||
def test_distribution_must_be_initialized(self):
|
||||
"""
|
||||
Cannot use distribution before it has logits
|
||||
"""
|
||||
|
||||
DIMS = 2
|
||||
dist = MaskableBernoulliDistribution(DIMS)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.log_prob(th.randint(1, (2, DIMS)))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.entropy()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.sample()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.mode()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dist.apply_masking(None)
|
||||
|
||||
# But now we can
|
||||
dist.proba_distribution(th.randn(1, 2 * DIMS))
|
||||
dist.log_prob(th.randint(1, (2, DIMS)))
|
||||
dist.entropy()
|
||||
dist.sample()
|
||||
dist.mode()
|
||||
dist.apply_masking(None)
|
||||
|
||||
def test_logits_must_align_with_dims(self):
|
||||
NUM_DIMS = 3
|
||||
dist = MaskableBernoulliDistribution(NUM_DIMS)
|
||||
|
||||
# There should be two logits per dim, we're one short
|
||||
logits = th.randn(1, 2 * NUM_DIMS - 1)
|
||||
with pytest.raises(RuntimeError):
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
# That's better
|
||||
logits = th.randn(1, 2 * NUM_DIMS)
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
# Other numbers of dimensions are acceptable as long as they can be realigned
|
||||
logits = th.randn(2 * NUM_DIMS)
|
||||
dist.proba_distribution(logits)
|
||||
logits = th.randn(3, 2 * NUM_DIMS, 3)
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
def test_dim_masking(self):
|
||||
NUM_DIMS = 2
|
||||
BINARY_STATES = 2
|
||||
dist = MaskableBernoulliDistribution(NUM_DIMS)
|
||||
|
||||
logits = th.Tensor([[0] * BINARY_STATES * NUM_DIMS])
|
||||
dist.proba_distribution(logits)
|
||||
|
||||
assert len(dist.distributions) == NUM_DIMS
|
||||
for i in range(NUM_DIMS):
|
||||
assert (dist.distributions[i].probs == 0.5).all()
|
||||
assert int(dist.entropy().exp()) == BINARY_STATES * NUM_DIMS
|
||||
|
||||
for i in range(BINARY_STATES):
|
||||
mask = np.array([False] * BINARY_STATES * NUM_DIMS)
|
||||
for j in range(NUM_DIMS):
|
||||
mask[j * BINARY_STATES + i] = True
|
||||
|
||||
dist.apply_masking(mask)
|
||||
for j in range(NUM_DIMS):
|
||||
probs = dist.distributions[j].probs
|
||||
assert probs.sum() == 1
|
||||
assert probs[0][i] == 1
|
||||
|
||||
assert int(dist.entropy().exp()) == 1
|
||||
|
||||
dist.apply_masking(None)
|
||||
for i in range(NUM_DIMS):
|
||||
assert (dist.distributions[i].probs == 0.5).all()
|
||||
assert int(dist.entropy().exp()) == BINARY_STATES * NUM_DIMS
|
||||
|
|
@ -0,0 +1,232 @@
|
|||
import random
|
||||
|
||||
import gym
|
||||
import pytest
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
|
||||
from sb3_contrib import MaskablePPO
|
||||
from sb3_contrib.common.envs import InvalidActionEnvDiscrete, InvalidActionEnvMultiBinary, InvalidActionEnvMultiDiscrete
|
||||
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
||||
from sb3_contrib.common.maskable.evaluation import evaluate_policy
|
||||
from sb3_contrib.common.maskable.utils import is_masking_supported
|
||||
from sb3_contrib.common.wrappers import ActionMasker
|
||||
|
||||
|
||||
def make_env():
|
||||
return InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
|
||||
|
||||
class ToDictWrapper(gym.Wrapper):
|
||||
"""
|
||||
Simple wrapper to test MultInputPolicy on Dict obs.
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space})
|
||||
|
||||
def reset(self):
|
||||
return {"obs": self.env.reset()}
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, infos = self.env.step(action)
|
||||
return {"obs": obs}, reward, done, infos
|
||||
|
||||
|
||||
def test_identity():
|
||||
"""
|
||||
Performance test.
|
||||
A randomly initialized model cannot solve that task (score ~=6),
|
||||
nor a model without invalid action masking (score ~=30 after training)
|
||||
which such a low training budget.
|
||||
"""
|
||||
env = InvalidActionEnvDiscrete(dim=70, n_invalid_actions=55)
|
||||
model = MaskablePPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
gamma=0.4,
|
||||
seed=32,
|
||||
verbose=0,
|
||||
)
|
||||
model.learn(3000)
|
||||
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||
|
||||
|
||||
def test_supports_discrete_action_space():
|
||||
"""
|
||||
No errors using algorithm with an env that has a discrete action space
|
||||
"""
|
||||
|
||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8)
|
||||
model.learn(100)
|
||||
evaluate_policy(model, env, warn=False)
|
||||
|
||||
# Mask all actions except the good one, a random model should succeed
|
||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19)
|
||||
model = MaskablePPO("MlpPolicy", env, seed=8)
|
||||
evaluate_policy(model, env, reward_threshold=99, warn=False)
|
||||
|
||||
|
||||
def test_supports_multi_discrete_action_space():
|
||||
"""
|
||||
No errors using algorithm with an env that has a multidiscrete action space
|
||||
"""
|
||||
|
||||
env = InvalidActionEnvMultiDiscrete(dims=[2, 3], n_invalid_actions=1)
|
||||
model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8)
|
||||
model.learn(100)
|
||||
evaluate_policy(model, env, warn=False)
|
||||
|
||||
# Mask all actions except the good ones, a random model should succeed
|
||||
env = InvalidActionEnvMultiDiscrete(dims=[2, 3], n_invalid_actions=3)
|
||||
model = MaskablePPO("MlpPolicy", env, seed=8)
|
||||
evaluate_policy(model, env, reward_threshold=99, warn=False)
|
||||
|
||||
|
||||
def test_supports_multi_binary_action_space():
|
||||
"""
|
||||
No errors using algorithm with an env that has a multidiscrete action space
|
||||
"""
|
||||
|
||||
env = InvalidActionEnvMultiBinary(dims=3, n_invalid_actions=1)
|
||||
model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8)
|
||||
model.learn(100)
|
||||
evaluate_policy(model, env, warn=False)
|
||||
|
||||
# Mask all actions except the good ones, a random model should succeed
|
||||
env = InvalidActionEnvMultiBinary(dims=3, n_invalid_actions=3)
|
||||
model = MaskablePPO("MlpPolicy", env, seed=8)
|
||||
evaluate_policy(model, env, reward_threshold=99, warn=False)
|
||||
|
||||
|
||||
def test_disabling_masking():
|
||||
"""
|
||||
Behave like normal PPO if masking is disabled, which allows for envs that don't provide masks
|
||||
"""
|
||||
|
||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19)
|
||||
model = MaskablePPO("MlpPolicy", env, seed=8)
|
||||
evaluate_policy(model, env, reward_threshold=99, warn=False)
|
||||
|
||||
# With masking disabled, perfect performance disappears
|
||||
with pytest.raises(AssertionError):
|
||||
evaluate_policy(model, env, reward_threshold=99, warn=False, use_masking=False)
|
||||
|
||||
# Without masking disabled, learning/evaluation will fail if the env doesn't provide masks
|
||||
env = IdentityEnv(dim=2)
|
||||
model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8)
|
||||
with pytest.raises(ValueError):
|
||||
model.learn(100)
|
||||
with pytest.raises(ValueError):
|
||||
evaluate_policy(model, env, warn=False)
|
||||
|
||||
model.learn(100, use_masking=False)
|
||||
evaluate_policy(model, env, warn=False, use_masking=False)
|
||||
|
||||
|
||||
def test_masked_evaluation():
|
||||
"""
|
||||
Masking can be enabled or disabled for evaluation, but masking should perform better.
|
||||
"""
|
||||
|
||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19)
|
||||
model = MaskablePPO("MlpPolicy", env, seed=8)
|
||||
masked_avg_rew, _ = evaluate_policy(model, env, warn=False)
|
||||
unmasked_avg_rew, _ = evaluate_policy(model, env, warn=False, use_masking=False)
|
||||
assert masked_avg_rew > unmasked_avg_rew
|
||||
|
||||
|
||||
def test_supports_multi_envs():
|
||||
"""
|
||||
Learning and evaluation works with VecEnvs
|
||||
"""
|
||||
|
||||
env = make_vec_env(make_env, n_envs=2)
|
||||
assert is_masking_supported(env)
|
||||
model = MaskablePPO("MlpPolicy", env, n_steps=256, gamma=0.4, seed=32, verbose=1)
|
||||
model.learn(100)
|
||||
evaluate_policy(model, env, warn=False)
|
||||
|
||||
env = make_vec_env(IdentityEnv, n_envs=2, env_kwargs={"dim": 2})
|
||||
assert not is_masking_supported(env)
|
||||
model = MaskablePPO("MlpPolicy", env, n_steps=256, gamma=0.4, seed=32, verbose=1)
|
||||
with pytest.raises(ValueError):
|
||||
model.learn(100)
|
||||
with pytest.raises(ValueError):
|
||||
evaluate_policy(model, env, warn=False)
|
||||
model.learn(100, use_masking=False)
|
||||
evaluate_policy(model, env, warn=False, use_masking=False)
|
||||
|
||||
|
||||
def test_callback(tmp_path):
|
||||
"""
|
||||
No errors using MaskableEvalCallback during learning
|
||||
"""
|
||||
|
||||
env = make_env()
|
||||
eval_env = make_env()
|
||||
model = MaskablePPO("MlpPolicy", env, n_steps=64, gamma=0.4, seed=32, verbose=1)
|
||||
model.learn(100, callback=MaskableEvalCallback(eval_env, eval_freq=100, warn=False, log_path=tmp_path))
|
||||
|
||||
model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False))
|
||||
|
||||
|
||||
def test_maskable_policy_required():
|
||||
"""
|
||||
MaskablePPO requires a policy that subclasses MaskableActorCriticPolicy
|
||||
"""
|
||||
|
||||
env = make_env()
|
||||
with pytest.raises(ValueError):
|
||||
MaskablePPO(ActorCriticPolicy, env)
|
||||
|
||||
|
||||
def test_discrete_action_space_required():
|
||||
"""
|
||||
MaskablePPO requires an env with a discrete (ie non-continuous) action space
|
||||
"""
|
||||
|
||||
env = IdentityEnvBox()
|
||||
with pytest.raises(AssertionError):
|
||||
MaskablePPO("MlpPolicy", env)
|
||||
|
||||
|
||||
def test_cnn():
|
||||
def action_mask_fn(env):
|
||||
random_invalid_action = random.randrange(env.action_space.n)
|
||||
return [i != random_invalid_action for i in range(env.action_space.n)]
|
||||
|
||||
env = FakeImageEnv()
|
||||
env = ActionMasker(env, action_mask_fn)
|
||||
|
||||
model = MaskablePPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
n_steps=64,
|
||||
seed=32,
|
||||
verbose=1,
|
||||
policy_kwargs=dict(
|
||||
features_extractor_kwargs=dict(features_dim=32),
|
||||
),
|
||||
)
|
||||
model.learn(100)
|
||||
evaluate_policy(model, env, warn=False)
|
||||
|
||||
|
||||
# Dict observations are currently not supported
|
||||
# def test_dict_obs():
|
||||
# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
# env = ToDictWrapper(env)
|
||||
# model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8)
|
||||
# model.learn(100)
|
||||
# evaluate_policy(model, env, warn=False)
|
||||
#
|
||||
# # Mask all actions except the good one, a random model should succeed
|
||||
# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19)
|
||||
# env = ToDictWrapper(env)
|
||||
# model = MaskablePPO("MultiInputPolicy", env, seed=8)
|
||||
# evaluate_policy(model, env, reward_threshold=99, warn=False)
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
|
@ -6,7 +8,9 @@ import torch.nn as nn
|
|||
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
from sb3_contrib import QRDQN, TQC
|
||||
from sb3_contrib import QRDQN, TQC, MaskablePPO
|
||||
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
|
||||
from sb3_contrib.common.maskable.utils import get_action_masks
|
||||
|
||||
|
||||
class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
|
||||
|
|
@ -76,12 +80,51 @@ def clone_tqc_batch_norm_stats(
|
|||
return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean)
|
||||
|
||||
|
||||
def clone_on_policy_batch_norm(model: Union[MaskablePPO]) -> (th.Tensor, th.Tensor):
|
||||
return clone_batch_norm_stats(model.policy.features_extractor.batch_norm)
|
||||
|
||||
|
||||
CLONE_HELPERS = {
|
||||
QRDQN: clone_qrdqn_batch_norm_stats,
|
||||
TQC: clone_tqc_batch_norm_stats,
|
||||
MaskablePPO: clone_on_policy_batch_norm,
|
||||
}
|
||||
|
||||
|
||||
def test_ppo_mask_train_eval_mode():
|
||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
model = MaskablePPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
|
||||
seed=1,
|
||||
)
|
||||
|
||||
bias_before, running_mean_before = clone_on_policy_batch_norm(model)
|
||||
|
||||
model.learn(total_timesteps=200)
|
||||
|
||||
bias_after, running_mean_after = clone_on_policy_batch_norm(model)
|
||||
|
||||
assert ~th.isclose(bias_before, bias_after).all()
|
||||
assert ~th.isclose(running_mean_before, running_mean_after).all()
|
||||
|
||||
batch_norm_stats_before = clone_on_policy_batch_norm(model)
|
||||
|
||||
observation = env.reset()
|
||||
action_masks = get_action_masks(env)
|
||||
first_prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True)
|
||||
for _ in range(5):
|
||||
prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True)
|
||||
np.testing.assert_allclose(first_prediction, prediction)
|
||||
|
||||
batch_norm_stats_after = clone_on_policy_batch_norm(model)
|
||||
|
||||
# No change in batch norm params
|
||||
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
|
||||
assert th.isclose(param_before, param_after).all()
|
||||
|
||||
|
||||
def test_qrdqn_train_with_batch_norm():
|
||||
model = QRDQN(
|
||||
"MlpPolicy",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
from typing import List
|
||||
|
||||
import pytest
|
||||
from gym.spaces import Discrete
|
||||
from stable_baselines3.common.envs import IdentityEnv
|
||||
|
||||
from sb3_contrib.common.wrappers import ActionMasker
|
||||
|
||||
|
||||
class IdentityEnvDiscrete(IdentityEnv):
|
||||
def __init__(self, dim: int = 1, ep_length: int = 100):
|
||||
"""
|
||||
Identity environment for testing purposes
|
||||
|
||||
:param dim: the size of the dimensions you want to learn
|
||||
:param ep_length: the length of each episode in timesteps
|
||||
"""
|
||||
space = Discrete(dim)
|
||||
self.useless_property = 1
|
||||
super().__init__(ep_length=ep_length, space=space)
|
||||
|
||||
def _action_masks(self) -> List[int]:
|
||||
return [i == self.state for i in range(self.action_space.n)]
|
||||
|
||||
|
||||
def action_mask_fn(env: IdentityEnvDiscrete) -> List[int]:
|
||||
return [i == env.state for i in range(env.action_space.n)]
|
||||
|
||||
|
||||
def test_wrapper_accepts_function():
|
||||
"""
|
||||
ActionMasker accepts a function
|
||||
"""
|
||||
|
||||
env = IdentityEnvDiscrete()
|
||||
|
||||
assert not hasattr(env, "action_masks")
|
||||
env = ActionMasker(env, action_mask_fn)
|
||||
assert hasattr(env, "action_masks")
|
||||
|
||||
|
||||
# Wrapper accepts as string name of a method on the underlying env
|
||||
def test_wrapper_accepts_attr_name():
|
||||
"""
|
||||
ActionMasker accepts a string name of a method on the underlying env
|
||||
"""
|
||||
|
||||
env = IdentityEnvDiscrete()
|
||||
|
||||
assert not hasattr(env, "action_masks")
|
||||
env = ActionMasker(env, "_action_masks")
|
||||
assert hasattr(env, "action_masks")
|
||||
|
||||
|
||||
def test_attr_must_be_callable():
|
||||
"""
|
||||
Passing ActionMasker the string name of a non-callable is an error
|
||||
"""
|
||||
|
||||
env = IdentityEnvDiscrete()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
env = ActionMasker(env, "useless_property")
|
||||
|
||||
|
||||
# Wrapper method returns expected results
|
||||
def test_action_masks_returns_expected_result():
|
||||
"""
|
||||
ActionMasker-provided action_masks() method returns expected results
|
||||
"""
|
||||
|
||||
env = IdentityEnvDiscrete()
|
||||
env = ActionMasker(env, action_mask_fn)
|
||||
|
||||
# Only one valid action expected
|
||||
masks = env.action_masks()
|
||||
masks[env.state] = not masks[env.state] # Bit-flip the one expected valid action
|
||||
assert all([not mask for mask in masks])
|
||||
Loading…
Reference in New Issue