MaskablePPO dictionary observation support (#47)

* Add dictionary observation support for ppo_mask.

* Improving naming consistency.

* Update changelog.

* Reformat and add test

* Update doc

* Update README and setup

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Geoff McDonald 2021-10-23 08:05:37 -07:00 committed by GitHub
parent 91f9b1ed34
commit d6c5cea644
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 185 additions and 69 deletions

View File

@ -51,5 +51,11 @@ Describe the characteristic of your environment:
* Gym version * Gym version
* Versions of any other relevant libraries * Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```python
import stable_baselines3 as sb3
sb3.get_system_info()
```
**Additional context** **Additional context**
Add any other context about the problem here. Add any other context about the problem here.

View File

@ -27,6 +27,7 @@ See documentation for the full list of included features.
**RL Algorithms**: **RL Algorithms**:
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
**Gym Wrappers**: **Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) - [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)

View File

@ -13,7 +13,7 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
from sb3_contrib import TQC from sb3_contrib import TQC
model = TQC("MlpPolicy", "Pendulum-v0", top_quantiles_to_drop_per_net=2, verbose=1) model = TQC("MlpPolicy", "Pendulum-v0", top_quantiles_to_drop_per_net=2, verbose=1)
model.learn(total_timesteps=10000, log_interval=4) model.learn(total_timesteps=10_000, log_interval=4)
model.save("tqc_pendulum") model.save("tqc_pendulum")
QR-DQN QR-DQN
@ -27,40 +27,20 @@ Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.
policy_kwargs = dict(n_quantiles=50) policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000, log_interval=4) model.learn(total_timesteps=10_000, log_interval=4)
model.save("qrdqn_cartpole") model.save("qrdqn_cartpole")
MaskablePPO
-----------
.. PyBullet: Normalizing input features Train a PPO with invalid action masking agent on a toy environment.
.. ------------------------------------
.. .. code-block:: python
.. Normalizing input features may be essential to successful training of an RL agent
.. (by default, images are scaled but not other types of input), from sb3_contrib import MaskablePPO
.. for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and from sb3_contrib.common.envs import InvalidActionEnvDiscrete
.. will compute a running average and standard deviation of input features (it can do the same for rewards).
.. env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
.. model = MaskablePPO("MlpPolicy", env, verbose=1)
.. .. note:: model.learn(5000)
.. model.save("qrdqn_cartpole")
.. you need to install pybullet with ``pip install pybullet``
..
..
.. .. image:: ../_static/img/colab-badge.svg
.. :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
..
..
.. .. code-block:: python
..
.. import gym
.. import pybullet_envs
..
.. from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
.. from stable_baselines3 import PPO
..
.. env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
.. # Automatically normalize the input features and reward
.. env = VecNormalize(env, norm_obs=True, norm_reward=True,
.. clip_obs=10.)
..
.. model = PPO('MlpPolicy', env)
.. model.learn(total_timesteps=2000)

View File

@ -3,7 +3,7 @@
Changelog Changelog
========== ==========
Release 1.2.1a2 (WIP) Release 1.2.1a3 (WIP)
------------------------------- -------------------------------
**Invalid action masking for PPO** **Invalid action masking for PPO**
@ -15,6 +15,7 @@ Breaking Changes:
New Features: New Features:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
- Added ``MaskablePPO`` algorithm (@kronion) - Added ``MaskablePPO`` algorithm (@kronion)
- ``MaskablePPO`` Dictionary Observation support (@glmcdona)
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^
@ -185,4 +186,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
Contributors: Contributors:
------------- -------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona

View File

@ -16,7 +16,7 @@ algorithm. Other than adding support for action masking, the behavior is the sam
MlpPolicy MlpPolicy
CnnPolicy CnnPolicy
.. MultiInputPolicy MultiInputPolicy
Notes Notes
@ -42,7 +42,7 @@ Discrete ✔️ ✔️
Box ❌ ✔️ Box ❌ ✔️
MultiDiscrete ✔️ ✔️ MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️ MultiBinary ✔️ ✔️
Dict ❌ Dict ❌ ✔️
============= ====== =========== ============= ====== ===========
@ -248,9 +248,9 @@ MaskablePPO Policies
:members: :members:
:noindex: :noindex:
.. .. autoclass:: MultiInputPolicy .. autoclass:: MultiInputPolicy
.. :members: :members:
..
.. .. autoclass:: sb3_contrib.common.maskable.policies.MaskableMultiInputActorCriticPolicy .. autoclass:: sb3_contrib.common.maskable.policies.MaskableMultiInputActorCriticPolicy
.. :members: :members:
.. :noindex: :noindex:

