PPO variant with invalid action masking (#25)

* Add wrappers

* Add maskable distributions

* Add mypy configuration

* Add maskable base datastructures

* Add ppo_mask package

* Fix circular dependency and remove test code that slipped in

* Automatically mask vecenv if env is masked

* Fix debugging change that slipped in

* Workaround for subclassing RolloutBufferSamples

* Duplicate lots of policy code in order to swap out the distributions used

* Fix pytype error

* Maintain py 3.6 compatibility

* Fix isort lint errors

* Use pyproject.toml to configure black line length

* Blacken

* Remove mypy.ini

* Fully replace RolloutBufferSamples

* Drop support for continuous distributions, remove SDE-related code

* Eliminate MaskableAlgorithm and MaskableOnPolicyAlgorithm

* Fix formatting

* Override superclass methods as needed, fix circular import, improve naming

* Fix codestyle

* Eliminate VecActionMasker, replace with utils

* Fix codestyle

* Support masking for MultiDiscrete action spaces

* Fix codestyle

* Don't require the env to provide the mask already flattened

* Consistent naming, prefer 'Maskable' to 'Masked'

* Register policy

* Link to abstract instead of pdf

* Allow distribution masking to be unapplied + improved comments and docstrings

* Don't use deprecated implicit optional typing

* Check codestyle

* Add docstring and remove misplaced TODO

* Simplify env masking API, error if API unmet. Make use_masking a learn() kwarg

* Fix codestyle

* Update various internals to be consistent with latest SB3

* Simplify MaskableRolloutBuffer reset

* Add docstring and type annotations

* Ensure old probs aren't cached

* Fix for new logger

* Add test + fixes

* Start doc

* Fix type annotation

* Remove abstract class + add test

* Fix evaluation (add support for multi envs)

* Handle merge conflicts in documentation

* Bugfix: mask updates should apply to original logits, not the last masked output

* Add test of distribution masking behavior

* Reformat

* Add MultiBinary support, remove unneeded distribution type checks

* Remove unused import

* Fix when using multiple envs

* Remove addressed TODO

* Upgrade for SB3 1.2.0

* Update docs with results + how to replicate

* Add action masker tests, move wrapper tests

* Move distributions, add more distribution tests

* Add MaskablePPO tests, simplify and rename discrete test env

* Address TODO

* Add tests for MaskableMultiCategoricalDistribution, fix distributions

* Add maskable identity envs for all supported action spaces, add tests, fix bug

* Formatting fixes

* Update doc env

* Dict support not ready

* Cleanup

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
kronion 2021-09-23 07:50:10 -05:00 committed by GitHub
parent b2e7126840
commit ab24f8039f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 2755 additions and 6 deletions

View File

@ -24,7 +24,7 @@ check-codestyle:
# Sort imports
isort --check ${LINT_PATHS}
# Reformat using black
black --check -l 127 ${LINT_PATHS}
black --check ${LINT_PATHS}
commit-checks: format type lint

View File

@ -6,7 +6,7 @@ dependencies:
- cpuonly=1.0=0
- pip=20.2
- python=3.6
- pytorch=1.5.0=py3.6_cpu_0
- pytorch=1.8.1=py3.6_cpu_0
- pip:
- gym>=0.17.2
- cloudpickle
@ -15,4 +15,6 @@ dependencies:
- numpy
- matplotlib
- sphinx_autodoc_typehints
- stable-baselines3>=0.10.0a0
- stable-baselines3>=1.2.0
# See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
- sphinx_rtd_theme>=1.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

BIN
docs/images/10x10_mask.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

BIN
docs/images/4x4_mask.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

BIN
docs/images/4x4_no_mask.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

View File

@ -33,6 +33,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
modules/tqc
modules/qrdqn
modules/ppo_mask
.. toctree::
:maxdepth: 1

View File

@ -3,6 +3,31 @@
Changelog
==========
Release 1.2.1a0 (WIP)
-------------------------------
**Invalid action masking for PPO**
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
- Added ``MaskablePPO`` algorithm (@kronion)
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Documentation:
^^^^^^^^^^^^^^
Release 1.2.0 (2021-09-08)
-------------------------------
@ -13,6 +38,9 @@ Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.2.0
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
- QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright)
@ -156,4 +184,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
Contributors:
-------------
@ku2482 @guyk1971 @minhlong94 @ayeright
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion

256
docs/modules/ppo_mask.rst Normal file
View File

@ -0,0 +1,256 @@
.. _ppo_mask:
.. automodule:: sb3_contrib.ppo_mask
Maskable PPO
============
Implementation of `invalid action masking <https://arxiv.org/abs/2006.14171>`_ for the Proximal Policy Optimization(PPO)
algorithm. Other than adding support for action masking, the behavior is the same as in SB3's core PPO algorithm.
.. rubric:: Available Policies
.. autosummary::
:nosignatures:
MlpPolicy
CnnPolicy
.. MultiInputPolicy
Notes
-----
- Paper: https://arxiv.org/abs/2006.14171
- Blog post: https://costa.sh/blog-a-closer-look-at-invalid-action-masking-in-policy-gradient-algorithms.html
- Additional Blog post: https://boring-guy.sh/posts/masking-rl/
Can I use?
----------
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔️ ✔️
Box ❌ ✔️
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
Dict ❌ ❌
============= ====== ===========
Example
-------
Train a PPO agent on ``InvalidActionEnvDiscrete``. ``InvalidActionEnvDiscrete`` has a ``action_masks`` method that
returns the invalid action mask (``True`` if the action is valid, ``False`` otherwise).
.. code-block:: python
from sb3_contrib import MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.utils import get_action_masks
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, verbose=1)
model.learn(5000)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
model.save("ppo_mask")
del model # remove to demonstrate saving and loading
model = MaskablePPO.load("ppo_mask")
obs = env.reset()
while True:
# Retrieve current action mask
action_masks = get_action_masks(env)
action, _states = model.predict(obs, action_masks=action_masks)
obs, rewards, dones, info = env.step(action)
env.render()
If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker``
to specify the name (see `PR #25 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/25>`_):
.. code-block:: python
import gym
import numpy as np
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.ppo_mask import MaskablePPO
def mask_fn(env: gym.Env) -> np.ndarray:
# Do whatever you'd like in this function to return the action mask
# for the current env. In this example, we assume the env has a
# helpful method we can rely on.
return env.valid_action_mask()
env = ... # Initialize env
env = ActionMasker(env, mask_fn) # Wrap to enable masking
# MaskablePPO behaves the same as SB3's PPO unless the env is wrapped
# with ActionMasker. If the wrapper is detected, the masks are automatically
# retrieved and used when learning. Note that MaskablePPO does not accept
# a new action_mask_fn kwarg, as it did in an earlier draft.
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)
model.learn()
# Note that use of masks is manual and optional outside of learning,
# so masking can be "removed" at testing time
model.predict(observation, action_masks=valid_action_array)
Results
-------
Results are shown for two MicroRTS benchmarks: MicrortsMining4x4F9-v0 (600K steps) and MicrortsMining10x10F9-v0
(1.5M steps). For each, models were trained with and without masking, using 3 seeds.
4x4
^^^
No masking
""""""""""
.. image:: /images/4x4_no_mask.png
With masking
""""""""""""
.. image:: /images/4x4_mask.png
Combined
""""""""
.. image:: /images/4x4_combined.png
10x10
^^^^^
No masking
""""""""""
.. image:: /images/10x10_no_mask.png
With masking
""""""""""""
""""""""
.. image:: /images/10x10_mask.png
Combined
""""""""
.. image:: /images/10x10_combined.png
More information may be found in the
`associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/25>`_.
How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Clone the repo for the experiment:
.. code-block:: bash
git clone git@github.com:kronion/microrts-ppo-comparison.git
cd microrts-ppo-comparison
Install dependencies:
.. code-block:: bash
# Install MicroRTS:
rm -fR ~/microrts && mkdir ~/microrts && \
wget -O ~/microrts/microrts.zip http://microrts.s3.amazonaws.com/microrts/artifacts/202004222224.microrts.zip && \
unzip ~/microrts/microrts.zip -d ~/microrts/
# You may want to make a venv before installing packages
pip install -r requirements.txt
Train several times with various seeds, with and without masking:
.. code-block:: bash
# python sb/train_ppo.py [output dir] [MicroRTS map size] [--mask] [--seed int]
# 4x4 unmasked
python sb3/train_ppo.py zoo 4 --seed 42
python sb3/train_ppo.py zoo 4 --seed 43
python sb3/train_ppo.py zoo 4 --seed 44
# 4x4 masked
python sb3/train_ppo.py zoo 4 --mask --seed 42
python sb3/train_ppo.py zoo 4 --mask --seed 43
python sb3/train_ppo.py zoo 4 --mask --seed 44
# 10x10 unmasked
python sb3/train_ppo.py zoo 10 --seed 42
python sb3/train_ppo.py zoo 10 --seed 43
python sb3/train_ppo.py zoo 10 --seed 44
# 10x10 masked
python sb3/train_ppo.py zoo 10 --mask --seed 42
python sb3/train_ppo.py zoo 10 --mask --seed 43
python sb3/train_ppo.py zoo 10 --mask --seed 44
View the tensorboard log output:
.. code-block:: bash
# For 4x4 environment
tensorboard --logdir zoo/4x4/runs
# For 10x10 environment
tensorboard --logdir zoo/10x10/runs
Parameters
----------
.. autoclass:: MaskablePPO
:members:
:inherited-members:
MaskablePPO Policies
--------------------
.. autoclass:: MlpPolicy
:members:
:inherited-members:
.. autoclass:: sb3_contrib.common.maskable.policies.MaskableActorCriticPolicy
:members:
:noindex:
.. autoclass:: CnnPolicy
:members:
.. autoclass:: sb3_contrib.common.maskable.policies.MaskableActorCriticCnnPolicy
:members:
:noindex:
.. .. autoclass:: MultiInputPolicy
.. :members:
..
.. .. autoclass:: sb3_contrib.common.maskable.policies.MaskableMultiInputActorCriticPolicy
.. :members:
.. :noindex:

