From d6c5cea644844f1aab6d2bbd55d85113c7e72904 Mon Sep 17 00:00:00 2001 From: Geoff McDonald Date: Sat, 23 Oct 2021 08:05:37 -0700 Subject: [PATCH] MaskablePPO dictionary observation support (#47) * Add dictionary observation support for ppo_mask. * Improving naming consistency. * Update changelog. * Reformat and add test * Update doc * Update README and setup Co-authored-by: Antonin Raffin --- .github/ISSUE_TEMPLATE/issue-template.md | 6 ++ README.md | 1 + docs/guide/examples.rst | 50 +++------- docs/misc/changelog.rst | 5 +- docs/modules/ppo_mask.rst | 16 ++-- sb3_contrib/common/maskable/buffers.py | 116 ++++++++++++++++++++++- sb3_contrib/ppo_mask/__init__.py | 2 +- sb3_contrib/ppo_mask/policies.py | 8 +- sb3_contrib/ppo_mask/ppo_mask.py | 13 ++- sb3_contrib/version.txt | 2 +- setup.py | 10 ++ tests/test_invalid_actions.py | 25 +++-- 12 files changed, 185 insertions(+), 69 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/issue-template.md b/.github/ISSUE_TEMPLATE/issue-template.md index 4c3ddaa..b84c0aa 100644 --- a/.github/ISSUE_TEMPLATE/issue-template.md +++ b/.github/ISSUE_TEMPLATE/issue-template.md @@ -51,5 +51,11 @@ Describe the characteristic of your environment: * Gym version * Versions of any other relevant libraries + You can use `sb3.get_system_info()` to print relevant packages info: + ```python + import stable_baselines3 as sb3 + sb3.get_system_info() + ``` + **Additional context** Add any other context about the problem here. diff --git a/README.md b/README.md index 5d7e774..b8d205e 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ See documentation for the full list of included features. **RL Algorithms**: - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) +- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) **Gym Wrappers**: - [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 12b7d71..7d70788 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -13,7 +13,7 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment. from sb3_contrib import TQC model = TQC("MlpPolicy", "Pendulum-v0", top_quantiles_to_drop_per_net=2, verbose=1) - model.learn(total_timesteps=10000, log_interval=4) + model.learn(total_timesteps=10_000, log_interval=4) model.save("tqc_pendulum") QR-DQN @@ -27,40 +27,20 @@ Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment. policy_kwargs = dict(n_quantiles=50) model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) - model.learn(total_timesteps=10000, log_interval=4) + model.learn(total_timesteps=10_000, log_interval=4) model.save("qrdqn_cartpole") +MaskablePPO +----------- -.. PyBullet: Normalizing input features -.. ------------------------------------ -.. -.. Normalizing input features may be essential to successful training of an RL agent -.. (by default, images are scaled but not other types of input), -.. for instance when training on `PyBullet `__ environments. For that, a wrapper exists and -.. will compute a running average and standard deviation of input features (it can do the same for rewards). -.. -.. -.. .. note:: -.. -.. you need to install pybullet with ``pip install pybullet`` -.. -.. -.. .. image:: ../_static/img/colab-badge.svg -.. :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb -.. -.. -.. .. code-block:: python -.. -.. import gym -.. import pybullet_envs -.. -.. from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize -.. from stable_baselines3 import PPO -.. -.. env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) -.. # Automatically normalize the input features and reward -.. env = VecNormalize(env, norm_obs=True, norm_reward=True, -.. clip_obs=10.) -.. -.. model = PPO('MlpPolicy', env) -.. model.learn(total_timesteps=2000) +Train a PPO with invalid action masking agent on a toy environment. + +.. code-block:: python + + from sb3_contrib import MaskablePPO + from sb3_contrib.common.envs import InvalidActionEnvDiscrete + + env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60) + model = MaskablePPO("MlpPolicy", env, verbose=1) + model.learn(5000) + model.save("qrdqn_cartpole") diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fcc94da..d437efc 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.2.1a2 (WIP) +Release 1.2.1a3 (WIP) ------------------------------- **Invalid action masking for PPO** @@ -15,6 +15,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added ``MaskablePPO`` algorithm (@kronion) +- ``MaskablePPO`` Dictionary Observation support (@glmcdona) Bug Fixes: ^^^^^^^^^^ @@ -185,4 +186,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) Contributors: ------------- -@ku2482 @guyk1971 @minhlong94 @ayeright @kronion +@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona diff --git a/docs/modules/ppo_mask.rst b/docs/modules/ppo_mask.rst index 1b5a60d..9580ff3 100644 --- a/docs/modules/ppo_mask.rst +++ b/docs/modules/ppo_mask.rst @@ -16,7 +16,7 @@ algorithm. Other than adding support for action masking, the behavior is the sam MlpPolicy CnnPolicy - .. MultiInputPolicy + MultiInputPolicy Notes @@ -42,7 +42,7 @@ Discrete ✔️ ✔️ Box ❌ ✔️ MultiDiscrete ✔️ ✔️ MultiBinary ✔️ ✔️ -Dict ❌ ❌ +Dict ❌ ✔️ ============= ====== =========== @@ -248,9 +248,9 @@ MaskablePPO Policies :members: :noindex: -.. .. autoclass:: MultiInputPolicy -.. :members: -.. -.. .. autoclass:: sb3_contrib.common.maskable.policies.MaskableMultiInputActorCriticPolicy -.. :members: -.. :noindex: +.. autoclass:: MultiInputPolicy + :members: + +.. autoclass:: sb3_contrib.common.maskable.policies.MaskableMultiInputActorCriticPolicy + :members: + :noindex: diff --git a/sb3_contrib/common/maskable/buffers.py b/sb3_contrib/common/maskable/buffers.py index 672435e..71229ef 100644 --- a/sb3_contrib/common/maskable/buffers.py +++ b/sb3_contrib/common/maskable/buffers.py @@ -1,9 +1,10 @@ -from typing import Generator, NamedTuple, Optional +from typing import Generator, NamedTuple, Optional, Union import numpy as np import torch as th from gym import spaces -from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.type_aliases import TensorDict from stable_baselines3.common.vec_env import VecNormalize @@ -17,6 +18,16 @@ class MaskableRolloutBufferSamples(NamedTuple): action_masks: th.Tensor +class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + action_masks: th.Tensor + + class MaskableRolloutBuffer(RolloutBuffer): """ Rollout buffer that also stores the invalid action masks associated with each observation. @@ -96,3 +107,104 @@ class MaskableRolloutBuffer(RolloutBuffer): self.action_masks[batch_inds].reshape(-1, self.mask_dims), ) return MaskableRolloutBufferSamples(*map(self.to_torch, data)) + + +class MaskableDictRolloutBuffer(DictRolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RolloutBuffer to use dictionary observations + + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.action_masks = None + super(MaskableDictRolloutBuffer, self).__init__( + buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs + ) + + def reset(self) -> None: + if isinstance(self.action_space, spaces.Discrete): + mask_dims = self.action_space.n + elif isinstance(self.action_space, spaces.MultiDiscrete): + mask_dims = sum(self.action_space.nvec) + elif isinstance(self.action_space, spaces.MultiBinary): + mask_dims = 2 * self.action_space.n # One mask per binary outcome + else: + raise ValueError(f"Unsupported action space {type(self.action_space)}") + + self.mask_dims = mask_dims + self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32) + + super(MaskableDictRolloutBuffer, self).reset() + + def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: + """ + :param action_masks: Masks applied to constrain the choice of possible actions. + """ + if action_masks is not None: + self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims)) + + super(MaskableDictRolloutBuffer, self).add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + _tensor_names = ["actions", "values", "log_probs", "advantages", "returns", "action_masks"] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: + + return MaskableDictRolloutBufferSamples( + observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, + actions=self.to_torch(self.actions[batch_inds]), + old_values=self.to_torch(self.values[batch_inds].flatten()), + old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), + advantages=self.to_torch(self.advantages[batch_inds].flatten()), + returns=self.to_torch(self.returns[batch_inds].flatten()), + action_masks=self.to_torch(self.action_masks[batch_inds].reshape(-1, self.mask_dims)), + ) diff --git a/sb3_contrib/ppo_mask/__init__.py b/sb3_contrib/ppo_mask/__init__.py index 014cd93..c26ace8 100644 --- a/sb3_contrib/ppo_mask/__init__.py +++ b/sb3_contrib/ppo_mask/__init__.py @@ -1,2 +1,2 @@ -from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy # , MultiInputPolicy +from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO diff --git a/sb3_contrib/ppo_mask/policies.py b/sb3_contrib/ppo_mask/policies.py index 9164363..ad76571 100644 --- a/sb3_contrib/ppo_mask/policies.py +++ b/sb3_contrib/ppo_mask/policies.py @@ -1,15 +1,15 @@ from stable_baselines3.common.policies import register_policy -from sb3_contrib.common.maskable.policies import ( # MaskableMultiInputActorCriticPolicy, +from sb3_contrib.common.maskable.policies import ( MaskableActorCriticCnnPolicy, MaskableActorCriticPolicy, + MaskableMultiInputActorCriticPolicy, ) MlpPolicy = MaskableActorCriticPolicy CnnPolicy = MaskableActorCriticCnnPolicy -# MultiInputPolicy = MaskableMultiInputActorCriticPolicy +MultiInputPolicy = MaskableMultiInputActorCriticPolicy register_policy("MlpPolicy", MaskableActorCriticPolicy) register_policy("CnnPolicy", MaskableActorCriticCnnPolicy) -# Currently not supported -# register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy) +register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy) diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index be9fb12..4648af3 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -2,6 +2,7 @@ import time from collections import deque from typing import Any, Dict, Optional, Tuple, Type, Union +import gym import numpy as np import torch as th from gym import spaces @@ -14,7 +15,7 @@ from stable_baselines3.common.utils import explained_variance, get_schedule_fn, from stable_baselines3.common.vec_env import VecEnv from torch.nn import functional as F -from sb3_contrib.common.maskable.buffers import MaskableRolloutBuffer +from sb3_contrib.common.maskable.buffers import MaskableDictRolloutBuffer, MaskableRolloutBuffer from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported @@ -127,6 +128,10 @@ class MaskablePPO(OnPolicyAlgorithm): self._setup_lr_schedule() self.set_random_seed(self.seed) + buffer_cls = ( + MaskableDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else MaskableRolloutBuffer + ) + self.policy = self.policy_class( self.observation_space, self.action_space, @@ -138,7 +143,7 @@ class MaskablePPO(OnPolicyAlgorithm): if not isinstance(self.policy, MaskableActorCriticPolicy): raise ValueError("Policy must subclass MaskableActorCriticPolicy") - self.rollout_buffer = MaskableRolloutBuffer( + self.rollout_buffer = buffer_cls( self.n_steps, self.observation_space, self.action_space, @@ -288,7 +293,9 @@ class MaskablePPO(OnPolicyAlgorithm): collected, False if callback terminated rollout prematurely. """ - assert isinstance(rollout_buffer, MaskableRolloutBuffer), "RolloutBuffer doesn't support action masking" + assert isinstance( + rollout_buffer, (MaskableRolloutBuffer, MaskableDictRolloutBuffer) + ), "RolloutBuffer doesn't support action masking" assert self._last_obs is not None, "No previous observation was provided" # Switch to eval mode (this affects batch norm / dropout) self.policy.set_training_mode(False) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index c4baa5c..90ebae4 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.2.1a2 +1.2.1a3 diff --git a/setup.py b/setup.py index 1b75a55..0266d2b 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ See documentation for the full list of included features. **RL Algorithms**: - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) +- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) **Gym Wrappers**: - [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) @@ -74,6 +75,15 @@ setup( long_description=long_description, long_description_content_type="text/markdown", version=__version__, + python_requires=">=3.6", + # PyPI package information. + classifiers=[ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + ], ) # python setup.py sdist diff --git a/tests/test_invalid_actions.py b/tests/test_invalid_actions.py index 89245d9..222191f 100644 --- a/tests/test_invalid_actions.py +++ b/tests/test_invalid_actions.py @@ -217,16 +217,15 @@ def test_cnn(): evaluate_policy(model, env, warn=False) -# Dict observations are currently not supported -# def test_dict_obs(): -# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10) -# env = ToDictWrapper(env) -# model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8) -# model.learn(100) -# evaluate_policy(model, env, warn=False) -# -# # Mask all actions except the good one, a random model should succeed -# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19) -# env = ToDictWrapper(env) -# model = MaskablePPO("MultiInputPolicy", env, seed=8) -# evaluate_policy(model, env, reward_threshold=99, warn=False) +def test_dict_obs(): + env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10) + env = ToDictWrapper(env) + model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8) + model.learn(100) + evaluate_policy(model, env, warn=False) + + # Mask all actions except the good one, a random model should succeed + env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19) + env = ToDictWrapper(env) + model = MaskablePPO("MultiInputPolicy", env, seed=8) + evaluate_policy(model, env, reward_threshold=99, warn=False)