diff --git a/Makefile b/Makefile index a8f34af..d36ad2e 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ check-codestyle: # Sort imports isort --check ${LINT_PATHS} # Reformat using black - black --check -l 127 ${LINT_PATHS} + black --check ${LINT_PATHS} commit-checks: format type lint diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 627daac..52f7498 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -6,7 +6,7 @@ dependencies: - cpuonly=1.0=0 - pip=20.2 - python=3.6 - - pytorch=1.5.0=py3.6_cpu_0 + - pytorch=1.8.1=py3.6_cpu_0 - pip: - gym>=0.17.2 - cloudpickle @@ -15,4 +15,6 @@ dependencies: - numpy - matplotlib - sphinx_autodoc_typehints - - stable-baselines3>=0.10.0a0 + - stable-baselines3>=1.2.0 + # See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 + - sphinx_rtd_theme>=1.0 diff --git a/docs/images/10x10_combined.png b/docs/images/10x10_combined.png new file mode 100644 index 0000000..e93cd1a Binary files /dev/null and b/docs/images/10x10_combined.png differ diff --git a/docs/images/10x10_mask.png b/docs/images/10x10_mask.png new file mode 100644 index 0000000..5e2fe04 Binary files /dev/null and b/docs/images/10x10_mask.png differ diff --git a/docs/images/10x10_no_mask.png b/docs/images/10x10_no_mask.png new file mode 100644 index 0000000..e3d2a9f Binary files /dev/null and b/docs/images/10x10_no_mask.png differ diff --git a/docs/images/4x4_combined.png b/docs/images/4x4_combined.png new file mode 100644 index 0000000..0dcfda9 Binary files /dev/null and b/docs/images/4x4_combined.png differ diff --git a/docs/images/4x4_mask.png b/docs/images/4x4_mask.png new file mode 100644 index 0000000..7fa76ee Binary files /dev/null and b/docs/images/4x4_mask.png differ diff --git a/docs/images/4x4_no_mask.png b/docs/images/4x4_no_mask.png new file mode 100644 index 0000000..a78ecb2 Binary files /dev/null and b/docs/images/4x4_no_mask.png differ diff --git a/docs/index.rst b/docs/index.rst index f86c47e..8e37c71 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,6 +33,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d modules/tqc modules/qrdqn + modules/ppo_mask .. toctree:: :maxdepth: 1 diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 07a42a0..79693dd 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,31 @@ Changelog ========== +Release 1.2.1a0 (WIP) +------------------------------- + +**Invalid action masking for PPO** + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ +- Added ``MaskablePPO`` algorithm (@kronion) + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + + +Documentation: +^^^^^^^^^^^^^^ + Release 1.2.0 (2021-09-08) ------------------------------- @@ -13,6 +38,9 @@ Breaking Changes: ^^^^^^^^^^^^^^^^^ - Upgraded to Stable-Baselines3 >= 1.2.0 +New Features: +^^^^^^^^^^^^^ + Bug Fixes: ^^^^^^^^^^ - QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright) @@ -156,4 +184,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) Contributors: ------------- -@ku2482 @guyk1971 @minhlong94 @ayeright +@ku2482 @guyk1971 @minhlong94 @ayeright @kronion diff --git a/docs/modules/ppo_mask.rst b/docs/modules/ppo_mask.rst new file mode 100644 index 0000000..1b5a60d --- /dev/null +++ b/docs/modules/ppo_mask.rst @@ -0,0 +1,256 @@ +.. _ppo_mask: + +.. automodule:: sb3_contrib.ppo_mask + +Maskable PPO +============ + +Implementation of `invalid action masking `_ for the Proximal Policy Optimization(PPO) +algorithm. Other than adding support for action masking, the behavior is the same as in SB3's core PPO algorithm. + + +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + MlpPolicy + CnnPolicy + .. MultiInputPolicy + + +Notes +----- + +- Paper: https://arxiv.org/abs/2006.14171 +- Blog post: https://costa.sh/blog-a-closer-look-at-invalid-action-masking-in-policy-gradient-algorithms.html +- Additional Blog post: https://boring-guy.sh/posts/masking-rl/ + + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ✔️ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ✔️ ✔️ +Box ❌ ✔️ +MultiDiscrete ✔️ ✔️ +MultiBinary ✔️ ✔️ +Dict ❌ ❌ +============= ====== =========== + + +Example +------- + +Train a PPO agent on ``InvalidActionEnvDiscrete``. ``InvalidActionEnvDiscrete`` has a ``action_masks`` method that +returns the invalid action mask (``True`` if the action is valid, ``False`` otherwise). + +.. code-block:: python + + from sb3_contrib import MaskablePPO + from sb3_contrib.common.envs import InvalidActionEnvDiscrete + from sb3_contrib.common.maskable.evaluation import evaluate_policy + from sb3_contrib.common.maskable.utils import get_action_masks + + + env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60) + model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, verbose=1) + model.learn(5000) + + evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False) + + model.save("ppo_mask") + del model # remove to demonstrate saving and loading + + model = MaskablePPO.load("ppo_mask") + + obs = env.reset() + while True: + # Retrieve current action mask + action_masks = get_action_masks(env) + action, _states = model.predict(obs, action_masks=action_masks) + obs, rewards, dones, info = env.step(action) + env.render() + + +If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker`` +to specify the name (see `PR #25 `_): + + +.. code-block:: python + + import gym + import numpy as np + + from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy + from sb3_contrib.common.wrappers import ActionMasker + from sb3_contrib.ppo_mask import MaskablePPO + + + def mask_fn(env: gym.Env) -> np.ndarray: + # Do whatever you'd like in this function to return the action mask + # for the current env. In this example, we assume the env has a + # helpful method we can rely on. + return env.valid_action_mask() + + + env = ... # Initialize env + env = ActionMasker(env, mask_fn) # Wrap to enable masking + + # MaskablePPO behaves the same as SB3's PPO unless the env is wrapped + # with ActionMasker. If the wrapper is detected, the masks are automatically + # retrieved and used when learning. Note that MaskablePPO does not accept + # a new action_mask_fn kwarg, as it did in an earlier draft. + model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1) + model.learn() + + # Note that use of masks is manual and optional outside of learning, + # so masking can be "removed" at testing time + model.predict(observation, action_masks=valid_action_array) + + +Results +------- + +Results are shown for two MicroRTS benchmarks: MicrortsMining4x4F9-v0 (600K steps) and MicrortsMining10x10F9-v0 +(1.5M steps). For each, models were trained with and without masking, using 3 seeds. + +4x4 +^^^ + +No masking +"""""""""" + +.. image:: /images/4x4_no_mask.png + +With masking +"""""""""""" + +.. image:: /images/4x4_mask.png + +Combined +"""""""" + +.. image:: /images/4x4_combined.png + +10x10 +^^^^^ + +No masking +"""""""""" + +.. image:: /images/10x10_no_mask.png + +With masking +"""""""""""" + +"""""""" + +.. image:: /images/10x10_mask.png + +Combined +"""""""" + +.. image:: /images/10x10_combined.png + +More information may be found in the +`associated PR `_. + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the repo for the experiment: + +.. code-block:: bash + + git clone git@github.com:kronion/microrts-ppo-comparison.git + cd microrts-ppo-comparison + +Install dependencies: + +.. code-block:: bash + + # Install MicroRTS: + rm -fR ~/microrts && mkdir ~/microrts && \ + wget -O ~/microrts/microrts.zip http://microrts.s3.amazonaws.com/microrts/artifacts/202004222224.microrts.zip && \ + unzip ~/microrts/microrts.zip -d ~/microrts/ + + # You may want to make a venv before installing packages + pip install -r requirements.txt + +Train several times with various seeds, with and without masking: + +.. code-block:: bash + + # python sb/train_ppo.py [output dir] [MicroRTS map size] [--mask] [--seed int] + + # 4x4 unmasked + python sb3/train_ppo.py zoo 4 --seed 42 + python sb3/train_ppo.py zoo 4 --seed 43 + python sb3/train_ppo.py zoo 4 --seed 44 + + # 4x4 masked + python sb3/train_ppo.py zoo 4 --mask --seed 42 + python sb3/train_ppo.py zoo 4 --mask --seed 43 + python sb3/train_ppo.py zoo 4 --mask --seed 44 + + # 10x10 unmasked + python sb3/train_ppo.py zoo 10 --seed 42 + python sb3/train_ppo.py zoo 10 --seed 43 + python sb3/train_ppo.py zoo 10 --seed 44 + + # 10x10 masked + python sb3/train_ppo.py zoo 10 --mask --seed 42 + python sb3/train_ppo.py zoo 10 --mask --seed 43 + python sb3/train_ppo.py zoo 10 --mask --seed 44 + +View the tensorboard log output: + +.. code-block:: bash + + # For 4x4 environment + tensorboard --logdir zoo/4x4/runs + + # For 10x10 environment + tensorboard --logdir zoo/10x10/runs + + +Parameters +---------- + +.. autoclass:: MaskablePPO + :members: + :inherited-members: + + +MaskablePPO Policies +-------------------- + +.. autoclass:: MlpPolicy + :members: + :inherited-members: + +.. autoclass:: sb3_contrib.common.maskable.policies.MaskableActorCriticPolicy + :members: + :noindex: + +.. autoclass:: CnnPolicy + :members: + +.. autoclass:: sb3_contrib.common.maskable.policies.MaskableActorCriticCnnPolicy + :members: + :noindex: + +.. .. autoclass:: MultiInputPolicy +.. :members: +.. +.. .. autoclass:: sb3_contrib.common.maskable.policies.MaskableMultiInputActorCriticPolicy +.. :members: +.. :noindex: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..dcff8a7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 127 diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 8f253e1..c90336a 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -1,5 +1,6 @@ import os +from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC diff --git a/sb3_contrib/common/envs/__init__.py b/sb3_contrib/common/envs/__init__.py new file mode 100644 index 0000000..20bdffe --- /dev/null +++ b/sb3_contrib/common/envs/__init__.py @@ -0,0 +1,5 @@ +from sb3_contrib.common.envs.invalid_actions_env import ( + InvalidActionEnvDiscrete, + InvalidActionEnvMultiBinary, + InvalidActionEnvMultiDiscrete, +) diff --git a/sb3_contrib/common/envs/invalid_actions_env.py b/sb3_contrib/common/envs/invalid_actions_env.py new file mode 100644 index 0000000..aed67db --- /dev/null +++ b/sb3_contrib/common/envs/invalid_actions_env.py @@ -0,0 +1,116 @@ +from typing import List, Optional + +import numpy as np +from gym import spaces +from stable_baselines3.common.envs import IdentityEnv + + +class InvalidActionEnvDiscrete(IdentityEnv): + """ + Identity env with a discrete action space. Supports action masking. + """ + + def __init__( + self, + dim: Optional[int] = None, + ep_length: int = 100, + n_invalid_actions: int = 0, + ): + if dim is None: + dim = 1 + assert n_invalid_actions < dim, f"Too many invalid actions: {n_invalid_actions} < {dim}" + + space = spaces.Discrete(dim) + self.n_invalid_actions = n_invalid_actions + self.possible_actions = np.arange(space.n) + self.invalid_actions: List[int] = [] + super().__init__(space=space, ep_length=ep_length) + + def _choose_next_state(self) -> None: + self.state = self.action_space.sample() + # Randomly choose invalid actions that are not the current state + potential_invalid_actions = [i for i in self.possible_actions if i != self.state] + self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) + + def action_masks(self) -> List[bool]: + return [action not in self.invalid_actions for action in self.possible_actions] + + +class InvalidActionEnvMultiDiscrete(IdentityEnv): + """ + Identity env with a multidiscrete action space. Supports action masking. + """ + + def __init__( + self, + dims: Optional[List[int]] = None, + ep_length: int = 100, + n_invalid_actions: int = 0, + ): + if dims is None: + dims = [1, 1] + + if n_invalid_actions > sum(dims) - len(dims): + raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {sum(dims) - len(dims)}") + + space = spaces.MultiDiscrete(dims) + self.n_invalid_actions = n_invalid_actions + self.possible_actions = np.arange(sum(dims)) + self.invalid_actions: List[int] = [] + super().__init__(space=space, ep_length=ep_length) + + def _choose_next_state(self) -> None: + self.state = self.action_space.sample() + + converted_state: List[int] = [] + running_total = 0 + for i in range(len(self.action_space.nvec)): + converted_state.append(running_total + self.state[i]) + running_total += self.action_space.nvec[i] + + # Randomly choose invalid actions that are not the current state + potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state] + self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) + + def action_masks(self) -> List[bool]: + return [action not in self.invalid_actions for action in self.possible_actions] + + +class InvalidActionEnvMultiBinary(IdentityEnv): + """ + Identity env with a multibinary action space. Supports action masking. + """ + + def __init__( + self, + dims: Optional[int] = None, + ep_length: int = 100, + n_invalid_actions: int = 0, + ): + if dims is None: + dims = 1 + + if n_invalid_actions > dims: + raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}") + + space = spaces.MultiBinary(dims) + self.n_invalid_actions = n_invalid_actions + self.possible_actions = np.arange(2 * dims) + self.invalid_actions: List[int] = [] + super().__init__(space=space, ep_length=ep_length) + + def _choose_next_state(self) -> None: + self.state = self.action_space.sample() + + converted_state: List[int] = [] + running_total = 0 + for i in range(self.action_space.n): + converted_state.append(running_total + self.state[i]) + running_total += 2 + + # Randomly choose invalid actions that are not the current state + potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state] + self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) + + def action_masks(self) -> List[bool]: + return [action not in self.invalid_actions for action in self.possible_actions] diff --git a/sb3_contrib/common/maskable/__init__.py b/sb3_contrib/common/maskable/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sb3_contrib/common/maskable/buffers.py b/sb3_contrib/common/maskable/buffers.py new file mode 100644 index 0000000..672435e --- /dev/null +++ b/sb3_contrib/common/maskable/buffers.py @@ -0,0 +1,98 @@ +from typing import Generator, NamedTuple, Optional + +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.vec_env import VecNormalize + + +class MaskableRolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + action_masks: th.Tensor + + +class MaskableRolloutBuffer(RolloutBuffer): + """ + Rollout buffer that also stores the invalid action masks associated with each observation. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__(self, *args, **kwargs): + self.action_masks = None + super().__init__(*args, **kwargs) + + def reset(self) -> None: + if isinstance(self.action_space, spaces.Discrete): + mask_dims = self.action_space.n + elif isinstance(self.action_space, spaces.MultiDiscrete): + mask_dims = sum(self.action_space.nvec) + elif isinstance(self.action_space, spaces.MultiBinary): + mask_dims = 2 * self.action_space.n # One mask per binary outcome + else: + raise ValueError(f"Unsupported action space {type(self.action_space)}") + + self.mask_dims = mask_dims + self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32) + + super().reset() + + def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: + """ + :param action_masks: Masks applied to constrain the choice of possible actions. + """ + if action_masks is not None: + self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims)) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "action_masks", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + self.action_masks[batch_inds].reshape(-1, self.mask_dims), + ) + return MaskableRolloutBufferSamples(*map(self.to_torch, data)) diff --git a/sb3_contrib/common/maskable/callbacks.py b/sb3_contrib/common/maskable/callbacks.py new file mode 100644 index 0000000..9832548 --- /dev/null +++ b/sb3_contrib/common/maskable/callbacks.py @@ -0,0 +1,107 @@ +import os + +import numpy as np +from stable_baselines3.common.callbacks import EvalCallback +from stable_baselines3.common.vec_env import sync_envs_normalization + +from sb3_contrib.common.maskable.evaluation import evaluate_policy + + +class MaskableEvalCallback(EvalCallback): + """ + Callback for evaluating an agent. Supports invalid action masking. + + :param eval_env: The environment used for initialization + :param callback_on_new_best: Callback to trigger + when there is a new best model according to the ``mean_reward`` + :param n_eval_episodes: The number of episodes to test the agent + :param eval_freq: Evaluate the agent every eval_freq call of the callback. + :param log_path: Path to a folder where the evaluations (``evaluations.npz``) + will be saved. It will be updated at each evaluation. + :param best_model_save_path: Path to a folder where the best model + according to performance on the eval env will be saved. + :param deterministic: Whether the evaluation should + use a stochastic or deterministic actions. + :param render: Whether to render or not the environment during evaluation + :param verbose: + :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been + wrapped with a Monitor wrapper) + :param use_masking: Whether or not to use invalid action masks during evaluation + """ + + def __init__(self, *args, use_masking: bool = True, **kwargs): + super().__init__(*args, **kwargs) + self.use_masking = use_masking + + def _on_step(self) -> bool: + if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: + # Sync training and eval env if there is VecNormalize + sync_envs_normalization(self.training_env, self.eval_env) + + # Reset success rate buffer + self._is_success_buffer = [] + + # Note that evaluate_policy() has been patched to support masking + episode_rewards, episode_lengths = evaluate_policy( + self.model, + self.eval_env, + n_eval_episodes=self.n_eval_episodes, + render=self.render, + deterministic=self.deterministic, + return_episode_rewards=True, + warn=self.warn, + callback=self._log_success_callback, + use_masking=self.use_masking, + ) + + if self.log_path is not None: + self.evaluations_timesteps.append(self.num_timesteps) + self.evaluations_results.append(episode_rewards) + self.evaluations_length.append(episode_lengths) + + kwargs = {} + # Save success log if present + if len(self._is_success_buffer) > 0: + self.evaluations_successes.append(self._is_success_buffer) + kwargs = dict(successes=self.evaluations_successes) + + np.savez( + self.log_path, + timesteps=self.evaluations_timesteps, + results=self.evaluations_results, + ep_lengths=self.evaluations_length, + **kwargs, + ) + + mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards) + mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths) + self.last_mean_reward = mean_reward + + if self.verbose > 0: + print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") + print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}") + # Add to current Logger + self.logger.record("eval/mean_reward", float(mean_reward)) + self.logger.record("eval/mean_ep_length", mean_ep_length) + + if len(self._is_success_buffer) > 0: + success_rate = np.mean(self._is_success_buffer) + if self.verbose > 0: + print(f"Success rate: {100 * success_rate:.2f}%") + self.logger.record("eval/success_rate", success_rate) + + # Dump log so the evaluation results are printed with the correct timestep + self.logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(self.num_timesteps) + + if mean_reward > self.best_mean_reward: + if self.verbose > 0: + print("New best mean reward!") + if self.best_model_save_path is not None: + self.model.save(os.path.join(self.best_model_save_path, "best_model")) + self.best_mean_reward = mean_reward + # Trigger callback if needed + if self.callback is not None: + return self._on_event() + + return True diff --git a/sb3_contrib/common/maskable/distributions.py b/sb3_contrib/common/maskable/distributions.py new file mode 100644 index 0000000..1ccf5ee --- /dev/null +++ b/sb3_contrib/common/maskable/distributions.py @@ -0,0 +1,272 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple + +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common.distributions import Distribution +from torch import nn +from torch.distributions import Categorical +from torch.distributions.utils import logits_to_probs + + +class MaskableCategorical(Categorical): + """ + Modified PyTorch Categorical distribution with support for invalid action masking. + + To instantiate, must provide either probs or logits, but not both. + + :param probs: Tensor containing finite non-negative values, which will be renormalized + to sum to 1 along the last dimension. + :param logits: Tensor of unnormalized log probabilities. + :param validate_args: Whether or not to validate that arguments to methods like lob_prob() + and icdf() match the distribution's shape, support, etc. + :param masks: An optional boolean ndarray of compatible shape with the distribution. + If True, the corresponding choice's logit value is preserved. If False, it is set to a + large negative value, resulting in near 0 probability. + """ + + def __init__( + self, + probs: Optional[th.Tensor] = None, + logits: Optional[th.Tensor] = None, + validate_args: Optional[bool] = None, + masks: Optional[np.ndarray] = None, + ): + self.masks: Optional[th.Tensor] = None + super().__init__(probs, logits, validate_args) + self._original_logits = self.logits + self.apply_masking(masks) + + def apply_masking(self, masks: Optional[np.ndarray]) -> None: + """ + Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0. + + :param masks: An optional boolean ndarray of compatible shape with the distribution. + If True, the corresponding choice's logit value is preserved. If False, it is set + to a large negative value, resulting in near 0 probability. If masks is None, any + previously applied masking is removed, and the original logits are restored. + """ + + if masks is not None: + device = self.logits.device + self.masks = th.as_tensor(masks, dtype=th.bool, device=device).reshape(self.logits.shape) + HUGE_NEG = th.tensor(-1e8, dtype=self.logits.dtype, device=device) + + logits = th.where(self.masks, self._original_logits, HUGE_NEG) + else: + self.masks = None + logits = self._original_logits + + # Reinitialize with updated logits + super().__init__(logits=logits) + + # self.probs may already be cached, so we must force an update + self.probs = logits_to_probs(self.logits) + + def entropy(self) -> th.Tensor: + if self.masks is None: + return super().entropy() + + # Highly negative logits don't result in 0 probs, so we must replace + # with 0s to ensure 0 contribution to the distribution's entropy, since + # masked actions possess no uncertainty. + device = self.logits.device + p_log_p = self.logits * self.probs + p_log_p = th.where(self.masks, p_log_p, th.tensor(0.0, device=device)) + return -p_log_p.sum(-1) + + +class MaskableDistribution(Distribution, ABC): + @abstractmethod + def apply_masking(self, masks: Optional[np.ndarray]) -> None: + """ + Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0. + + :param masks: An optional boolean ndarray of compatible shape with the distribution. + If True, the corresponding choice's logit value is preserved. If False, it is set + to a large negative value, resulting in near 0 probability. If masks is None, any + previously applied masking is removed, and the original logits are restored. + """ + + +class MaskableCategoricalDistribution(MaskableDistribution): + """ + Categorical distribution for discrete actions. Supports invalid action masking. + + :param action_dim: Number of discrete actions + """ + + def __init__(self, action_dim: int): + super().__init__() + self.distribution: Optional[MaskableCategorical] = None + self.action_dim = action_dim + + def proba_distribution_net(self, latent_dim: int) -> nn.Module: + """ + Create the layer that represents the distribution: + it will be the logits of the Categorical distribution. + You can then get probabilities using a softmax. + + :param latent_dim: Dimension of the last layer + of the policy network (before the action layer) + :return: + """ + action_logits = nn.Linear(latent_dim, self.action_dim) + return action_logits + + def proba_distribution(self, action_logits: th.Tensor) -> "MaskableCategoricalDistribution": + # Restructure shape to align with logits + reshaped_logits = action_logits.view(-1, self.action_dim) + self.distribution = MaskableCategorical(logits=reshaped_logits) + return self + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + assert self.distribution is not None, "Must set distribution parameters" + return self.distribution.log_prob(actions) + + def entropy(self) -> th.Tensor: + assert self.distribution is not None, "Must set distribution parameters" + return self.distribution.entropy() + + def sample(self) -> th.Tensor: + assert self.distribution is not None, "Must set distribution parameters" + return self.distribution.sample() + + def mode(self) -> th.Tensor: + assert self.distribution is not None, "Must set distribution parameters" + return th.argmax(self.distribution.probs, dim=1) + + def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(action_logits) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + actions = self.actions_from_params(action_logits) + log_prob = self.log_prob(actions) + return actions, log_prob + + def apply_masking(self, masks: Optional[np.ndarray]) -> None: + assert self.distribution is not None, "Must set distribution parameters" + self.distribution.apply_masking(masks) + + +class MaskableMultiCategoricalDistribution(MaskableDistribution): + """ + MultiCategorical distribution for multi discrete actions. Supports invalid action masking. + + :param action_dims: List of sizes of discrete action spaces + """ + + def __init__(self, action_dims: List[int]): + super().__init__() + self.distributions: List[MaskableCategorical] = [] + self.action_dims = action_dims + + def proba_distribution_net(self, latent_dim: int) -> nn.Module: + """ + Create the layer that represents the distribution: + it will be the logits (flattened) of the MultiCategorical distribution. + You can then get probabilities using a softmax on each sub-space. + + :param latent_dim: Dimension of the last layer + of the policy network (before the action layer) + :return: + """ + + action_logits = nn.Linear(latent_dim, sum(self.action_dims)) + return action_logits + + def proba_distribution(self, action_logits: th.Tensor) -> "MaskableMultiCategoricalDistribution": + # Restructure shape to align with logits + reshaped_logits = action_logits.view(-1, sum(self.action_dims)) + + self.distributions = [ + MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1) + ] + return self + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + assert len(self.distributions) > 0, "Must set distribution parameters" + + # Restructure shape to align with each categorical + actions = actions.view(-1, len(self.action_dims)) + + # Extract each discrete action and compute log prob for their respective distributions + return th.stack( + [dist.log_prob(action) for dist, action in zip(self.distributions, th.unbind(actions, dim=1))], dim=1 + ).sum(dim=1) + + def entropy(self) -> th.Tensor: + assert len(self.distributions) > 0, "Must set distribution parameters" + return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1) + + def sample(self) -> th.Tensor: + assert len(self.distributions) > 0, "Must set distribution parameters" + return th.stack([dist.sample() for dist in self.distributions], dim=1) + + def mode(self) -> th.Tensor: + assert len(self.distributions) > 0, "Must set distribution parameters" + return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1) + + def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(action_logits) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + actions = self.actions_from_params(action_logits) + log_prob = self.log_prob(actions) + return actions, log_prob + + def apply_masking(self, masks: Optional[np.ndarray]) -> None: + assert len(self.distributions) > 0, "Must set distribution parameters" + + split_masks = [None] * len(self.distributions) + if masks is not None: + masks = th.as_tensor(masks) + + # Restructure shape to align with logits + masks = masks.view(-1, sum(self.action_dims)) + + # Then split columnwise for each discrete action + split_masks = th.split(masks, tuple(self.action_dims), dim=1) + + for distribution, mask in zip(self.distributions, split_masks): + distribution.apply_masking(mask) + + +class MaskableBernoulliDistribution(MaskableMultiCategoricalDistribution): + """ + Bernoulli distribution for multibinary actions. Supports invalid action masking. + + :param action_dim: Number of binary actions + """ + + def __init__(self, action_dim: int): + # Two states per binary action + action_dims = [2] * action_dim + super().__init__(action_dims) + + +def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistribution: + """ + Return an instance of Distribution for the correct type of action space + + :param action_space: the input action space + :return: the appropriate Distribution object + """ + + if isinstance(action_space, spaces.Discrete): + return MaskableCategoricalDistribution(action_space.n) + elif isinstance(action_space, spaces.MultiDiscrete): + return MaskableMultiCategoricalDistribution(action_space.nvec) + elif isinstance(action_space, spaces.MultiBinary): + return MaskableBernoulliDistribution(action_space.n) + else: + raise NotImplementedError( + "Error: probability distribution, not implemented for action space" + f"of type {type(action_space)}." + " Must be of type Gym Spaces: Discrete, MultiDiscrete." + ) diff --git a/sb3_contrib/common/maskable/evaluation.py b/sb3_contrib/common/maskable/evaluation.py new file mode 100644 index 0000000..82b8cc3 --- /dev/null +++ b/sb3_contrib/common/maskable/evaluation.py @@ -0,0 +1,148 @@ +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import gym +import numpy as np +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped + +from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported +from sb3_contrib.ppo_mask import MaskablePPO + + +def evaluate_policy( # noqa: C901 + model: MaskablePPO, + env: Union[gym.Env, VecEnv], + n_eval_episodes: int = 10, + deterministic: bool = True, + render: bool = False, + callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None, + reward_threshold: Optional[float] = None, + return_episode_rewards: bool = False, + warn: bool = True, + use_masking: bool = True, +) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: + """ + Runs policy for ``n_eval_episodes`` episodes and returns average reward. + If a vector env is passed in, this divides the episodes to evaluate onto the + different elements of the vector env. This static division of work is done to + remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more + details and discussion. + + .. note:: + If environment has not been wrapped with ``Monitor`` wrapper, reward and + episode lengths are counted as it appears with ``env.step`` calls. If + the environment contains wrappers that modify rewards or episode lengths + (e.g. reward scaling, early episode reset), these will affect the evaluation + results as well. You can avoid this by wrapping environment with ``Monitor`` + wrapper before anything else. + + :param model: The RL agent you want to evaluate. + :param env: The gym environment. In the case of a ``VecEnv`` + this must contain only one environment. + :param n_eval_episodes: Number of episode to evaluate the agent + :param deterministic: Whether to use deterministic or stochastic actions + :param render: Whether to render the environment or not + :param callback: callback function to do additional checks, + called after each step. Gets locals() and globals() passed as parameters. + :param reward_threshold: Minimum expected reward per episode, + this will raise an error if the performance is not met + :param return_episode_rewards: If True, a list of rewards and episde lengths + per episode will be returned instead of the mean. + :param warn: If True (default), warns user about lack of a Monitor wrapper in the + evaluation environment. + :param use_masking: Whether or not to use invalid action masks during evaluation + :return: Mean reward per episode, std of reward per episode. + Returns ([float], [int]) when ``return_episode_rewards`` is True, first + list containing per-episode rewards and second containing per-episode lengths + (in number of steps). + """ + + if use_masking and not is_masking_supported(env): + raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper") + + is_monitor_wrapped = False + + if not isinstance(env, VecEnv): + env = DummyVecEnv([lambda: env]) + + is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0] + + if not is_monitor_wrapped and warn: + warnings.warn( + "Evaluation environment is not wrapped with a ``Monitor`` wrapper. " + "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. " + "Consider wrapping environment first with ``Monitor`` wrapper.", + UserWarning, + ) + + n_envs = env.num_envs + episode_rewards = [] + episode_lengths = [] + + episode_counts = np.zeros(n_envs, dtype="int") + # Divides episodes among different sub environments in the vector as evenly as possible + episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int") + + current_rewards = np.zeros(n_envs) + current_lengths = np.zeros(n_envs, dtype="int") + observations = env.reset() + states = None + + while (episode_counts < episode_count_targets).any(): + if use_masking: + action_masks = get_action_masks(env) + actions, state = model.predict( + observations, + state=states, + deterministic=deterministic, + action_masks=action_masks, + ) + else: + actions, states = model.predict(observations, state=states, deterministic=deterministic) + observations, rewards, dones, infos = env.step(actions) + current_rewards += rewards + current_lengths += 1 + for i in range(n_envs): + if episode_counts[i] < episode_count_targets[i]: + + # unpack values so that the callback can access the local variables + reward = rewards[i] + done = dones[i] + info = infos[i] + + if callback is not None: + callback(locals(), globals()) + + if dones[i]: + if is_monitor_wrapped: + # Atari wrapper can send a "done" signal when + # the agent loses a life, but it does not correspond + # to the true end of episode + if "episode" in info.keys(): + # Do not trust "done" with episode endings. + # Monitor wrapper includes "episode" key in info if environment + # has been wrapped with it. Use those rewards instead. + episode_rewards.append(info["episode"]["r"]) + episode_lengths.append(info["episode"]["l"]) + # Only increment at the real end of an episode + episode_counts[i] += 1 + else: + episode_rewards.append(current_rewards[i]) + episode_lengths.append(current_lengths[i]) + episode_counts[i] += 1 + current_rewards[i] = 0 + current_lengths[i] = 0 + if states is not None: + states[i] *= 0 + + if render: + env.render() + + mean_reward = np.mean(episode_rewards) + std_reward = np.std(episode_rewards) + if reward_threshold is not None: + assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}" + if return_episode_rewards: + return episode_rewards, episode_lengths + return mean_reward, std_reward diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py new file mode 100644 index 0000000..d7c9522 --- /dev/null +++ b/sb3_contrib/common/maskable/policies.py @@ -0,0 +1,402 @@ +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch as th +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, +) +from stable_baselines3.common.type_aliases import Schedule +from torch import nn + +from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution + + +class MaskableActorCriticPolicy(BasePolicy): + """ + Policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + + if optimizer_kwargs is None: + optimizer_kwargs = {} + # Small values to avoid NaN in Adam optimizer + if optimizer_class == th.optim.Adam: + optimizer_kwargs["eps"] = 1e-5 + + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=False, + ) + + # Default network architecture, from stable-baselines + if net_arch is None: + if features_extractor_class == NatureCNN: + net_arch = [] + else: + net_arch = [dict(pi=[64, 64], vf=[64, 64])] + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.ortho_init = ortho_init + + self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + self.features_dim = self.features_extractor.features_dim + + self.normalize_images = normalize_images + + # Action distribution + self.action_dist = make_masked_proba_distribution(action_space) + + self._build(lr_schedule) + + def forward( + self, + obs: th.Tensor, + deterministic: bool = False, + action_masks: Optional[np.ndarray] = None, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation + :param deterministic: Whether to sample or use deterministic actions + :param action_masks: Action masks to apply to the action distribution + :return: action, value and log probability of the action + """ + latent_pi, latent_vf = self._get_latent(obs) + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + if action_masks is not None: + distribution.apply_masking(action_masks) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + data.update( + dict( + net_arch=self.net_arch, + activation_fn=self.activation_fn, + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + ortho_init=self.ortho_init, + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + # Note: If net_arch is None and some features extractor is used, + # net_arch here is an empty list and mlp_extractor does not + # really contain any layers (acts like an identity module). + self.mlp_extractor = MlpExtractor( + self.features_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + def _build(self, lr_schedule: Schedule) -> None: + """ + Create the networks and the optimizer. + + :param lr_schedule: Learning rate schedule + lr_schedule(1) is the initial learning rate + """ + self._build_mlp_extractor() + + self.action_net = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi) + self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) + + # Init weights: use orthogonal initialization + # with small initial weight for the output + if self.ortho_init: + # TODO: check for features_extractor + # Values from stable-baselines. + # features_extractor/mlp values are + # originally from openai/baselines (default gains/init_scales). + module_gains = { + self.features_extractor: np.sqrt(2), + self.mlp_extractor: np.sqrt(2), + self.action_net: 0.01, + self.value_net: 1, + } + for module, gain in module_gains.items(): + module.apply(partial(self.init_weights, gain=gain)) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + """ + Get the latent code (i.e., activations of the last layer of each network) + for the different networks. + + :param obs: Observation + :return: Latent codes + for the actor, the value function and for gSDE function + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + latent_pi, latent_vf = self.mlp_extractor(features) + + return latent_pi, latent_vf + + def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution: + """ + Retrieve action distribution given the latent codes. + + :param latent_pi: Latent code for the actor + :return: Action distribution + """ + action_logits = self.action_net(latent_pi) + return self.action_dist.proba_distribution(action_logits=action_logits) + + def _predict( + self, + observation: th.Tensor, + deterministic: bool = False, + action_masks: Optional[np.ndarray] = None, + ) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :param action_masks: Action masks to apply to the action distribution + :return: Taken action according to the policy + """ + latent_pi, _ = self._get_latent(observation) + distribution = self._get_action_dist_from_latent(latent_pi) + if action_masks is not None: + distribution.apply_masking(action_masks) + return distribution.get_actions(deterministic=deterministic) + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[np.ndarray] = None, + mask: Optional[np.ndarray] = None, + deterministic: bool = False, + action_masks: Optional[np.ndarray] = None, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Get the policy action and state from an observation (and optional state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last states (can be None, used in recurrent policies) + :param mask: The last masks (can be None, used in recurrent policies) + :param deterministic: Whether or not to return deterministic actions. + :param action_masks: Action masks to apply to the action distribution + :return: the model's action and the next state + (used in recurrent policies) + """ + # TODO (GH/1): add support for RNN policies + # if state is None: + # state = self.initial_state + # if mask is None: + # mask = [False for _ in range(self.n_envs)] + + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + observation, vectorized_env = self.obs_to_tensor(observation) + + with th.no_grad(): + actions = self._predict(observation, deterministic=deterministic, action_masks=action_masks) + # Convert to numpy + actions = actions.cpu().numpy() + + if isinstance(self.action_space, gym.spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + if not vectorized_env: + if state is not None: + raise ValueError("Error: The environment must be vectorized when using recurrent policies.") + actions = actions[0] + + return actions, state + + def evaluate_actions( + self, + obs: th.Tensor, + actions: th.Tensor, + action_masks: Optional[np.ndarray] = None, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: + :param actions: + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + latent_pi, latent_vf = self._get_latent(obs) + distribution = self._get_action_dist_from_latent(latent_pi) + if action_masks is not None: + distribution.apply_masking(action_masks) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + + +class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy): + """ + CNN policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(MaskableActorCriticCnnPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy): + """ + MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space (Tuple) + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param features_extractor_class: Uses the CombinedExtractor + :param features_extractor_kwargs: Keyword arguments + to pass to the feature extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Dict, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(MaskableMultiInputActorCriticPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) diff --git a/sb3_contrib/common/maskable/utils.py b/sb3_contrib/common/maskable/utils.py new file mode 100644 index 0000000..b4f0990 --- /dev/null +++ b/sb3_contrib/common/maskable/utils.py @@ -0,0 +1,38 @@ +import numpy as np +from stable_baselines3.common.type_aliases import GymEnv +from stable_baselines3.common.vec_env import VecEnv + +EXPECTED_METHOD_NAME = "action_masks" + + +def get_action_masks(env: GymEnv) -> np.ndarray: + """ + Checks whether gym env exposes a method returning invalid action masks + + :param env: the Gym environment to get masks from + :return: A numpy array of the masks + """ + + if isinstance(env, VecEnv): + return np.stack(env.env_method(EXPECTED_METHOD_NAME)) + else: + return getattr(env, EXPECTED_METHOD_NAME)() + + +def is_masking_supported(env: GymEnv) -> bool: + """ + Checks whether gym env exposes a method returning invalid action masks + + :param env: the Gym environment to check + :return: True if the method is found, False otherwise + """ + + if isinstance(env, VecEnv): + try: + # TODO: add VecEnv.has_attr() + env.get_attr(EXPECTED_METHOD_NAME) + return True + except AttributeError: + return False + else: + return hasattr(env, EXPECTED_METHOD_NAME) diff --git a/sb3_contrib/common/wrappers/__init__.py b/sb3_contrib/common/wrappers/__init__.py index 5e59c76..8b591b4 100644 --- a/sb3_contrib/common/wrappers/__init__.py +++ b/sb3_contrib/common/wrappers/__init__.py @@ -1 +1,2 @@ +from sb3_contrib.common.wrappers.action_masker import ActionMasker from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper diff --git a/sb3_contrib/common/wrappers/action_masker.py b/sb3_contrib/common/wrappers/action_masker.py new file mode 100644 index 0000000..5850f2c --- /dev/null +++ b/sb3_contrib/common/wrappers/action_masker.py @@ -0,0 +1,32 @@ +from typing import Callable, Union + +import gym +import numpy as np + + +class ActionMasker(gym.Wrapper): + """ + Env wrapper providing the method required to support masking. + + Exposes a method called action_masks(), which returns masks for the wrapped env. + This wrapper is not needed if the env exposes the expected method itself. + + :param env: the Gym environment to wrap + :param action_mask_fn: A function that takes a Gym environment and returns an action mask, + or the name of such a method provided by the environment. + """ + + def __init__(self, env: gym.Env, action_mask_fn: Union[str, Callable[[gym.Env], np.ndarray]]): + super().__init__(env) + + if isinstance(action_mask_fn, str): + found_method = getattr(self.env, action_mask_fn) + if not callable(found_method): + raise ValueError(f"Environment attribute {action_mask_fn} is not a method") + + self._action_mask_fn = found_method + else: + self._action_mask_fn = action_mask_fn + + def action_masks(self) -> np.ndarray: + return self._action_mask_fn(self.env) diff --git a/sb3_contrib/ppo_mask/__init__.py b/sb3_contrib/ppo_mask/__init__.py new file mode 100644 index 0000000..014cd93 --- /dev/null +++ b/sb3_contrib/ppo_mask/__init__.py @@ -0,0 +1,2 @@ +from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy # , MultiInputPolicy +from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO diff --git a/sb3_contrib/ppo_mask/policies.py b/sb3_contrib/ppo_mask/policies.py new file mode 100644 index 0000000..9164363 --- /dev/null +++ b/sb3_contrib/ppo_mask/policies.py @@ -0,0 +1,15 @@ +from stable_baselines3.common.policies import register_policy + +from sb3_contrib.common.maskable.policies import ( # MaskableMultiInputActorCriticPolicy, + MaskableActorCriticCnnPolicy, + MaskableActorCriticPolicy, +) + +MlpPolicy = MaskableActorCriticPolicy +CnnPolicy = MaskableActorCriticCnnPolicy +# MultiInputPolicy = MaskableMultiInputActorCriticPolicy + +register_policy("MlpPolicy", MaskableActorCriticPolicy) +register_policy("CnnPolicy", MaskableActorCriticCnnPolicy) +# Currently not supported +# register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy) diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py new file mode 100644 index 0000000..be9fb12 --- /dev/null +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -0,0 +1,551 @@ +import time +from collections import deque +from typing import Any, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common import utils +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean +from stable_baselines3.common.vec_env import VecEnv +from torch.nn import functional as F + +from sb3_contrib.common.maskable.buffers import MaskableRolloutBuffer +from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy +from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported + + +class MaskablePPO(OnPolicyAlgorithm): + """ + Proximal Policy Optimization algorithm (PPO) (clip version) with Invalid Action Masking. + + Based on the original Stable Baselines 3 implementation. + + Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html + Background on Invalid Action Masking: https://arxiv.org/abs/2006.14171 + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) + :param batch_size: Minibatch size + :param n_epochs: Number of epoch when optimizing the surrogate loss + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param clip_range: Clipping parameter, it can be a function of the current progress + remaining (from 1 to 0). + :param clip_range_vf: Clipping parameter for the value function, + it can be a function of the current progress remaining (from 1 to 0). + This is a parameter specific to the OpenAI implementation. If None is passed (default), + no clipping will be done on the value function. + IMPORTANT: this clipping depends on the reward scaling. + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param target_kl: Limit the KL divergence between updates, + because the clipping is not enough to prevent large update + see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + By default, there is no limit on the kl div. + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + def __init__( + self, + policy: Union[str, Type[MaskableActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 2048, + batch_size: Optional[int] = 64, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + target_kl: Optional[float] = None, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + super().__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=False, + sde_sample_freq=-1, + tensorboard_log=tensorboard_log, + create_eval_env=create_eval_env, + policy_kwargs=policy_kwargs, + policy_base=MaskableActorCriticPolicy, + verbose=verbose, + seed=seed, + device=device, + _init_setup_model=False, + supported_action_spaces=( + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + ) + + self.batch_size = batch_size + self.n_epochs = n_epochs + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + self.target_kl = target_kl + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + self.policy = self.policy_class( + self.observation_space, + self.action_space, + self.lr_schedule, + **self.policy_kwargs, # pytype:disable=not-instantiable + ) + self.policy = self.policy.to(self.device) + + if not isinstance(self.policy, MaskableActorCriticPolicy): + raise ValueError("Policy must subclass MaskableActorCriticPolicy") + + self.rollout_buffer = MaskableRolloutBuffer( + self.n_steps, + self.observation_space, + self.action_space, + self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + ) + + # Initialize schedules for policy/value clipping + self.clip_range = get_schedule_fn(self.clip_range) + if self.clip_range_vf is not None: + if isinstance(self.clip_range_vf, (float, int)): + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" + + self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + + def _init_callback( + self, + callback: MaybeCallback, + eval_env: Optional[VecEnv] = None, + eval_freq: int = 10000, + n_eval_episodes: int = 5, + log_path: Optional[str] = None, + use_masking: bool = True, + ) -> BaseCallback: + """ + :param callback: Callback(s) called at every step with state of the algorithm. + :param eval_freq: How many steps between evaluations; if None, do not evaluate. + :param n_eval_episodes: How many episodes to play per evaluation + :param n_eval_episodes: Number of episodes to rollout during evaluation. + :param log_path: Path to a folder where the evaluations will be saved + :param use_masking: Whether or not to use invalid action masks during evaluation + :return: A hybrid callback calling `callback` and performing evaluation. + """ + # Convert a list of callbacks into a callback + if isinstance(callback, list): + callback = CallbackList(callback) + + # Convert functional callback to object + if not isinstance(callback, BaseCallback): + callback = ConvertCallback(callback) + + # Create eval callback in charge of the evaluation + if eval_env is not None: + # Avoid circular import error + from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback + + eval_callback = MaskableEvalCallback( + eval_env, + best_model_save_path=log_path, + log_path=log_path, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + use_masking=use_masking, + ) + callback = CallbackList([callback, eval_callback]) + + callback.init_callback(self) + return callback + + def _setup_learn( + self, + total_timesteps: int, + eval_env: Optional[GymEnv], + callback: MaybeCallback = None, + eval_freq: int = 10000, + n_eval_episodes: int = 5, + log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + tb_log_name: str = "run", + use_masking: bool = True, + ) -> Tuple[int, BaseCallback]: + """ + Initialize different variables needed for training. + + :param total_timesteps: The total number of samples (env steps) to train on + :param eval_env: Environment to use for evaluation. + :param callback: Callback(s) called at every step with state of the algorithm. + :param eval_freq: How many steps between evaluations + :param n_eval_episodes: How many episodes to play per evaluation + :param log_path: Path to a folder where the evaluations will be saved + :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute + :param tb_log_name: the name of the run for tensorboard log + :param use_masking: Whether or not to use invalid action masks during training + :return: + """ + + self.start_time = time.time() + if self.ep_info_buffer is None or reset_num_timesteps: + # Initialize buffers if they don't exist, or reinitialize if resetting counters + self.ep_info_buffer = deque(maxlen=100) + self.ep_success_buffer = deque(maxlen=100) + + if reset_num_timesteps: + self.num_timesteps = 0 + self._episode_num = 0 + else: + # Make sure training timesteps are ahead of the internal counter + total_timesteps += self.num_timesteps + self._total_timesteps = total_timesteps + + # Avoid resetting the environment when calling ``.learn()`` consecutive times + if reset_num_timesteps or self._last_obs is None: + self._last_obs = self.env.reset() + self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) + # Retrieve unnormalized observation for saving into the buffer + if self._vec_normalize_env is not None: + self._last_original_obs = self._vec_normalize_env.get_original_obs() + + if eval_env is not None and self.seed is not None: + eval_env.seed(self.seed) + + eval_env = self._get_eval_env(eval_env) + + # Configure logger's outputs if no logger was passed + if not self._custom_logger: + self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps) + + # Create eval callback if needed + callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking) + + return total_timesteps, callback + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + use_masking: bool = True, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + This method is largely identical to the implementation found in the parent class. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_steps: Number of experiences to collect per environment + :param use_masking: Whether or not to use invalid action masks during training + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + + assert isinstance(rollout_buffer, MaskableRolloutBuffer), "RolloutBuffer doesn't support action masking" + assert self._last_obs is not None, "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + n_steps = 0 + action_masks = None + rollout_buffer.reset() + + if use_masking and not is_masking_supported(env): + raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper") + + callback.on_rollout_start() + + while n_steps < n_rollout_steps: + with th.no_grad(): + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) + + # This is the only change related to invalid action masking + if use_masking: + action_masks = get_action_masks(env) + + actions, values, log_probs = self.policy.forward(obs_tensor, action_masks=action_masks) + + actions = actions.cpu().numpy() + new_obs, rewards, dones, infos = env.step(actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if callback.on_step() is False: + return False + + self._update_info_buffer(infos) + n_steps += 1 + + if isinstance(self.action_space, spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) + + rollout_buffer.add( + self._last_obs, + actions, + rewards, + self._last_episode_starts, + values, + log_probs, + action_masks=action_masks, + ) + self._last_obs = new_obs + self._last_episode_starts = dones + + with th.no_grad(): + # Compute value for the last timestep + obs_tensor = obs_as_tensor(new_obs, self.device) + + # Masking is not needed here, the choice of action doesn't matter. + # We only want the value of the current observation. + _, values, _ = self.policy.forward(obs_tensor) + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + + callback.on_rollout_end() + + return True + + def predict( + self, + observation: np.ndarray, + state: Optional[np.ndarray] = None, + mask: Optional[np.ndarray] = None, + deterministic: bool = False, + action_masks: Optional[np.ndarray] = None, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Get the model's action(s) from an observation. + + :param observation: the input observation + :param state: The last states (can be None, used in recurrent policies) + :param mask: The last masks (can be None, used in recurrent policies) + :param deterministic: Whether or not to return deterministic actions. + :param action_masks: Action masks to apply to the action distribution. + :return: the model's action and the next state (used in recurrent policies) + """ + return self.policy.predict(observation, state, mask, deterministic, action_masks=action_masks) + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range(self._current_progress_remaining) + # Optional: clip range for the value function + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + + entropy_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + + continue_training = True + + # train for n_epochs epochs + for epoch in range(self.n_epochs): + approx_kl_divs = [] + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + values, log_prob, entropy = self.policy.evaluate_actions( + rollout_data.observations, + actions, + action_masks=rollout_data.action_masks, + ) + + values = values.flatten() + # Normalize advantage + advantages = rollout_data.advantages + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) + + # clipped surrogate loss + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() + + # Logging + pg_losses.append(policy_loss.item()) + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + clip_fractions.append(clip_fraction) + + if self.clip_range_vf is None: + # No clipping + values_pred = values + else: + # Clip the different between old and new value + # NOTE: this depends on the reward scaling + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(rollout_data.returns, values_pred) + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -th.mean(-log_prob) + else: + entropy_loss = -th.mean(entropy) + + entropy_losses.append(entropy_loss.item()) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with th.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + if not continue_training: + break + + self._n_updates += self.n_epochs + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + # Logs + self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/loss", loss.item()) + self.logger.record("train/explained_variance", explained_var) + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + if self.clip_range_vf is not None: + self.logger.record("train/clip_range_vf", clip_range_vf) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "PPO", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + use_masking: bool = True, + ) -> "MaskablePPO": + iteration = 0 + + total_timesteps, callback = self._setup_learn( + total_timesteps, + eval_env, + callback, + eval_freq, + n_eval_episodes, + eval_log_path, + reset_num_timesteps, + tb_log_name, + use_masking, + ) + + callback.on_training_start(locals(), globals()) + + while self.num_timesteps < total_timesteps: + continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking) + + if continue_training is False: + break + + iteration += 1 + self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + + # Display training infos + if log_interval is not None and iteration % log_interval == 0: + fps = int(self.num_timesteps / (time.time() - self.start_time)) + self.logger.record("time/iterations", iteration, exclude="tensorboard") + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(step=self.num_timesteps) + + self.train() + + callback.on_training_end() + + return self diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 26aaba0..348e216 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.2.0 +1.2.1a0 diff --git a/setup.cfg b/setup.cfg index 3ddbd03..cf162f7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,9 +22,12 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators # Ignore import not used when aliases are defined per-file-ignores = ./sb3_contrib/__init__.py:F401 + ./sb3_contrib/ppo_mask/__init__.py:F401 ./sb3_contrib/qrdqn/__init__.py:F401 ./sb3_contrib/tqc/__init__.py:F401 + ./sb3_contrib/common/vec_env/wrappers/__init__.py:F401 ./sb3_contrib/common/wrappers/__init__.py:F401 + ./sb3_contrib/common/envs/__init__.py:F401 exclude = # No need to traverse our git directory .git, diff --git a/tests/test_distributions.py b/tests/test_distributions.py new file mode 100644 index 0000000..a54a9a5 --- /dev/null +++ b/tests/test_distributions.py @@ -0,0 +1,316 @@ +import numpy as np +import pytest +import torch as th + +from sb3_contrib.common.maskable.distributions import ( + MaskableBernoulliDistribution, + MaskableCategorical, + MaskableCategoricalDistribution, + MaskableMultiCategoricalDistribution, +) + + +class TestMaskableCategorical: + def test_applying_mask(self): + """ + Show that probs change as a result of masking + """ + + starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]]) + expected_probs = th.Tensor([[0, 0.25, 0.75], [0, 0.5, 0.5]]) + mask = np.array([[False, True, True], [False, True, True]]) + + distribution = MaskableCategorical(probs=starting_probs) + distribution.apply_masking(mask) + assert th.allclose(distribution.probs, expected_probs) + + def test_modifying_mask(self): + """ + Show that masks apply independently of each other + """ + + starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]]) + expected_probs = th.Tensor([[0.5, 0.5, 0], [0, 1, 0]]) + first_mask = np.array([[False, True, True], [False, True, True]]) + second_mask = np.array([[True, True, False], [False, True, False]]) + + # pytorch converts probs to logits in a way that loses some precision and makes + # 0 probability outcomes slightly non-zero. + atol = 2e-07 + distribution = MaskableCategorical(probs=starting_probs) + assert th.allclose(distribution.probs, starting_probs, atol=atol) + + target_distribution = MaskableCategorical(probs=expected_probs) + + distribution.apply_masking(first_mask) + distribution.apply_masking(second_mask) + + assert th.allclose(distribution.probs, target_distribution.probs, atol=atol) + + def test_removing_mask(self): + """ + Show that masking may be unapplied to recover original probs + """ + + starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]]) + mask = np.array([[False, True, True], [False, True, True]]) + + distribution = MaskableCategorical(probs=starting_probs) + target_distribution = MaskableCategorical(probs=starting_probs) + distribution.apply_masking(mask) + distribution.apply_masking(None) + assert th.allclose(distribution.probs, target_distribution.probs) + + def test_masking_affects_entropy(self): + # All outcomes equally likely + NUM_DIMS = 3 + logits = th.Tensor([[0] * NUM_DIMS]) + dist = MaskableCategorical(logits=logits) + + # For each possible number of valid actions v, show that e^entropy == v + for v in range(1, NUM_DIMS + 1): + masks = [j < v for j in range(NUM_DIMS)] + dist.apply_masking(masks) + assert int(dist.entropy().exp()) == v + + +class TestMaskableCategoricalDistribution: + def test_distribution_must_be_initialized(self): + """ + Cannot use distribution before it has logits + """ + + DIMS = 2 + dist = MaskableCategoricalDistribution(DIMS) + with pytest.raises(AssertionError): + dist.log_prob(th.randint(DIMS - 1, (1, 3))) + + with pytest.raises(AssertionError): + dist.entropy() + + with pytest.raises(AssertionError): + dist.sample() + + with pytest.raises(AssertionError): + dist.mode() + + with pytest.raises(AssertionError): + dist.apply_masking(None) + + # But now we can + dist.proba_distribution(th.randn(1, DIMS)) + dist.log_prob(th.randint(DIMS - 1, (3, 1))) + dist.entropy() + dist.sample() + dist.mode() + dist.apply_masking(None) + + def test_logits_must_align_with_dims(self): + NUM_DIMS = 3 + dist = MaskableCategoricalDistribution(NUM_DIMS) + + # There should be one logit per dim, we're one short + logits = th.randn(1, NUM_DIMS - 1) + with pytest.raises(RuntimeError): + dist.proba_distribution(logits) + + # That's better + logits = th.randn(1, NUM_DIMS) + dist.proba_distribution(logits) + + # Other numbers of dimensions are acceptable as long as they can be realigned + logits = th.randn(NUM_DIMS) + dist.proba_distribution(logits) + logits = th.randn(3, NUM_DIMS, 3) + dist.proba_distribution(logits) + + def test_dim_masking(self): + NUM_DIMS = 2 + dist = MaskableCategoricalDistribution(NUM_DIMS) + + logits = th.Tensor([[0] * NUM_DIMS]) + dist.proba_distribution(logits) + + assert (dist.distribution.probs == 0.5).all() + assert int(dist.entropy().exp()) == NUM_DIMS + + for i in range(NUM_DIMS): + mask = np.array([False] * NUM_DIMS) + mask[i] = True + dist.apply_masking(mask) + probs = dist.distribution.probs + assert probs.sum() == 1 + assert probs[0][i] == 1 + assert int(dist.entropy().exp()) == 1 + + dist.apply_masking(None) + assert (dist.distribution.probs == 0.5).all() + assert int(dist.entropy().exp()) == NUM_DIMS + + +class TestMaskableMultiCategoricalDistribution: + def test_distribution_must_be_initialized(self): + """ + Cannot use distribution before it has logits + """ + + DIMS_PER_CAT = 2 + NUM_CATS = 2 + dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS) + + with pytest.raises(AssertionError): + dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS))) + + with pytest.raises(AssertionError): + dist.entropy() + + with pytest.raises(AssertionError): + dist.sample() + + with pytest.raises(AssertionError): + dist.mode() + + with pytest.raises(AssertionError): + dist.apply_masking(None) + + # But now we can + dist.proba_distribution(th.randn(1, DIMS_PER_CAT * NUM_CATS)) + dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS))) + dist.entropy() + dist.sample() + dist.mode() + dist.apply_masking(None) + + def test_logits_must_align_with_dims(self): + DIMS_PER_CAT = 3 + NUM_CATS = 2 + dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS) + + # There should be one logit per dim, we're one short + logits = th.randn(1, DIMS_PER_CAT * NUM_CATS - 1) + with pytest.raises(RuntimeError): + dist.proba_distribution(logits) + + # That's better + logits = th.randn(1, DIMS_PER_CAT * NUM_CATS) + dist.proba_distribution(logits) + + # Other numbers of dimensions are acceptable as long as they can be realigned + logits = th.randn(DIMS_PER_CAT * NUM_CATS) + dist.proba_distribution(logits) + logits = th.randn(3, DIMS_PER_CAT * NUM_CATS, 3) + dist.proba_distribution(logits) + + def test_dim_masking(self): + DIMS_PER_CAT = 2 + NUM_CATS = 3 + dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS) + + logits = th.Tensor([[0] * DIMS_PER_CAT * NUM_CATS]) + dist.proba_distribution(logits) + + assert len(dist.distributions) == NUM_CATS + for i in range(NUM_CATS): + assert (dist.distributions[i].probs == 0.5).all() + assert int(dist.entropy().exp()) == DIMS_PER_CAT ** NUM_CATS + + for i in range(DIMS_PER_CAT): + mask = np.array([False] * DIMS_PER_CAT * NUM_CATS) + for j in range(NUM_CATS): + mask[j * DIMS_PER_CAT + i] = True + + dist.apply_masking(mask) + for j in range(NUM_CATS): + probs = dist.distributions[j].probs + assert probs.sum() == 1 + assert probs[0][i] == 1 + + assert int(dist.entropy().exp()) == 1 + + dist.apply_masking(None) + for i in range(NUM_CATS): + assert (dist.distributions[i].probs == 0.5).all() + assert int(dist.entropy().exp()) == DIMS_PER_CAT ** NUM_CATS + + +class TestMaskableBernoulliDistribution: + def test_distribution_must_be_initialized(self): + """ + Cannot use distribution before it has logits + """ + + DIMS = 2 + dist = MaskableBernoulliDistribution(DIMS) + + with pytest.raises(AssertionError): + dist.log_prob(th.randint(1, (2, DIMS))) + + with pytest.raises(AssertionError): + dist.entropy() + + with pytest.raises(AssertionError): + dist.sample() + + with pytest.raises(AssertionError): + dist.mode() + + with pytest.raises(AssertionError): + dist.apply_masking(None) + + # But now we can + dist.proba_distribution(th.randn(1, 2 * DIMS)) + dist.log_prob(th.randint(1, (2, DIMS))) + dist.entropy() + dist.sample() + dist.mode() + dist.apply_masking(None) + + def test_logits_must_align_with_dims(self): + NUM_DIMS = 3 + dist = MaskableBernoulliDistribution(NUM_DIMS) + + # There should be two logits per dim, we're one short + logits = th.randn(1, 2 * NUM_DIMS - 1) + with pytest.raises(RuntimeError): + dist.proba_distribution(logits) + + # That's better + logits = th.randn(1, 2 * NUM_DIMS) + dist.proba_distribution(logits) + + # Other numbers of dimensions are acceptable as long as they can be realigned + logits = th.randn(2 * NUM_DIMS) + dist.proba_distribution(logits) + logits = th.randn(3, 2 * NUM_DIMS, 3) + dist.proba_distribution(logits) + + def test_dim_masking(self): + NUM_DIMS = 2 + BINARY_STATES = 2 + dist = MaskableBernoulliDistribution(NUM_DIMS) + + logits = th.Tensor([[0] * BINARY_STATES * NUM_DIMS]) + dist.proba_distribution(logits) + + assert len(dist.distributions) == NUM_DIMS + for i in range(NUM_DIMS): + assert (dist.distributions[i].probs == 0.5).all() + assert int(dist.entropy().exp()) == BINARY_STATES * NUM_DIMS + + for i in range(BINARY_STATES): + mask = np.array([False] * BINARY_STATES * NUM_DIMS) + for j in range(NUM_DIMS): + mask[j * BINARY_STATES + i] = True + + dist.apply_masking(mask) + for j in range(NUM_DIMS): + probs = dist.distributions[j].probs + assert probs.sum() == 1 + assert probs[0][i] == 1 + + assert int(dist.entropy().exp()) == 1 + + dist.apply_masking(None) + for i in range(NUM_DIMS): + assert (dist.distributions[i].probs == 0.5).all() + assert int(dist.entropy().exp()) == BINARY_STATES * NUM_DIMS diff --git a/tests/test_invalid_actions.py b/tests/test_invalid_actions.py new file mode 100644 index 0000000..89245d9 --- /dev/null +++ b/tests/test_invalid_actions.py @@ -0,0 +1,232 @@ +import random + +import gym +import pytest +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.policies import ActorCriticPolicy + +from sb3_contrib import MaskablePPO +from sb3_contrib.common.envs import InvalidActionEnvDiscrete, InvalidActionEnvMultiBinary, InvalidActionEnvMultiDiscrete +from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback +from sb3_contrib.common.maskable.evaluation import evaluate_policy +from sb3_contrib.common.maskable.utils import is_masking_supported +from sb3_contrib.common.wrappers import ActionMasker + + +def make_env(): + return InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10) + + +class ToDictWrapper(gym.Wrapper): + """ + Simple wrapper to test MultInputPolicy on Dict obs. + """ + + def __init__(self, env): + super().__init__(env) + self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space}) + + def reset(self): + return {"obs": self.env.reset()} + + def step(self, action): + obs, reward, done, infos = self.env.step(action) + return {"obs": obs}, reward, done, infos + + +def test_identity(): + """ + Performance test. + A randomly initialized model cannot solve that task (score ~=6), + nor a model without invalid action masking (score ~=30 after training) + which such a low training budget. + """ + env = InvalidActionEnvDiscrete(dim=70, n_invalid_actions=55) + model = MaskablePPO( + "MlpPolicy", + env, + gamma=0.4, + seed=32, + verbose=0, + ) + model.learn(3000) + evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False) + + +def test_supports_discrete_action_space(): + """ + No errors using algorithm with an env that has a discrete action space + """ + + env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10) + model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8) + model.learn(100) + evaluate_policy(model, env, warn=False) + + # Mask all actions except the good one, a random model should succeed + env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19) + model = MaskablePPO("MlpPolicy", env, seed=8) + evaluate_policy(model, env, reward_threshold=99, warn=False) + + +def test_supports_multi_discrete_action_space(): + """ + No errors using algorithm with an env that has a multidiscrete action space + """ + + env = InvalidActionEnvMultiDiscrete(dims=[2, 3], n_invalid_actions=1) + model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8) + model.learn(100) + evaluate_policy(model, env, warn=False) + + # Mask all actions except the good ones, a random model should succeed + env = InvalidActionEnvMultiDiscrete(dims=[2, 3], n_invalid_actions=3) + model = MaskablePPO("MlpPolicy", env, seed=8) + evaluate_policy(model, env, reward_threshold=99, warn=False) + + +def test_supports_multi_binary_action_space(): + """ + No errors using algorithm with an env that has a multidiscrete action space + """ + + env = InvalidActionEnvMultiBinary(dims=3, n_invalid_actions=1) + model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8) + model.learn(100) + evaluate_policy(model, env, warn=False) + + # Mask all actions except the good ones, a random model should succeed + env = InvalidActionEnvMultiBinary(dims=3, n_invalid_actions=3) + model = MaskablePPO("MlpPolicy", env, seed=8) + evaluate_policy(model, env, reward_threshold=99, warn=False) + + +def test_disabling_masking(): + """ + Behave like normal PPO if masking is disabled, which allows for envs that don't provide masks + """ + + env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19) + model = MaskablePPO("MlpPolicy", env, seed=8) + evaluate_policy(model, env, reward_threshold=99, warn=False) + + # With masking disabled, perfect performance disappears + with pytest.raises(AssertionError): + evaluate_policy(model, env, reward_threshold=99, warn=False, use_masking=False) + + # Without masking disabled, learning/evaluation will fail if the env doesn't provide masks + env = IdentityEnv(dim=2) + model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8) + with pytest.raises(ValueError): + model.learn(100) + with pytest.raises(ValueError): + evaluate_policy(model, env, warn=False) + + model.learn(100, use_masking=False) + evaluate_policy(model, env, warn=False, use_masking=False) + + +def test_masked_evaluation(): + """ + Masking can be enabled or disabled for evaluation, but masking should perform better. + """ + + env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19) + model = MaskablePPO("MlpPolicy", env, seed=8) + masked_avg_rew, _ = evaluate_policy(model, env, warn=False) + unmasked_avg_rew, _ = evaluate_policy(model, env, warn=False, use_masking=False) + assert masked_avg_rew > unmasked_avg_rew + + +def test_supports_multi_envs(): + """ + Learning and evaluation works with VecEnvs + """ + + env = make_vec_env(make_env, n_envs=2) + assert is_masking_supported(env) + model = MaskablePPO("MlpPolicy", env, n_steps=256, gamma=0.4, seed=32, verbose=1) + model.learn(100) + evaluate_policy(model, env, warn=False) + + env = make_vec_env(IdentityEnv, n_envs=2, env_kwargs={"dim": 2}) + assert not is_masking_supported(env) + model = MaskablePPO("MlpPolicy", env, n_steps=256, gamma=0.4, seed=32, verbose=1) + with pytest.raises(ValueError): + model.learn(100) + with pytest.raises(ValueError): + evaluate_policy(model, env, warn=False) + model.learn(100, use_masking=False) + evaluate_policy(model, env, warn=False, use_masking=False) + + +def test_callback(tmp_path): + """ + No errors using MaskableEvalCallback during learning + """ + + env = make_env() + eval_env = make_env() + model = MaskablePPO("MlpPolicy", env, n_steps=64, gamma=0.4, seed=32, verbose=1) + model.learn(100, callback=MaskableEvalCallback(eval_env, eval_freq=100, warn=False, log_path=tmp_path)) + + model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False)) + + +def test_maskable_policy_required(): + """ + MaskablePPO requires a policy that subclasses MaskableActorCriticPolicy + """ + + env = make_env() + with pytest.raises(ValueError): + MaskablePPO(ActorCriticPolicy, env) + + +def test_discrete_action_space_required(): + """ + MaskablePPO requires an env with a discrete (ie non-continuous) action space + """ + + env = IdentityEnvBox() + with pytest.raises(AssertionError): + MaskablePPO("MlpPolicy", env) + + +def test_cnn(): + def action_mask_fn(env): + random_invalid_action = random.randrange(env.action_space.n) + return [i != random_invalid_action for i in range(env.action_space.n)] + + env = FakeImageEnv() + env = ActionMasker(env, action_mask_fn) + + model = MaskablePPO( + "CnnPolicy", + env, + n_steps=64, + seed=32, + verbose=1, + policy_kwargs=dict( + features_extractor_kwargs=dict(features_dim=32), + ), + ) + model.learn(100) + evaluate_policy(model, env, warn=False) + + +# Dict observations are currently not supported +# def test_dict_obs(): +# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10) +# env = ToDictWrapper(env) +# model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8) +# model.learn(100) +# evaluate_policy(model, env, warn=False) +# +# # Mask all actions except the good one, a random model should succeed +# env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=19) +# env = ToDictWrapper(env) +# model = MaskablePPO("MultiInputPolicy", env, seed=8) +# evaluate_policy(model, env, reward_threshold=99, warn=False) diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index cbab03f..025a143 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -1,3 +1,5 @@ +from typing import Union + import gym import numpy as np import pytest @@ -6,7 +8,9 @@ import torch.nn as nn from stable_baselines3.common.preprocessing import get_flattened_obs_dim from stable_baselines3.common.torch_layers import BaseFeaturesExtractor -from sb3_contrib import QRDQN, TQC +from sb3_contrib import QRDQN, TQC, MaskablePPO +from sb3_contrib.common.envs import InvalidActionEnvDiscrete +from sb3_contrib.common.maskable.utils import get_action_masks class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor): @@ -76,12 +80,51 @@ def clone_tqc_batch_norm_stats( return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean) +def clone_on_policy_batch_norm(model: Union[MaskablePPO]) -> (th.Tensor, th.Tensor): + return clone_batch_norm_stats(model.policy.features_extractor.batch_norm) + + CLONE_HELPERS = { QRDQN: clone_qrdqn_batch_norm_stats, TQC: clone_tqc_batch_norm_stats, + MaskablePPO: clone_on_policy_batch_norm, } +def test_ppo_mask_train_eval_mode(): + env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10) + model = MaskablePPO( + "MlpPolicy", + env, + policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), + seed=1, + ) + + bias_before, running_mean_before = clone_on_policy_batch_norm(model) + + model.learn(total_timesteps=200) + + bias_after, running_mean_after = clone_on_policy_batch_norm(model) + + assert ~th.isclose(bias_before, bias_after).all() + assert ~th.isclose(running_mean_before, running_mean_after).all() + + batch_norm_stats_before = clone_on_policy_batch_norm(model) + + observation = env.reset() + action_masks = get_action_masks(env) + first_prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True) + for _ in range(5): + prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True) + np.testing.assert_allclose(first_prediction, prediction) + + batch_norm_stats_after = clone_on_policy_batch_norm(model) + + # No change in batch norm params + for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after): + assert th.isclose(param_before, param_after).all() + + def test_qrdqn_train_with_batch_norm(): model = QRDQN( "MlpPolicy", diff --git a/tests/wrappers/test_action_masker.py b/tests/wrappers/test_action_masker.py new file mode 100644 index 0000000..e130959 --- /dev/null +++ b/tests/wrappers/test_action_masker.py @@ -0,0 +1,78 @@ +from typing import List + +import pytest +from gym.spaces import Discrete +from stable_baselines3.common.envs import IdentityEnv + +from sb3_contrib.common.wrappers import ActionMasker + + +class IdentityEnvDiscrete(IdentityEnv): + def __init__(self, dim: int = 1, ep_length: int = 100): + """ + Identity environment for testing purposes + + :param dim: the size of the dimensions you want to learn + :param ep_length: the length of each episode in timesteps + """ + space = Discrete(dim) + self.useless_property = 1 + super().__init__(ep_length=ep_length, space=space) + + def _action_masks(self) -> List[int]: + return [i == self.state for i in range(self.action_space.n)] + + +def action_mask_fn(env: IdentityEnvDiscrete) -> List[int]: + return [i == env.state for i in range(env.action_space.n)] + + +def test_wrapper_accepts_function(): + """ + ActionMasker accepts a function + """ + + env = IdentityEnvDiscrete() + + assert not hasattr(env, "action_masks") + env = ActionMasker(env, action_mask_fn) + assert hasattr(env, "action_masks") + + +# Wrapper accepts as string name of a method on the underlying env +def test_wrapper_accepts_attr_name(): + """ + ActionMasker accepts a string name of a method on the underlying env + """ + + env = IdentityEnvDiscrete() + + assert not hasattr(env, "action_masks") + env = ActionMasker(env, "_action_masks") + assert hasattr(env, "action_masks") + + +def test_attr_must_be_callable(): + """ + Passing ActionMasker the string name of a non-callable is an error + """ + + env = IdentityEnvDiscrete() + + with pytest.raises(ValueError): + env = ActionMasker(env, "useless_property") + + +# Wrapper method returns expected results +def test_action_masks_returns_expected_result(): + """ + ActionMasker-provided action_masks() method returns expected results + """ + + env = IdentityEnvDiscrete() + env = ActionMasker(env, action_mask_fn) + + # Only one valid action expected + masks = env.action_masks() + masks[env.state] = not masks[env.state] # Bit-flip the one expected valid action + assert all([not mask for mask in masks]) diff --git a/tests/test_wrappers.py b/tests/wrappers/test_time_feature.py similarity index 100% rename from tests/test_wrappers.py rename to tests/wrappers/test_time_feature.py