2
pyproject.toml Normal file
View File

@ -0,0 +1,2 @@
[tool.black]
line-length = 127

View File

@ -1,5 +1,6 @@
import os
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC

View File

@ -0,0 +1,5 @@
from sb3_contrib.common.envs.invalid_actions_env import (
InvalidActionEnvDiscrete,
InvalidActionEnvMultiBinary,
InvalidActionEnvMultiDiscrete,
)

View File

@ -0,0 +1,116 @@
from typing import List, Optional
import numpy as np
from gym import spaces
from stable_baselines3.common.envs import IdentityEnv
class InvalidActionEnvDiscrete(IdentityEnv):
"""
Identity env with a discrete action space. Supports action masking.
"""
def __init__(
self,
dim: Optional[int] = None,
ep_length: int = 100,
n_invalid_actions: int = 0,
):
if dim is None:
dim = 1
assert n_invalid_actions < dim, f"Too many invalid actions: {n_invalid_actions} < {dim}"
space = spaces.Discrete(dim)
self.n_invalid_actions = n_invalid_actions
self.possible_actions = np.arange(space.n)
self.invalid_actions: List[int] = []
super().__init__(space=space, ep_length=ep_length)
def _choose_next_state(self) -> None:
self.state = self.action_space.sample()
# Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i != self.state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions]
class InvalidActionEnvMultiDiscrete(IdentityEnv):
"""
Identity env with a multidiscrete action space. Supports action masking.
"""
def __init__(
self,
dims: Optional[List[int]] = None,
ep_length: int = 100,
n_invalid_actions: int = 0,
):
if dims is None:
dims = [1, 1]
if n_invalid_actions > sum(dims) - len(dims):
raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {sum(dims) - len(dims)}")
space = spaces.MultiDiscrete(dims)
self.n_invalid_actions = n_invalid_actions
self.possible_actions = np.arange(sum(dims))
self.invalid_actions: List[int] = []
super().__init__(space=space, ep_length=ep_length)
def _choose_next_state(self) -> None:
self.state = self.action_space.sample()
converted_state: List[int] = []
running_total = 0
for i in range(len(self.action_space.nvec)):
converted_state.append(running_total + self.state[i])
running_total += self.action_space.nvec[i]
# Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions]
class InvalidActionEnvMultiBinary(IdentityEnv):
"""
Identity env with a multibinary action space. Supports action masking.
"""
def __init__(
self,
dims: Optional[int] = None,
ep_length: int = 100,
n_invalid_actions: int = 0,
):
if dims is None:
dims = 1
if n_invalid_actions > dims:
raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}")
space = spaces.MultiBinary(dims)
self.n_invalid_actions = n_invalid_actions
self.possible_actions = np.arange(2 * dims)
self.invalid_actions: List[int] = []
super().__init__(space=space, ep_length=ep_length)
def _choose_next_state(self) -> None:
self.state = self.action_space.sample()
converted_state: List[int] = []
running_total = 0
for i in range(self.action_space.n):
converted_state.append(running_total + self.state[i])
running_total += 2
# Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions]

View File

View File