View File

@ -1,9 +1,10 @@
from typing import Generator, NamedTuple, Optional from typing import Generator, NamedTuple, Optional, Union
import numpy as np import numpy as np
import torch as th import torch as th
from gym import spaces from gym import spaces
from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.vec_env import VecNormalize from stable_baselines3.common.vec_env import VecNormalize
@ -17,6 +18,16 @@ class MaskableRolloutBufferSamples(NamedTuple):
action_masks: th.Tensor action_masks: th.Tensor
class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples):
observations: TensorDict
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
action_masks: th.Tensor
class MaskableRolloutBuffer(RolloutBuffer): class MaskableRolloutBuffer(RolloutBuffer):
""" """
Rollout buffer that also stores the invalid action masks associated with each observation. Rollout buffer that also stores the invalid action masks associated with each observation.
@ -96,3 +107,104 @@ class MaskableRolloutBuffer(RolloutBuffer):
self.action_masks[batch_inds].reshape(-1, self.mask_dims), self.action_masks[batch_inds].reshape(-1, self.mask_dims),
) )
return MaskableRolloutBufferSamples(*map(self.to_torch, data)) return MaskableRolloutBufferSamples(*map(self.to_torch, data))
class MaskableDictRolloutBuffer(DictRolloutBuffer):
"""
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
Extends the RolloutBuffer to use dictionary observations
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
self.action_masks = None
super(MaskableDictRolloutBuffer, self).__init__(
buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs
)
def reset(self) -> None:
if isinstance(self.action_space, spaces.Discrete):
mask_dims = self.action_space.n
elif isinstance(self.action_space, spaces.MultiDiscrete):
mask_dims = sum(self.action_space.nvec)
elif isinstance(self.action_space, spaces.MultiBinary):
mask_dims = 2 * self.action_space.n # One mask per binary outcome
else:
raise ValueError(f"Unsupported action space {type(self.action_space)}")
self.mask_dims = mask_dims
self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
super(MaskableDictRolloutBuffer, self).reset()
def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
"""
:param action_masks: Masks applied to constrain the choice of possible actions.
"""
if action_masks is not None:
self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
super(MaskableDictRolloutBuffer, self).add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns", "action_masks"]
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
return MaskableDictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),
old_values=self.to_torch(self.values[batch_inds].flatten()),
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
returns=self.to_torch(self.returns[batch_inds].flatten()),
action_masks=self.to_torch(self.action_masks[batch_inds].reshape(-1, self.mask_dims)),
)

View File

@ -1,2 +1,2 @@
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy # , MultiInputPolicy from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO

View File

@ -1,15 +1,15 @@
from stable_baselines3.common.policies import register_policy from stable_baselines3.common.policies import register_policy
from sb3_contrib.common.maskable.policies import ( # MaskableMultiInputActorCriticPolicy, from sb3_contrib.common.maskable.policies import (
MaskableActorCriticCnnPolicy, MaskableActorCriticCnnPolicy,
MaskableActorCriticPolicy, MaskableActorCriticPolicy,
MaskableMultiInputActorCriticPolicy,
) )
MlpPolicy = MaskableActorCriticPolicy MlpPolicy = MaskableActorCriticPolicy
CnnPolicy = MaskableActorCriticCnnPolicy CnnPolicy = MaskableActorCriticCnnPolicy
# MultiInputPolicy = MaskableMultiInputActorCriticPolicy MultiInputPolicy = MaskableMultiInputActorCriticPolicy
register_policy("MlpPolicy", MaskableActorCriticPolicy) register_policy("MlpPolicy", MaskableActorCriticPolicy)
register_policy("CnnPolicy", MaskableActorCriticCnnPolicy) register_policy("CnnPolicy", MaskableActorCriticCnnPolicy)
# Currently not supported register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)
# register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)

View File

@ -2,6 +2,7 @@ import time
from collections import deque from collections import deque
from typing import Any, Dict, Optional, Tuple, Type, Union from typing import Any, Dict, Optional, Tuple, Type, Union
import gym
import numpy as np import numpy as np
import torch as th import torch as th
from gym import spaces from gym import spaces
@ -14,7 +15,7 @@ from stable_baselines3.common.utils import explained_variance, get_schedule_fn,
from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.vec_env import VecEnv
from torch.nn import functional as F from torch.nn import functional as F
from sb3_contrib.common.maskable.buffers import MaskableRolloutBuffer from sb3_contrib.common.maskable.buffers import MaskableDictRolloutBuffer, MaskableRolloutBuffer
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
@ -127,6 +128,10 @@ class MaskablePPO(OnPolicyAlgorithm):
self._setup_lr_schedule() self._setup_lr_schedule()
self.set_random_seed(self.seed) self.set_random_seed(self.seed)
buffer_cls = (
MaskableDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else MaskableRolloutBuffer
)
self.policy = self.policy_class( self.policy = self.policy_class(
self.observation_space, self.observation_space,
self.action_space, self.action_space,
@ -138,7 +143,7 @@ class MaskablePPO(OnPolicyAlgorithm):
if not isinstance(self.policy, MaskableActorCriticPolicy): if not isinstance(self.policy, MaskableActorCriticPolicy):
raise ValueError("Policy must subclass MaskableActorCriticPolicy") raise ValueError("Policy must subclass MaskableActorCriticPolicy")
self.rollout_buffer = MaskableRolloutBuffer( self.rollout_buffer = buffer_cls(
self.n_steps, self.n_steps,
self.observation_space, self.observation_space,
self.action_space, self.action_space,
@ -288,7 +293,9 @@ class MaskablePPO(OnPolicyAlgorithm):
collected, False if callback terminated rollout prematurely. collected, False if callback terminated rollout prematurely.
""" """
assert isinstance(rollout_buffer, MaskableRolloutBuffer), "RolloutBuffer doesn't support action masking" assert isinstance(
rollout_buffer, (MaskableRolloutBuffer, MaskableDictRolloutBuffer)
), "RolloutBuffer doesn't support action masking"
assert self._last_obs is not None, "No previous observation was provided" assert self._last_obs is not None, "No previous observation was provided"
# Switch to eval mode (this affects batch norm / dropout) # Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False) self.policy.set_training_mode(False)

View File

@ -1 +1 @@
1.2.1a2 1.2.1a3

View File

@ -32,6 +32,7 @@ See documentation for the full list of included features.
**RL Algorithms**: **RL Algorithms**:
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
**Gym Wrappers**: **Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) - [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
@ -74,6 +75,15 @@ setup(
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
version=__version__, version=__version__,
python_requires=">=3.6",
# PyPI package information.
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
],
) )
# python setup.py sdist # python setup.py sdist

View File

@ -217,16 +217,15 @@ def test_cnn():
evaluate_policy(model, env, warn=False) evaluate_policy(model, env, warn=False)
# Dict observations are currently not supported def test_dict_obs():
# def test_dict_obs(): env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10) env = ToDictWrapper(env)
# env = ToDictWrapper(env) model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8)
# model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8) model.learn(100)
# model.learn(100) evaluate_policy(model, env, warn=False)
# evaluate_policy(model, env, warn=False)
# # Mask all actions except the good one, a random model should succeed
# # Mask all actions except the good one, a random model should succeed env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19)
# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19) env = ToDictWrapper(env)
# env = ToDictWrapper(env) model = MaskablePPO("MultiInputPolicy", env, seed=8)
# model = MaskablePPO("MultiInputPolicy", env, seed=8) evaluate_policy(model, env, reward_threshold=99, warn=False)
# evaluate_policy(model, env, reward_threshold=99, warn=False)