MaskablePPO dictionary observation support (#47)
* Add dictionary observation support for ppo_mask. * Improving naming consistency. * Update changelog. * Reformat and add test * Update doc * Update README and setup Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
91f9b1ed34
commit
d6c5cea644
|
|
@ -51,5 +51,11 @@ Describe the characteristic of your environment:
|
|||
* Gym version
|
||||
* Versions of any other relevant libraries
|
||||
|
||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||
```python
|
||||
import stable_baselines3 as sb3
|
||||
sb3.get_system_info()
|
||||
```
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ See documentation for the full list of included features.
|
|||
**RL Algorithms**:
|
||||
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
||||
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
|
||||
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
|
||||
|
||||
**Gym Wrappers**:
|
||||
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
|
|||
from sb3_contrib import TQC
|
||||
|
||||
model = TQC("MlpPolicy", "Pendulum-v0", top_quantiles_to_drop_per_net=2, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.learn(total_timesteps=10_000, log_interval=4)
|
||||
model.save("tqc_pendulum")
|
||||
|
||||
QR-DQN
|
||||
|
|
@ -27,40 +27,20 @@ Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.
|
|||
|
||||
policy_kwargs = dict(n_quantiles=50)
|
||||
model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.learn(total_timesteps=10_000, log_interval=4)
|
||||
model.save("qrdqn_cartpole")
|
||||
|
||||
MaskablePPO
|
||||
-----------
|
||||
|
||||
.. PyBullet: Normalizing input features
|
||||
.. ------------------------------------
|
||||
..
|
||||
.. Normalizing input features may be essential to successful training of an RL agent
|
||||
.. (by default, images are scaled but not other types of input),
|
||||
.. for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and
|
||||
.. will compute a running average and standard deviation of input features (it can do the same for rewards).
|
||||
..
|
||||
..
|
||||
.. .. note::
|
||||
..
|
||||
.. you need to install pybullet with ``pip install pybullet``
|
||||
..
|
||||
..
|
||||
.. .. image:: ../_static/img/colab-badge.svg
|
||||
.. :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
|
||||
..
|
||||
..
|
||||
.. .. code-block:: python
|
||||
..
|
||||
.. import gym
|
||||
.. import pybullet_envs
|
||||
..
|
||||
.. from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||
.. from stable_baselines3 import PPO
|
||||
..
|
||||
.. env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
|
||||
.. # Automatically normalize the input features and reward
|
||||
.. env = VecNormalize(env, norm_obs=True, norm_reward=True,
|
||||
.. clip_obs=10.)
|
||||
..
|
||||
.. model = PPO('MlpPolicy', env)
|
||||
.. model.learn(total_timesteps=2000)
|
||||
Train a PPO with invalid action masking agent on a toy environment.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sb3_contrib import MaskablePPO
|
||||
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
|
||||
|
||||
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
||||
model = MaskablePPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(5000)
|
||||
model.save("qrdqn_cartpole")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.2.1a2 (WIP)
|
||||
Release 1.2.1a3 (WIP)
|
||||
-------------------------------
|
||||
|
||||
**Invalid action masking for PPO**
|
||||
|
|
@ -15,6 +15,7 @@ Breaking Changes:
|
|||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``MaskablePPO`` algorithm (@kronion)
|
||||
- ``MaskablePPO`` Dictionary Observation support (@glmcdona)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -185,4 +186,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
|
|||
Contributors:
|
||||
-------------
|
||||
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ algorithm. Other than adding support for action masking, the behavior is the sam
|
|||
|
||||
MlpPolicy
|
||||
CnnPolicy
|
||||
.. MultiInputPolicy
|
||||
MultiInputPolicy
|
||||
|
||||
|
||||
Notes
|
||||
|
|
@ -42,7 +42,7 @@ Discrete ✔️ ✔️
|
|||
Box ❌ ✔️
|
||||
MultiDiscrete ✔️ ✔️
|
||||
MultiBinary ✔️ ✔️
|
||||
Dict ❌ ❌
|
||||
Dict ❌ ✔️
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
|
|
@ -248,9 +248,9 @@ MaskablePPO Policies
|
|||
:members:
|
||||
:noindex:
|
||||
|
||||
.. .. autoclass:: MultiInputPolicy
|
||||
.. :members:
|
||||
..
|
||||
.. .. autoclass:: sb3_contrib.common.maskable.policies.MaskableMultiInputActorCriticPolicy
|
||||
.. :members:
|
||||
.. :noindex:
|
||||
.. autoclass:: MultiInputPolicy
|
||||
:members:
|
||||
|
||||
.. autoclass:: sb3_contrib.common.maskable.policies.MaskableMultiInputActorCriticPolicy
|
||||
:members:
|
||||
:noindex:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Generator, NamedTuple, Optional
|
||||
from typing import Generator, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||
from stable_baselines3.common.type_aliases import TensorDict
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
|
||||
|
|
@ -17,6 +18,16 @@ class MaskableRolloutBufferSamples(NamedTuple):
|
|||
action_masks: th.Tensor
|
||||
|
||||
|
||||
class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples):
|
||||
observations: TensorDict
|
||||
actions: th.Tensor
|
||||
old_values: th.Tensor
|
||||
old_log_prob: th.Tensor
|
||||
advantages: th.Tensor
|
||||
returns: th.Tensor
|
||||
action_masks: th.Tensor
|
||||
|
||||
|
||||
class MaskableRolloutBuffer(RolloutBuffer):
|
||||
"""
|
||||
Rollout buffer that also stores the invalid action masks associated with each observation.
|
||||
|
|
@ -96,3 +107,104 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
self.action_masks[batch_inds].reshape(-1, self.mask_dims),
|
||||
)
|
||||
return MaskableRolloutBufferSamples(*map(self.to_torch, data))
|
||||
|
||||
|
||||
class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||
"""
|
||||
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
|
||||
Extends the RolloutBuffer to use dictionary observations
|
||||
|
||||
It corresponds to ``buffer_size`` transitions collected
|
||||
using the current policy.
|
||||
This experience will be discarded after the policy update.
|
||||
In order to use PPO objective, we also store the current value of each state
|
||||
and the log probability of each taken action.
|
||||
|
||||
The term rollout here refers to the model-free notion and should not
|
||||
be used with the concept of rollout used in model-based RL or planning.
|
||||
Hence, it is only involved in policy and value function training but not action selection.
|
||||
|
||||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param device:
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
Equivalent to classic advantage when set to 1.
|
||||
:param gamma: Discount factor
|
||||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "cpu",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
):
|
||||
self.action_masks = None
|
||||
super(MaskableDictRolloutBuffer, self).__init__(
|
||||
buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
mask_dims = self.action_space.n
|
||||
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
||||
mask_dims = sum(self.action_space.nvec)
|
||||
elif isinstance(self.action_space, spaces.MultiBinary):
|
||||
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
||||
|
||||
self.mask_dims = mask_dims
|
||||
self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
|
||||
|
||||
super(MaskableDictRolloutBuffer, self).reset()
|
||||
|
||||
def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
|
||||
"""
|
||||
:param action_masks: Masks applied to constrain the choice of possible actions.
|
||||
"""
|
||||
if action_masks is not None:
|
||||
self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
|
||||
|
||||
super(MaskableDictRolloutBuffer, self).add(*args, **kwargs)
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
|
||||
assert self.full, ""
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
if not self.generator_ready:
|
||||
|
||||
for key, obs in self.observations.items():
|
||||
self.observations[key] = self.swap_and_flatten(obs)
|
||||
|
||||
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns", "action_masks"]
|
||||
|
||||
for tensor in _tensor_names:
|
||||
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
||||
self.generator_ready = True
|
||||
|
||||
# Return everything, don't create minibatches
|
||||
if batch_size is None:
|
||||
batch_size = self.buffer_size * self.n_envs
|
||||
|
||||
start_idx = 0
|
||||
while start_idx < self.buffer_size * self.n_envs:
|
||||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
|
||||
|
||||
return MaskableDictRolloutBufferSamples(
|
||||
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||
actions=self.to_torch(self.actions[batch_inds]),
|
||||
old_values=self.to_torch(self.values[batch_inds].flatten()),
|
||||
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
|
||||
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
|
||||
returns=self.to_torch(self.returns[batch_inds].flatten()),
|
||||
action_masks=self.to_torch(self.action_masks[batch_inds].reshape(-1, self.mask_dims)),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy # , MultiInputPolicy
|
||||
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
from stable_baselines3.common.policies import register_policy
|
||||
|
||||
from sb3_contrib.common.maskable.policies import ( # MaskableMultiInputActorCriticPolicy,
|
||||
from sb3_contrib.common.maskable.policies import (
|
||||
MaskableActorCriticCnnPolicy,
|
||||
MaskableActorCriticPolicy,
|
||||
MaskableMultiInputActorCriticPolicy,
|
||||
)
|
||||
|
||||
MlpPolicy = MaskableActorCriticPolicy
|
||||
CnnPolicy = MaskableActorCriticCnnPolicy
|
||||
# MultiInputPolicy = MaskableMultiInputActorCriticPolicy
|
||||
MultiInputPolicy = MaskableMultiInputActorCriticPolicy
|
||||
|
||||
register_policy("MlpPolicy", MaskableActorCriticPolicy)
|
||||
register_policy("CnnPolicy", MaskableActorCriticCnnPolicy)
|
||||
# Currently not supported
|
||||
# register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)
|
||||
register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import time
|
|||
from collections import deque
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
|
@ -14,7 +15,7 @@ from stable_baselines3.common.utils import explained_variance, get_schedule_fn,
|
|||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from torch.nn import functional as F
|
||||
|
||||
from sb3_contrib.common.maskable.buffers import MaskableRolloutBuffer
|
||||
from sb3_contrib.common.maskable.buffers import MaskableDictRolloutBuffer, MaskableRolloutBuffer
|
||||
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
||||
from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
|
||||
|
||||
|
|
@ -127,6 +128,10 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
buffer_cls = (
|
||||
MaskableDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else MaskableRolloutBuffer
|
||||
)
|
||||
|
||||
self.policy = self.policy_class(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
|
|
@ -138,7 +143,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
if not isinstance(self.policy, MaskableActorCriticPolicy):
|
||||
raise ValueError("Policy must subclass MaskableActorCriticPolicy")
|
||||
|
||||
self.rollout_buffer = MaskableRolloutBuffer(
|
||||
self.rollout_buffer = buffer_cls(
|
||||
self.n_steps,
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
|
|
@ -288,7 +293,9 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
collected, False if callback terminated rollout prematurely.
|
||||
"""
|
||||
|
||||
assert isinstance(rollout_buffer, MaskableRolloutBuffer), "RolloutBuffer doesn't support action masking"
|
||||
assert isinstance(
|
||||
rollout_buffer, (MaskableRolloutBuffer, MaskableDictRolloutBuffer)
|
||||
), "RolloutBuffer doesn't support action masking"
|
||||
assert self._last_obs is not None, "No previous observation was provided"
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.policy.set_training_mode(False)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.2.1a2
|
||||
1.2.1a3
|
||||
|
|
|
|||
10
setup.py
10
setup.py
|
|
@ -32,6 +32,7 @@ See documentation for the full list of included features.
|
|||
**RL Algorithms**:
|
||||
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
||||
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
|
||||
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
|
||||
|
||||
**Gym Wrappers**:
|
||||
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
||||
|
|
@ -74,6 +75,15 @@ setup(
|
|||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
version=__version__,
|
||||
python_requires=">=3.6",
|
||||
# PyPI package information.
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
],
|
||||
)
|
||||
|
||||
# python setup.py sdist
|
||||
|
|
|
|||
|
|
@ -217,16 +217,15 @@ def test_cnn():
|
|||
evaluate_policy(model, env, warn=False)
|
||||
|
||||
|
||||
# Dict observations are currently not supported
|
||||
# def test_dict_obs():
|
||||
# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
# env = ToDictWrapper(env)
|
||||
# model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8)
|
||||
# model.learn(100)
|
||||
# evaluate_policy(model, env, warn=False)
|
||||
#
|
||||
# # Mask all actions except the good one, a random model should succeed
|
||||
# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19)
|
||||
# env = ToDictWrapper(env)
|
||||
# model = MaskablePPO("MultiInputPolicy", env, seed=8)
|
||||
# evaluate_policy(model, env, reward_threshold=99, warn=False)
|
||||
def test_dict_obs():
|
||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
env = ToDictWrapper(env)
|
||||
model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8)
|
||||
model.learn(100)
|
||||
evaluate_policy(model, env, warn=False)
|
||||
|
||||
# Mask all actions except the good one, a random model should succeed
|
||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19)
|
||||
env = ToDictWrapper(env)
|
||||
model = MaskablePPO("MultiInputPolicy", env, seed=8)
|
||||
evaluate_policy(model, env, reward_threshold=99, warn=False)
|
||||
|
|
|
|||
Loading…
Reference in New Issue