@ -0,0 +1,98 @@
from typing import Generator, NamedTuple, Optional
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.vec_env import VecNormalize
class MaskableRolloutBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
action_masks: th.Tensor
class MaskableRolloutBuffer(RolloutBuffer):
"""
Rollout buffer that also stores the invalid action masks associated with each observation.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
def __init__(self, *args, **kwargs):
self.action_masks = None
super().__init__(*args, **kwargs)
def reset(self) -> None:
if isinstance(self.action_space, spaces.Discrete):
mask_dims = self.action_space.n
elif isinstance(self.action_space, spaces.MultiDiscrete):
mask_dims = sum(self.action_space.nvec)
elif isinstance(self.action_space, spaces.MultiBinary):
mask_dims = 2 * self.action_space.n # One mask per binary outcome
else:
raise ValueError(f"Unsupported action space {type(self.action_space)}")
self.mask_dims = mask_dims
self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
super().reset()
def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
"""
:param action_masks: Masks applied to constrain the choice of possible actions.
"""
if action_masks is not None:
self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for tensor in [
"observations",
"actions",
"values",
"log_probs",
"advantages",
"returns",
"action_masks",
]:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
data = (
self.observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
self.action_masks[batch_inds].reshape(-1, self.mask_dims),
)
return MaskableRolloutBufferSamples(*map(self.to_torch, data))

View File

@ -0,0 +1,107 @@
import os
import numpy as np
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import sync_envs_normalization
from sb3_contrib.common.maskable.evaluation import evaluate_policy
class MaskableEvalCallback(EvalCallback):
"""
Callback for evaluating an agent. Supports invalid action masking.
:param eval_env: The environment used for initialization
:param callback_on_new_best: Callback to trigger
when there is a new best model according to the ``mean_reward``
:param n_eval_episodes: The number of episodes to test the agent
:param eval_freq: Evaluate the agent every eval_freq call of the callback.
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
will be saved. It will be updated at each evaluation.
:param best_model_save_path: Path to a folder where the best model
according to performance on the eval env will be saved.
:param deterministic: Whether the evaluation should
use a stochastic or deterministic actions.
:param render: Whether to render or not the environment during evaluation
:param verbose:
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
wrapped with a Monitor wrapper)
:param use_masking: Whether or not to use invalid action masks during evaluation
"""
def __init__(self, *args, use_masking: bool = True, **kwargs):
super().__init__(*args, **kwargs)
self.use_masking = use_masking
def _on_step(self) -> bool:
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
# Sync training and eval env if there is VecNormalize
sync_envs_normalization(self.training_env, self.eval_env)
# Reset success rate buffer
self._is_success_buffer = []
# Note that evaluate_policy() has been patched to support masking
episode_rewards, episode_lengths = evaluate_policy(
self.model,
self.eval_env,
n_eval_episodes=self.n_eval_episodes,
render=self.render,
deterministic=self.deterministic,
return_episode_rewards=True,
warn=self.warn,
callback=self._log_success_callback,
use_masking=self.use_masking,
)
if self.log_path is not None:
self.evaluations_timesteps.append(self.num_timesteps)
self.evaluations_results.append(episode_rewards)
self.evaluations_length.append(episode_lengths)
kwargs = {}
# Save success log if present
if len(self._is_success_buffer) > 0:
self.evaluations_successes.append(self._is_success_buffer)
kwargs = dict(successes=self.evaluations_successes)
np.savez(
self.log_path,
timesteps=self.evaluations_timesteps,
results=self.evaluations_results,
ep_lengths=self.evaluations_length,
**kwargs,
)
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
self.last_mean_reward = mean_reward
if self.verbose > 0:
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
# Add to current Logger
self.logger.record("eval/mean_reward", float(mean_reward))
self.logger.record("eval/mean_ep_length", mean_ep_length)
if len(self._is_success_buffer) > 0:
success_rate = np.mean(self._is_success_buffer)
if self.verbose > 0:
print(f"Success rate: {100 * success_rate:.2f}%")
self.logger.record("eval/success_rate", success_rate)
# Dump log so the evaluation results are printed with the correct timestep
self.logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(self.num_timesteps)
if mean_reward > self.best_mean_reward:
if self.verbose > 0:
print("New best mean reward!")
if self.best_model_save_path is not None:
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
self.best_mean_reward = mean_reward
# Trigger callback if needed
if self.callback is not None:
return self._on_event()
return True

View File

@ -0,0 +1,272 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.distributions import Distribution
from torch import nn
from torch.distributions import Categorical
from torch.distributions.utils import logits_to_probs
class MaskableCategorical(Categorical):
"""
Modified PyTorch Categorical distribution with support for invalid action masking.
To instantiate, must provide either probs or logits, but not both.
:param probs: Tensor containing finite non-negative values, which will be renormalized
to sum to 1 along the last dimension.
:param logits: Tensor of unnormalized log probabilities.
:param validate_args: Whether or not to validate that arguments to methods like lob_prob()
and icdf() match the distribution's shape, support, etc.
:param masks: An optional boolean ndarray of compatible shape with the distribution.
If True, the corresponding choice's logit value is preserved. If False, it is set to a
large negative value, resulting in near 0 probability.
"""
def __init__(
self,
probs: Optional[th.Tensor] = None,
logits: Optional[th.Tensor] = None,
validate_args: Optional[bool] = None,
masks: Optional[np.ndarray] = None,
):
self.masks: Optional[th.Tensor] = None
super().__init__(probs, logits, validate_args)
self._original_logits = self.logits
self.apply_masking(masks)
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
"""
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
:param masks: An optional boolean ndarray of compatible shape with the distribution.
If True, the corresponding choice's logit value is preserved. If False, it is set
to a large negative value, resulting in near 0 probability. If masks is None, any
previously applied masking is removed, and the original logits are restored.
"""
if masks is not None:
device = self.logits.device
self.masks = th.as_tensor(masks, dtype=th.bool, device=device).reshape(self.logits.shape)
HUGE_NEG = th.tensor(-1e8, dtype=self.logits.dtype, device=device)
logits = th.where(self.masks, self._original_logits, HUGE_NEG)
else:
self.masks = None
logits = self._original_logits
# Reinitialize with updated logits
super().__init__(logits=logits)
# self.probs may already be cached, so we must force an update
self.probs = logits_to_probs(self.logits)
def entropy(self) -> th.Tensor:
if self.masks is None:
return super().entropy()
# Highly negative logits don't result in 0 probs, so we must replace
# with 0s to ensure 0 contribution to the distribution's entropy, since
# masked actions possess no uncertainty.
device = self.logits.device
p_log_p = self.logits * self.probs
p_log_p = th.where(self.masks, p_log_p, th.tensor(0.0, device=device))
return -p_log_p.sum(-1)
class MaskableDistribution(Distribution, ABC):
@abstractmethod
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
"""
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.
:param masks: An optional boolean ndarray of compatible shape with the distribution.
If True, the corresponding choice's logit value is preserved. If False, it is set
to a large negative value, resulting in near 0 probability. If masks is None, any
previously applied masking is removed, and the original logits are restored.
"""
class MaskableCategoricalDistribution(MaskableDistribution):
"""
Categorical distribution for discrete actions. Supports invalid action masking.
:param action_dim: Number of discrete actions
"""
def __init__(self, action_dim: int):
super().__init__()
self.distribution: Optional[MaskableCategorical] = None
self.action_dim = action_dim
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits of the Categorical distribution.
You can then get probabilities using a softmax.
:param latent_dim: Dimension of the last layer
of the policy network (before the action layer)
:return:
"""
action_logits = nn.Linear(latent_dim, self.action_dim)
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableCategoricalDistribution":
# Restructure shape to align with logits
reshaped_logits = action_logits.view(-1, self.action_dim)
self.distribution = MaskableCategorical(logits=reshaped_logits)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
assert self.distribution is not None, "Must set distribution parameters"
return self.distribution.log_prob(actions)
def entropy(self) -> th.Tensor:
assert self.distribution is not None, "Must set distribution parameters"
return self.distribution.entropy()
def sample(self) -> th.Tensor:
assert self.distribution is not None, "Must set distribution parameters"
return self.distribution.sample()
def mode(self) -> th.Tensor:
assert self.distribution is not None, "Must set distribution parameters"
return th.argmax(self.distribution.probs, dim=1)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
assert self.distribution is not None, "Must set distribution parameters"
self.distribution.apply_masking(masks)
class MaskableMultiCategoricalDistribution(MaskableDistribution):
"""
MultiCategorical distribution for multi discrete actions. Supports invalid action masking.
:param action_dims: List of sizes of discrete action spaces
"""
def __init__(self, action_dims: List[int]):
super().__init__()
self.distributions: List[MaskableCategorical] = []
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits (flattened) of the MultiCategorical distribution.
You can then get probabilities using a softmax on each sub-space.
:param latent_dim: Dimension of the last layer
of the policy network (before the action layer)
:return:
"""
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "MaskableMultiCategoricalDistribution":
# Restructure shape to align with logits
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
self.distributions = [
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1)
]
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
assert len(self.distributions) > 0, "Must set distribution parameters"
# Restructure shape to align with each categorical
actions = actions.view(-1, len(self.action_dims))
# Extract each discrete action and compute log prob for their respective distributions
return th.stack(
[dist.log_prob(action) for dist, action in zip(self.distributions, th.unbind(actions, dim=1))], dim=1
).sum(dim=1)
def entropy(self) -> th.Tensor:
assert len(self.distributions) > 0, "Must set distribution parameters"
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
def sample(self) -> th.Tensor:
assert len(self.distributions) > 0, "Must set distribution parameters"
return th.stack([dist.sample() for dist in self.distributions], dim=1)
def mode(self) -> th.Tensor:
assert len(self.distributions) > 0, "Must set distribution parameters"
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
assert len(self.distributions) > 0, "Must set distribution parameters"
split_masks = [None] * len(self.distributions)
if masks is not None:
masks = th.as_tensor(masks)
# Restructure shape to align with logits
masks = masks.view(-1, sum(self.action_dims))
# Then split columnwise for each discrete action
split_masks = th.split(masks, tuple(self.action_dims), dim=1)
for distribution, mask in zip(self.distributions, split_masks):
distribution.apply_masking(mask)
class MaskableBernoulliDistribution(MaskableMultiCategoricalDistribution):
"""
Bernoulli distribution for multibinary actions. Supports invalid action masking.
:param action_dim: Number of binary actions
"""
def __init__(self, action_dim: int):
# Two states per binary action
action_dims = [2] * action_dim
super().__init__(action_dims)
def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistribution:
"""
Return an instance of Distribution for the correct type of action space
:param action_space: the input action space
:return: the appropriate Distribution object
"""
if isinstance(action_space, spaces.Discrete):
return MaskableCategoricalDistribution(action_space.n)
elif isinstance(action_space, spaces.MultiDiscrete):
return MaskableMultiCategoricalDistribution(action_space.nvec)
elif isinstance(action_space, spaces.MultiBinary):
return MaskableBernoulliDistribution(action_space.n)
else:
raise NotImplementedError(
"Error: probability distribution, not implemented for action space"
f"of type {type(action_space)}."
" Must be of type Gym Spaces: Discrete, MultiDiscrete."
)

View File

@ -0,0 +1,148 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import gym
import numpy as np
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
from sb3_contrib.ppo_mask import MaskablePPO
def evaluate_policy( # noqa: C901
model: MaskablePPO,
env: Union[gym.Env, VecEnv],
n_eval_episodes: int = 10,
deterministic: bool = True,
render: bool = False,
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
warn: bool = True,
use_masking: bool = True,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
"""
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
If a vector env is passed in, this divides the episodes to evaluate onto the
different elements of the vector env. This static division of work is done to
remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
details and discussion.
.. note::
If environment has not been wrapped with ``Monitor`` wrapper, reward and
episode lengths are counted as it appears with ``env.step`` calls. If
the environment contains wrappers that modify rewards or episode lengths
(e.g. reward scaling, early episode reset), these will affect the evaluation
results as well. You can avoid this by wrapping environment with ``Monitor``
wrapper before anything else.
:param model: The RL agent you want to evaluate.
:param env: The gym environment. In the case of a ``VecEnv``
this must contain only one environment.
:param n_eval_episodes: Number of episode to evaluate the agent
:param deterministic: Whether to use deterministic or stochastic actions
:param render: Whether to render the environment or not
:param callback: callback function to do additional checks,
called after each step. Gets locals() and globals() passed as parameters.
:param reward_threshold: Minimum expected reward per episode,
this will raise an error if the performance is not met
:param return_episode_rewards: If True, a list of rewards and episde lengths
per episode will be returned instead of the mean.
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
evaluation environment.
:param use_masking: Whether or not to use invalid action masks during evaluation
:return: Mean reward per episode, std of reward per episode.
Returns ([float], [int]) when ``return_episode_rewards`` is True, first
list containing per-episode rewards and second containing per-episode lengths
(in number of steps).
"""
if use_masking and not is_masking_supported(env):
raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper")
is_monitor_wrapped = False
if not isinstance(env, VecEnv):
env = DummyVecEnv([lambda: env])
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
if not is_monitor_wrapped and warn:
warnings.warn(
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
"This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
"Consider wrapping environment first with ``Monitor`` wrapper.",
UserWarning,
)
n_envs = env.num_envs
episode_rewards = []
episode_lengths = []
episode_counts = np.zeros(n_envs, dtype="int")
# Divides episodes among different sub environments in the vector as evenly as possible
episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")
current_rewards = np.zeros(n_envs)
current_lengths = np.zeros(n_envs, dtype="int")
observations = env.reset()
states = None
while (episode_counts < episode_count_targets).any():
if use_masking:
action_masks = get_action_masks(env)
actions, state = model.predict(
observations,
state=states,
deterministic=deterministic,
action_masks=action_masks,
)
else:
actions, states = model.predict(observations, state=states, deterministic=deterministic)
observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1
for i in range(n_envs):
if episode_counts[i] < episode_count_targets[i]:
# unpack values so that the callback can access the local variables
reward = rewards[i]
done = dones[i]
info = infos[i]
if callback is not None:
callback(locals(), globals())
if dones[i]:
if is_monitor_wrapped:
# Atari wrapper can send a "done" signal when
# the agent loses a life, but it does not correspond
# to the true end of episode
if "episode" in info.keys():
# Do not trust "done" with episode endings.
# Monitor wrapper includes "episode" key in info if environment
# has been wrapped with it. Use those rewards instead.
episode_rewards.append(info["episode"]["r"])
episode_lengths.append(info["episode"]["l"])
# Only increment at the real end of an episode
episode_counts[i] += 1
else:
episode_rewards.append(current_rewards[i])
episode_lengths.append(current_lengths[i])
episode_counts[i] += 1
current_rewards[i] = 0
current_lengths[i] = 0
if states is not None:
states[i] *= 0
if render:
env.render()
mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)
if reward_threshold is not None:
assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
if return_episode_rewards:
return episode_rewards, episode_lengths
return mean_reward, std_reward

View File

@ -0,0 +1,402 @@
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
MlpExtractor,
NatureCNN,
)
from stable_baselines3.common.type_aliases import Schedule
from torch import nn
from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution
class MaskableActorCriticPolicy(BasePolicy):
"""
Policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in Adam optimizer
if optimizer_class == th.optim.Adam:
optimizer_kwargs["eps"] = 1e-5
super().__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=False,
)
# Default network architecture, from stable-baselines
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
self.net_arch = net_arch
self.activation_fn = activation_fn
self.ortho_init = ortho_init
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.normalize_images = normalize_images
# Action distribution
self.action_dist = make_masked_proba_distribution(action_space)
self._build(lr_schedule)
def forward(
self,
obs: th.Tensor,
deterministic: bool = False,
action_masks: Optional[np.ndarray] = None,
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Forward pass in all the networks (actor and critic)
:param obs: Observation
:param deterministic: Whether to sample or use deterministic actions
:param action_masks: Action masks to apply to the action distribution
:return: action, value and log probability of the action
"""
latent_pi, latent_vf = self._get_latent(obs)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
if action_masks is not None:
distribution.apply_masking(action_masks)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
return actions, values, log_prob
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
activation_fn=self.activation_fn,
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
def _build_mlp_extractor(self) -> None:
"""
Create the policy and value networks.
Part of the layers can be shared.
"""
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MlpExtractor(
self.features_dim,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
device=self.device,
)
def _build(self, lr_schedule: Schedule) -> None:
"""
Create the networks and the optimizer.
:param lr_schedule: Learning rate schedule
lr_schedule(1) is the initial learning rate
"""
self._build_mlp_extractor()
self.action_net = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi)
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
# Init weights: use orthogonal initialization
# with small initial weight for the output
if self.ortho_init:
# TODO: check for features_extractor
# Values from stable-baselines.
# features_extractor/mlp values are
# originally from openai/baselines (default gains/init_scales).
module_gains = {
self.features_extractor: np.sqrt(2),
self.mlp_extractor: np.sqrt(2),
self.action_net: 0.01,
self.value_net: 1,
}
for module, gain in module_gains.items():
module.apply(partial(self.init_weights, gain=gain))
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
Get the latent code (i.e., activations of the last layer of each network)
for the different networks.
:param obs: Observation
:return: Latent codes
for the actor, the value function and for gSDE function
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
return latent_pi, latent_vf
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
"""
Retrieve action distribution given the latent codes.
:param latent_pi: Latent code for the actor
:return: Action distribution
"""
action_logits = self.action_net(latent_pi)
return self.action_dist.proba_distribution(action_logits=action_logits)
def _predict(
self,
observation: th.Tensor,
deterministic: bool = False,
action_masks: Optional[np.ndarray] = None,
) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
:param observation:
:param deterministic: Whether to use stochastic or deterministic actions
:param action_masks: Action masks to apply to the action distribution
:return: Taken action according to the policy
"""
latent_pi, _ = self._get_latent(observation)
distribution = self._get_action_dist_from_latent(latent_pi)
if action_masks is not None:
distribution.apply_masking(action_masks)
return distribution.get_actions(deterministic=deterministic)
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
deterministic: bool = False,
action_masks: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Get the policy action and state from an observation (and optional state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:param action_masks: Action masks to apply to the action distribution
:return: the model's action and the next state
(used in recurrent policies)
"""
# TODO (GH/1): add support for RNN policies
# if state is None:
# state = self.initial_state
# if mask is None:
# mask = [False for _ in range(self.n_envs)]
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)
observation, vectorized_env = self.obs_to_tensor(observation)
with th.no_grad():
actions = self._predict(observation, deterministic=deterministic, action_masks=action_masks)
# Convert to numpy
actions = actions.cpu().numpy()
if isinstance(self.action_space, gym.spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
actions = actions[0]
return actions, state
def evaluate_actions(
self,
obs: th.Tensor,
actions: th.Tensor,
action_masks: Optional[np.ndarray] = None,
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Evaluate actions according to the current policy,
given the observations.
:param obs:
:param actions:
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
latent_pi, latent_vf = self._get_latent(obs)
distribution = self._get_action_dist_from_latent(latent_pi)
if action_masks is not None:
distribution.apply_masking(action_masks)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
return values, log_prob, distribution.entropy()
class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
"""
CNN policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(MaskableActorCriticCnnPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
"""
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space (Tuple)
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param features_extractor_class: Uses the CombinedExtractor
:param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Dict,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(MaskableMultiInputActorCriticPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)

View File

@ -0,0 +1,38 @@
import numpy as np
from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.vec_env import VecEnv
EXPECTED_METHOD_NAME = "action_masks"
def get_action_masks(env: GymEnv) -> np.ndarray:
"""
Checks whether gym env exposes a method returning invalid action masks
:param env: the Gym environment to get masks from
:return: A numpy array of the masks
"""
if isinstance(env, VecEnv):
return np.stack(env.env_method(EXPECTED_METHOD_NAME))
else:
return getattr(env, EXPECTED_METHOD_NAME)()
def is_masking_supported(env: GymEnv) -> bool:
"""
Checks whether gym env exposes a method returning invalid action masks
:param env: the Gym environment to check
:return: True if the method is found, False otherwise
"""
if isinstance(env, VecEnv):
try:
# TODO: add VecEnv.has_attr()
env.get_attr(EXPECTED_METHOD_NAME)
return True
except AttributeError:
return False
else:
return hasattr(env, EXPECTED_METHOD_NAME)

View File

@ -1 +1,2 @@
from sb3_contrib.common.wrappers.action_masker import ActionMasker
from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper

View File

@ -0,0 +1,32 @@
from typing import Callable, Union
import gym
import numpy as np
class ActionMasker(gym.Wrapper):
"""
Env wrapper providing the method required to support masking.
Exposes a method called action_masks(), which returns masks for the wrapped env.
This wrapper is not needed if the env exposes the expected method itself.
:param env: the Gym environment to wrap
:param action_mask_fn: A function that takes a Gym environment and returns an action mask,
or the name of such a method provided by the environment.
"""
def __init__(self, env: gym.Env, action_mask_fn: Union[str, Callable[[gym.Env], np.ndarray]]):
super().__init__(env)
if isinstance(action_mask_fn, str):
found_method = getattr(self.env, action_mask_fn)
if not callable(found_method):
raise ValueError(f"Environment attribute {action_mask_fn} is not a method")
self._action_mask_fn = found_method
else:
self._action_mask_fn = action_mask_fn
def action_masks(self) -> np.ndarray:
return self._action_mask_fn(self.env)

View File

@ -0,0 +1,2 @@
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy # , MultiInputPolicy
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO

View File

@ -0,0 +1,15 @@
from stable_baselines3.common.policies import register_policy
from sb3_contrib.common.maskable.policies import ( # MaskableMultiInputActorCriticPolicy,
MaskableActorCriticCnnPolicy,
MaskableActorCriticPolicy,
)
MlpPolicy = MaskableActorCriticPolicy
CnnPolicy = MaskableActorCriticCnnPolicy
# MultiInputPolicy = MaskableMultiInputActorCriticPolicy
register_policy("MlpPolicy", MaskableActorCriticPolicy)
register_policy("CnnPolicy", MaskableActorCriticCnnPolicy)
# Currently not supported
# register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)

View File

@ -0,0 +1,551 @@
import time
from collections import deque
from typing import Any, Dict, Optional, Tuple, Type, Union
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
from torch.nn import functional as F
from sb3_contrib.common.maskable.buffers import MaskableRolloutBuffer
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
class MaskablePPO(OnPolicyAlgorithm):
"""
Proximal Policy Optimization algorithm (PPO) (clip version) with Invalid Action Masking.
Based on the original Stable Baselines 3 implementation.
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
Background on Invalid Action Masking: https://arxiv.org/abs/2006.14171
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param n_steps: The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param batch_size: Minibatch size
:param n_epochs: Number of epoch when optimizing the surrogate loss
:param gamma: Discount factor
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
:param clip_range: Clipping parameter, it can be a function of the current progress
remaining (from 1 to 0).
:param clip_range_vf: Clipping parameter for the value function,
it can be a function of the current progress remaining (from 1 to 0).
This is a parameter specific to the OpenAI implementation. If None is passed (default),
no clipping will be done on the value function.
IMPORTANT: this clipping depends on the reward scaling.
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
:param target_kl: Limit the KL divergence between updates,
because the clipping is not enough to prevent large update
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
By default, there is no limit on the kl div.
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
def __init__(
self,
policy: Union[str, Type[MaskableActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
n_steps: int = 2048,
batch_size: Optional[int] = 64,
n_epochs: int = 10,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_range: Union[float, Schedule] = 0.2,
clip_range_vf: Union[None, float, Schedule] = None,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
target_kl: Optional[float] = None,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate=learning_rate,
n_steps=n_steps,
gamma=gamma,
gae_lambda=gae_lambda,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
use_sde=False,
sde_sample_freq=-1,
tensorboard_log=tensorboard_log,
create_eval_env=create_eval_env,
policy_kwargs=policy_kwargs,
policy_base=MaskableActorCriticPolicy,
verbose=verbose,
seed=seed,
device=device,
_init_setup_model=False,
supported_action_spaces=(
spaces.Discrete,
spaces.MultiDiscrete,
spaces.MultiBinary,
),
)
self.batch_size = batch_size
self.n_epochs = n_epochs
self.clip_range = clip_range
self.clip_range_vf = clip_range_vf
self.target_kl = target_kl
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
self.policy = self.policy_class(
self.observation_space,
self.action_space,
self.lr_schedule,
**self.policy_kwargs, # pytype:disable=not-instantiable
)
self.policy = self.policy.to(self.device)
if not isinstance(self.policy, MaskableActorCriticPolicy):
raise ValueError("Policy must subclass MaskableActorCriticPolicy")
self.rollout_buffer = MaskableRolloutBuffer(
self.n_steps,
self.observation_space,
self.action_space,
self.device,
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
)
# Initialize schedules for policy/value clipping
self.clip_range = get_schedule_fn(self.clip_range)
if self.clip_range_vf is not None:
if isinstance(self.clip_range_vf, (float, int)):
assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
def _init_callback(
self,
callback: MaybeCallback,
eval_env: Optional[VecEnv] = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
use_masking: bool = True,
) -> BaseCallback:
"""
:param callback: Callback(s) called at every step with state of the algorithm.
:param eval_freq: How many steps between evaluations; if None, do not evaluate.
:param n_eval_episodes: How many episodes to play per evaluation
:param n_eval_episodes: Number of episodes to rollout during evaluation.
:param log_path: Path to a folder where the evaluations will be saved
:param use_masking: Whether or not to use invalid action masks during evaluation
:return: A hybrid callback calling `callback` and performing evaluation.
"""
# Convert a list of callbacks into a callback
if isinstance(callback, list):
callback = CallbackList(callback)
# Convert functional callback to object
if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback)
# Create eval callback in charge of the evaluation
if eval_env is not None:
# Avoid circular import error
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
eval_callback = MaskableEvalCallback(
eval_env,
best_model_save_path=log_path,
log_path=log_path,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
use_masking=use_masking,
)
callback = CallbackList([callback, eval_callback])
callback.init_callback(self)
return callback
def _setup_learn(
self,
total_timesteps: int,
eval_env: Optional[GymEnv],
callback: MaybeCallback = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
use_masking: bool = True,
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on
:param eval_env: Environment to use for evaluation.
:param callback: Callback(s) called at every step with state of the algorithm.
:param eval_freq: How many steps between evaluations
:param n_eval_episodes: How many episodes to play per evaluation
:param log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:param use_masking: Whether or not to use invalid action masks during training
:return:
"""
self.start_time = time.time()
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
self.ep_info_buffer = deque(maxlen=100)
self.ep_success_buffer = deque(maxlen=100)
if reset_num_timesteps:
self.num_timesteps = 0
self._episode_num = 0
else:
# Make sure training timesteps are ahead of the internal counter
total_timesteps += self.num_timesteps
self._total_timesteps = total_timesteps
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset()
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
if eval_env is not None and self.seed is not None:
eval_env.seed(self.seed)
eval_env = self._get_eval_env(eval_env)
# Configure logger's outputs if no logger was passed
if not self._custom_logger:
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
# Create eval callback if needed
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking)
return total_timesteps, callback
def collect_rollouts(
self,
env: VecEnv,
callback: BaseCallback,
rollout_buffer: RolloutBuffer,
n_rollout_steps: int,
use_masking: bool = True,
) -> bool:
"""
Collect experiences using the current policy and fill a ``RolloutBuffer``.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
This method is largely identical to the implementation found in the parent class.
:param env: The training environment
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param rollout_buffer: Buffer to fill with rollouts
:param n_steps: Number of experiences to collect per environment
:param use_masking: Whether or not to use invalid action masks during training
:return: True if function returned with at least `n_rollout_steps`
collected, False if callback terminated rollout prematurely.
"""
assert isinstance(rollout_buffer, MaskableRolloutBuffer), "RolloutBuffer doesn't support action masking"
assert self._last_obs is not None, "No previous observation was provided"
# Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False)
n_steps = 0
action_masks = None
rollout_buffer.reset()
if use_masking and not is_masking_supported(env):
raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper")
callback.on_rollout_start()
while n_steps < n_rollout_steps:
with th.no_grad():
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
# This is the only change related to invalid action masking
if use_masking:
action_masks = get_action_masks(env)
actions, values, log_probs = self.policy.forward(obs_tensor, action_masks=action_masks)
actions = actions.cpu().numpy()
new_obs, rewards, dones, infos = env.step(actions)
self.num_timesteps += env.num_envs
# Give access to local variables
callback.update_locals(locals())
if callback.on_step() is False:
return False
self._update_info_buffer(infos)
n_steps += 1
if isinstance(self.action_space, spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
rollout_buffer.add(
self._last_obs,
actions,
rewards,
self._last_episode_starts,
values,
log_probs,
action_masks=action_masks,
)
self._last_obs = new_obs
self._last_episode_starts = dones
with th.no_grad():
# Compute value for the last timestep
obs_tensor = obs_as_tensor(new_obs, self.device)
# Masking is not needed here, the choice of action doesn't matter.
# We only want the value of the current observation.
_, values, _ = self.policy.forward(obs_tensor)
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
callback.on_rollout_end()
return True
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
deterministic: bool = False,
action_masks: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Get the model's action(s) from an observation.
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:param action_masks: Action masks to apply to the action distribution.
:return: the model's action and the next state (used in recurrent policies)
"""
return self.policy.predict(observation, state, mask, deterministic, action_masks=action_masks)
def train(self) -> None:
"""
Update policy using the currently gathered rollout buffer.
"""
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# Compute current clip range
clip_range = self.clip_range(self._current_progress_remaining)
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
entropy_losses = []
pg_losses, value_losses = [], []
clip_fractions = []
continue_training = True
# train for n_epochs epochs
for epoch in range(self.n_epochs):
approx_kl_divs = []
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()
values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations,
actions,
action_masks=rollout_data.action_masks,
)
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# clipped surrogate loss
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
# Logging
pg_losses.append(policy_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
# No clipping
values_pred = values
else:
# Clip the different between old and new value
# NOTE: this depends on the reward scaling
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
)
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values_pred)
value_losses.append(value_loss.item())
# Entropy loss favor exploration
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = -th.mean(entropy)
entropy_losses.append(entropy_loss.item())
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
# Calculate approximate form of reverse KL Divergence for early stopping
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
# and Schulman blog: http://joschu.net/blog/kl-approx.html
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
if not continue_training:
break
self._n_updates += self.n_epochs
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
# Logs
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
self.logger.record("train/loss", loss.item())
self.logger.record("train/explained_variance", explained_var)
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/clip_range", clip_range)
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "PPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
use_masking: bool = True,
) -> "MaskablePPO":
iteration = 0
total_timesteps, callback = self._setup_learn(
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps,
tb_log_name,
use_masking,
)
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
if continue_training is False:
break
iteration += 1
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
fps = int(self.num_timesteps / (time.time() - self.start_time))
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
self.train()
callback.on_training_end()
return self

View File

@ -1 +1 @@
1.2.0
1.2.1a0

View File

@ -22,9 +22,12 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators
# Ignore import not used when aliases are defined
per-file-ignores =
./sb3_contrib/__init__.py:F401
./sb3_contrib/ppo_mask/__init__.py:F401
./sb3_contrib/qrdqn/__init__.py:F401
./sb3_contrib/tqc/__init__.py:F401
./sb3_contrib/common/vec_env/wrappers/__init__.py:F401
./sb3_contrib/common/wrappers/__init__.py:F401
./sb3_contrib/common/envs/__init__.py:F401
exclude =
# No need to traverse our git directory
.git,

316
tests/test_distributions.py Normal file
View File

@ -0,0 +1,316 @@
import numpy as np
import pytest
import torch as th
from sb3_contrib.common.maskable.distributions import (
MaskableBernoulliDistribution,
MaskableCategorical,
MaskableCategoricalDistribution,
MaskableMultiCategoricalDistribution,
)
class TestMaskableCategorical:
def test_applying_mask(self):
"""
Show that probs change as a result of masking
"""
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
expected_probs = th.Tensor([[0, 0.25, 0.75], [0, 0.5, 0.5]])
mask = np.array([[False, True, True], [False, True, True]])
distribution = MaskableCategorical(probs=starting_probs)
distribution.apply_masking(mask)
assert th.allclose(distribution.probs, expected_probs)
def test_modifying_mask(self):
"""
Show that masks apply independently of each other
"""
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
expected_probs = th.Tensor([[0.5, 0.5, 0], [0, 1, 0]])
first_mask = np.array([[False, True, True], [False, True, True]])
second_mask = np.array([[True, True, False], [False, True, False]])
# pytorch converts probs to logits in a way that loses some precision and makes
# 0 probability outcomes slightly non-zero.
atol = 2e-07
distribution = MaskableCategorical(probs=starting_probs)
assert th.allclose(distribution.probs, starting_probs, atol=atol)
target_distribution = MaskableCategorical(probs=expected_probs)
distribution.apply_masking(first_mask)
distribution.apply_masking(second_mask)
assert th.allclose(distribution.probs, target_distribution.probs, atol=atol)
def test_removing_mask(self):
"""
Show that masking may be unapplied to recover original probs
"""
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
mask = np.array([[False, True, True], [False, True, True]])
distribution = MaskableCategorical(probs=starting_probs)
target_distribution = MaskableCategorical(probs=starting_probs)
distribution.apply_masking(mask)
distribution.apply_masking(None)
assert th.allclose(distribution.probs, target_distribution.probs)
def test_masking_affects_entropy(self):
# All outcomes equally likely
NUM_DIMS = 3
logits = th.Tensor([[0] * NUM_DIMS])
dist = MaskableCategorical(logits=logits)
# For each possible number of valid actions v, show that e^entropy == v
for v in range(1, NUM_DIMS + 1):
masks = [j < v for j in range(NUM_DIMS)]
dist.apply_masking(masks)
assert int(dist.entropy().exp()) == v
class TestMaskableCategoricalDistribution:
def test_distribution_must_be_initialized(self):
"""
Cannot use distribution before it has logits
"""
DIMS = 2
dist = MaskableCategoricalDistribution(DIMS)
with pytest.raises(AssertionError):
dist.log_prob(th.randint(DIMS - 1, (1, 3)))
with pytest.raises(AssertionError):
dist.entropy()
with pytest.raises(AssertionError):
dist.sample()
with pytest.raises(AssertionError):
dist.mode()
with pytest.raises(AssertionError):
dist.apply_masking(None)
# But now we can
dist.proba_distribution(th.randn(1, DIMS))
dist.log_prob(th.randint(DIMS - 1, (3, 1)))
dist.entropy()
dist.sample()
dist.mode()
dist.apply_masking(None)
def test_logits_must_align_with_dims(self):
NUM_DIMS = 3
dist = MaskableCategoricalDistribution(NUM_DIMS)
# There should be one logit per dim, we're one short
logits = th.randn(1, NUM_DIMS - 1)
with pytest.raises(RuntimeError):
dist.proba_distribution(logits)
# That's better
logits = th.randn(1, NUM_DIMS)
dist.proba_distribution(logits)
# Other numbers of dimensions are acceptable as long as they can be realigned
logits = th.randn(NUM_DIMS)
dist.proba_distribution(logits)
logits = th.randn(3, NUM_DIMS, 3)
dist.proba_distribution(logits)
def test_dim_masking(self):
NUM_DIMS = 2
dist = MaskableCategoricalDistribution(NUM_DIMS)
logits = th.Tensor([[0] * NUM_DIMS])
dist.proba_distribution(logits)
assert (dist.distribution.probs == 0.5).all()
assert int(dist.entropy().exp()) == NUM_DIMS
for i in range(NUM_DIMS):
mask = np.array([False] * NUM_DIMS)
mask[i] = True
dist.apply_masking(mask)
probs = dist.distribution.probs
assert probs.sum() == 1
assert probs[0][i] == 1
assert int(dist.entropy().exp()) == 1
dist.apply_masking(None)
assert (dist.distribution.probs == 0.5).all()
assert int(dist.entropy().exp()) == NUM_DIMS
class TestMaskableMultiCategoricalDistribution:
def test_distribution_must_be_initialized(self):
"""
Cannot use distribution before it has logits
"""
DIMS_PER_CAT = 2
NUM_CATS = 2
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
with pytest.raises(AssertionError):
dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS)))
with pytest.raises(AssertionError):
dist.entropy()
with pytest.raises(AssertionError):
dist.sample()
with pytest.raises(AssertionError):
dist.mode()
with pytest.raises(AssertionError):
dist.apply_masking(None)
# But now we can
dist.proba_distribution(th.randn(1, DIMS_PER_CAT * NUM_CATS))
dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS)))
dist.entropy()
dist.sample()
dist.mode()
dist.apply_masking(None)
def test_logits_must_align_with_dims(self):
DIMS_PER_CAT = 3
NUM_CATS = 2
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
# There should be one logit per dim, we're one short
logits = th.randn(1, DIMS_PER_CAT * NUM_CATS - 1)
with pytest.raises(RuntimeError):
dist.proba_distribution(logits)
# That's better
logits = th.randn(1, DIMS_PER_CAT * NUM_CATS)
dist.proba_distribution(logits)
# Other numbers of dimensions are acceptable as long as they can be realigned
logits = th.randn(DIMS_PER_CAT * NUM_CATS)
dist.proba_distribution(logits)
logits = th.randn(3, DIMS_PER_CAT * NUM_CATS, 3)
dist.proba_distribution(logits)
def test_dim_masking(self):
DIMS_PER_CAT = 2
NUM_CATS = 3
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
logits = th.Tensor([[0] * DIMS_PER_CAT * NUM_CATS])
dist.proba_distribution(logits)
assert len(dist.distributions) == NUM_CATS
for i in range(NUM_CATS):
assert (dist.distributions[i].probs == 0.5).all()
assert int(dist.entropy().exp()) == DIMS_PER_CAT ** NUM_CATS
for i in range(DIMS_PER_CAT):
mask = np.array([False] * DIMS_PER_CAT * NUM_CATS)
for j in range(NUM_CATS):
mask[j * DIMS_PER_CAT + i] = True
dist.apply_masking(mask)
for j in range(NUM_CATS):
probs = dist.distributions[j].probs
assert probs.sum() == 1
assert probs[0][i] == 1
assert int(dist.entropy().exp()) == 1
dist.apply_masking(None)
for i in range(NUM_CATS):
assert (dist.distributions[i].probs == 0.5).all()
assert int(dist.entropy().exp()) == DIMS_PER_CAT ** NUM_CATS
class TestMaskableBernoulliDistribution:
def test_distribution_must_be_initialized(self):
"""
Cannot use distribution before it has logits
"""
DIMS = 2
dist = MaskableBernoulliDistribution(DIMS)
with pytest.raises(AssertionError):
dist.log_prob(th.randint(1, (2, DIMS)))
with pytest.raises(AssertionError):
dist.entropy()
with pytest.raises(AssertionError):
dist.sample()
with pytest.raises(AssertionError):
dist.mode()
with pytest.raises(AssertionError):
dist.apply_masking(None)
# But now we can
dist.proba_distribution(th.randn(1, 2 * DIMS))
dist.log_prob(th.randint(1, (2, DIMS)))
dist.entropy()
dist.sample()
dist.mode()
dist.apply_masking(None)
def test_logits_must_align_with_dims(self):
NUM_DIMS = 3
dist = MaskableBernoulliDistribution(NUM_DIMS)
# There should be two logits per dim, we're one short
logits = th.randn(1, 2 * NUM_DIMS - 1)
with pytest.raises(RuntimeError):
dist.proba_distribution(logits)
# That's better
logits = th.randn(1, 2 * NUM_DIMS)
dist.proba_distribution(logits)
# Other numbers of dimensions are acceptable as long as they can be realigned
logits = th.randn(2 * NUM_DIMS)
dist.proba_distribution(logits)
logits = th.randn(3, 2 * NUM_DIMS, 3)
dist.proba_distribution(logits)
def test_dim_masking(self):
NUM_DIMS = 2
BINARY_STATES = 2
dist = MaskableBernoulliDistribution(NUM_DIMS)
logits = th.Tensor([[0] * BINARY_STATES * NUM_DIMS])
dist.proba_distribution(logits)
assert len(dist.distributions) == NUM_DIMS
for i in range(NUM_DIMS):
assert (dist.distributions[i].probs == 0.5).all()
assert int(dist.entropy().exp()) == BINARY_STATES * NUM_DIMS
for i in range(BINARY_STATES):
mask = np.array([False] * BINARY_STATES * NUM_DIMS)
for j in range(NUM_DIMS):
mask[j * BINARY_STATES + i] = True
dist.apply_masking(mask)
for j in range(NUM_DIMS):
probs = dist.distributions[j].probs
assert probs.sum() == 1
assert probs[0][i] == 1
assert int(dist.entropy().exp()) == 1
dist.apply_masking(None)
for i in range(NUM_DIMS):
assert (dist.distributions[i].probs == 0.5).all()
assert int(dist.entropy().exp()) == BINARY_STATES * NUM_DIMS

View File

@ -0,0 +1,232 @@
import random
import gym
import pytest
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.policies import ActorCriticPolicy
from sb3_contrib import MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete, InvalidActionEnvMultiBinary, InvalidActionEnvMultiDiscrete
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.utils import is_masking_supported
from sb3_contrib.common.wrappers import ActionMasker
def make_env():
return InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
class ToDictWrapper(gym.Wrapper):
"""
Simple wrapper to test MultInputPolicy on Dict obs.
"""
def __init__(self, env):
super().__init__(env)
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space})
def reset(self):
return {"obs": self.env.reset()}
def step(self, action):
obs, reward, done, infos = self.env.step(action)
return {"obs": obs}, reward, done, infos
def test_identity():
"""
Performance test.
A randomly initialized model cannot solve that task (score ~=6),
nor a model without invalid action masking (score ~=30 after training)
which such a low training budget.
"""
env = InvalidActionEnvDiscrete(dim=70, n_invalid_actions=55)
model = MaskablePPO(
"MlpPolicy",
env,
gamma=0.4,
seed=32,
verbose=0,
)
model.learn(3000)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
def test_supports_discrete_action_space():
"""
No errors using algorithm with an env that has a discrete action space
"""
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8)
model.learn(100)
evaluate_policy(model, env, warn=False)
# Mask all actions except the good one, a random model should succeed
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19)
model = MaskablePPO("MlpPolicy", env, seed=8)
evaluate_policy(model, env, reward_threshold=99, warn=False)
def test_supports_multi_discrete_action_space():
"""
No errors using algorithm with an env that has a multidiscrete action space
"""
env = InvalidActionEnvMultiDiscrete(dims=[2, 3], n_invalid_actions=1)
model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8)
model.learn(100)
evaluate_policy(model, env, warn=False)
# Mask all actions except the good ones, a random model should succeed
env = InvalidActionEnvMultiDiscrete(dims=[2, 3], n_invalid_actions=3)
model = MaskablePPO("MlpPolicy", env, seed=8)
evaluate_policy(model, env, reward_threshold=99, warn=False)
def test_supports_multi_binary_action_space():
"""
No errors using algorithm with an env that has a multidiscrete action space
"""
env = InvalidActionEnvMultiBinary(dims=3, n_invalid_actions=1)
model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8)
model.learn(100)
evaluate_policy(model, env, warn=False)
# Mask all actions except the good ones, a random model should succeed
env = InvalidActionEnvMultiBinary(dims=3, n_invalid_actions=3)
model = MaskablePPO("MlpPolicy", env, seed=8)
evaluate_policy(model, env, reward_threshold=99, warn=False)
def test_disabling_masking():
"""
Behave like normal PPO if masking is disabled, which allows for envs that don't provide masks
"""
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19)
model = MaskablePPO("MlpPolicy", env, seed=8)
evaluate_policy(model, env, reward_threshold=99, warn=False)
# With masking disabled, perfect performance disappears
with pytest.raises(AssertionError):
evaluate_policy(model, env, reward_threshold=99, warn=False, use_masking=False)
# Without masking disabled, learning/evaluation will fail if the env doesn't provide masks
env = IdentityEnv(dim=2)
model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8)
with pytest.raises(ValueError):
model.learn(100)
with pytest.raises(ValueError):
evaluate_policy(model, env, warn=False)
model.learn(100, use_masking=False)
evaluate_policy(model, env, warn=False, use_masking=False)
def test_masked_evaluation():
"""
Masking can be enabled or disabled for evaluation, but masking should perform better.
"""
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19)
model = MaskablePPO("MlpPolicy", env, seed=8)
masked_avg_rew, _ = evaluate_policy(model, env, warn=False)
unmasked_avg_rew, _ = evaluate_policy(model, env, warn=False, use_masking=False)
assert masked_avg_rew > unmasked_avg_rew
def test_supports_multi_envs():
"""
Learning and evaluation works with VecEnvs
"""
env = make_vec_env(make_env, n_envs=2)
assert is_masking_supported(env)
model = MaskablePPO("MlpPolicy", env, n_steps=256, gamma=0.4, seed=32, verbose=1)
model.learn(100)
evaluate_policy(model, env, warn=False)
env = make_vec_env(IdentityEnv, n_envs=2, env_kwargs={"dim": 2})
assert not is_masking_supported(env)
model = MaskablePPO("MlpPolicy", env, n_steps=256, gamma=0.4, seed=32, verbose=1)
with pytest.raises(ValueError):
model.learn(100)
with pytest.raises(ValueError):
evaluate_policy(model, env, warn=False)
model.learn(100, use_masking=False)
evaluate_policy(model, env, warn=False, use_masking=False)
def test_callback(tmp_path):
"""
No errors using MaskableEvalCallback during learning
"""
env = make_env()
eval_env = make_env()
model = MaskablePPO("MlpPolicy", env, n_steps=64, gamma=0.4, seed=32, verbose=1)
model.learn(100, callback=MaskableEvalCallback(eval_env, eval_freq=100, warn=False, log_path=tmp_path))
model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False))
def test_maskable_policy_required():
"""
MaskablePPO requires a policy that subclasses MaskableActorCriticPolicy
"""
env = make_env()
with pytest.raises(ValueError):
MaskablePPO(ActorCriticPolicy, env)
def test_discrete_action_space_required():
"""
MaskablePPO requires an env with a discrete (ie non-continuous) action space
"""
env = IdentityEnvBox()
with pytest.raises(AssertionError):
MaskablePPO("MlpPolicy", env)
def test_cnn():
def action_mask_fn(env):
random_invalid_action = random.randrange(env.action_space.n)
return [i != random_invalid_action for i in range(env.action_space.n)]
env = FakeImageEnv()
env = ActionMasker(env, action_mask_fn)
model = MaskablePPO(
"CnnPolicy",
env,
n_steps=64,
seed=32,
verbose=1,
policy_kwargs=dict(
features_extractor_kwargs=dict(features_dim=32),
),
)
model.learn(100)
evaluate_policy(model, env, warn=False)
# Dict observations are currently not supported
# def test_dict_obs():
# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
# env = ToDictWrapper(env)
# model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8)
# model.learn(100)
# evaluate_policy(model, env, warn=False)
#
# # Mask all actions except the good one, a random model should succeed
# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19)
# env = ToDictWrapper(env)
# model = MaskablePPO("MultiInputPolicy", env, seed=8)
# evaluate_policy(model, env, reward_threshold=99, warn=False)

View File

@ -1,3 +1,5 @@
from typing import Union
import gym
import numpy as np
import pytest
@ -6,7 +8,9 @@ import torch.nn as nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from sb3_contrib import QRDQN, TQC
from sb3_contrib import QRDQN, TQC, MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sb3_contrib.common.maskable.utils import get_action_masks
class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
@ -76,12 +80,51 @@ def clone_tqc_batch_norm_stats(
return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean)
def clone_on_policy_batch_norm(model: Union[MaskablePPO]) -> (th.Tensor, th.Tensor):
return clone_batch_norm_stats(model.policy.features_extractor.batch_norm)
CLONE_HELPERS = {
QRDQN: clone_qrdqn_batch_norm_stats,
TQC: clone_tqc_batch_norm_stats,
MaskablePPO: clone_on_policy_batch_norm,
}
def test_ppo_mask_train_eval_mode():
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
model = MaskablePPO(
"MlpPolicy",
env,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
seed=1,
)
bias_before, running_mean_before = clone_on_policy_batch_norm(model)
model.learn(total_timesteps=200)
bias_after, running_mean_after = clone_on_policy_batch_norm(model)
assert ~th.isclose(bias_before, bias_after).all()
assert ~th.isclose(running_mean_before, running_mean_after).all()
batch_norm_stats_before = clone_on_policy_batch_norm(model)
observation = env.reset()
action_masks = get_action_masks(env)
first_prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True)
for _ in range(5):
prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True)
np.testing.assert_allclose(first_prediction, prediction)
batch_norm_stats_after = clone_on_policy_batch_norm(model)
# No change in batch norm params
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
assert th.isclose(param_before, param_after).all()
def test_qrdqn_train_with_batch_norm():
model = QRDQN(
"MlpPolicy",

View File

@ -0,0 +1,78 @@
from typing import List
import pytest
from gym.spaces import Discrete
from stable_baselines3.common.envs import IdentityEnv
from sb3_contrib.common.wrappers import ActionMasker
class IdentityEnvDiscrete(IdentityEnv):
def __init__(self, dim: int = 1, ep_length: int = 100):
"""
Identity environment for testing purposes
:param dim: the size of the dimensions you want to learn
:param ep_length: the length of each episode in timesteps
"""
space = Discrete(dim)
self.useless_property = 1
super().__init__(ep_length=ep_length, space=space)
def _action_masks(self) -> List[int]:
return [i == self.state for i in range(self.action_space.n)]
def action_mask_fn(env: IdentityEnvDiscrete) -> List[int]:
return [i == env.state for i in range(env.action_space.n)]
def test_wrapper_accepts_function():
"""
ActionMasker accepts a function
"""
env = IdentityEnvDiscrete()
assert not hasattr(env, "action_masks")
env = ActionMasker(env, action_mask_fn)
assert hasattr(env, "action_masks")
# Wrapper accepts as string name of a method on the underlying env
def test_wrapper_accepts_attr_name():
"""
ActionMasker accepts a string name of a method on the underlying env
"""
env = IdentityEnvDiscrete()
assert not hasattr(env, "action_masks")
env = ActionMasker(env, "_action_masks")
assert hasattr(env, "action_masks")
def test_attr_must_be_callable():
"""
Passing ActionMasker the string name of a non-callable is an error
"""
env = IdentityEnvDiscrete()
with pytest.raises(ValueError):
env = ActionMasker(env, "useless_property")
# Wrapper method returns expected results
def test_action_masks_returns_expected_result():
"""
ActionMasker-provided action_masks() method returns expected results
"""
env = IdentityEnvDiscrete()
env = ActionMasker(env, action_mask_fn)
# Only one valid action expected
masks = env.action_masks()
masks[env.state] = not masks[env.state] # Bit-flip the one expected valid action
assert all([not mask for mask in masks])