From 675304d8fa06287344b559849aa82e57cb8cb000 Mon Sep 17 00:00:00 2001 From: Sean Gillen Date: Tue, 18 Jan 2022 04:57:27 -0800 Subject: [PATCH] Augmented Random Search (ARS) (#42) * first pass at ars, replicates initial results, still needs more testing, cleanup * add a few docs and tests, bugfixes for ARS * debug and comment * break out dump logs * rollback so there are now predict workers, some refactoring * remove callback from self, remove torch multiprocessing * add module docs * run formatter * fix load and rerun formatter * rename to less mathy variable names, rename _validate_hypers * refactor to use evaluatate_policy, linear policy no longer uses bias or squashing * move everything to torch, add support for discrete action spaces, bugfix for alive reward offset * added tests, passing all of them, add support for discrete action spaces * update documentation * allow for reward offset when there are multiple envs * update results again * Reformat * Ignore unused imports * Renaming + Cleanup * Experimental multiprocessing * Cleaner multiprocessing * Reformat * Fixes for callback * Fix combining stats * 2nd way * Make the implementation cpu only * Fixes + POC with mp module * POC Processes * Cleaner aync implementation * Remove unused arg * Add typing * Revert vec normalize offset hack * Add `squash_output` parameter * Add more tests * Add comments * Update doc * Add comments * Add more logging * Fix TRPO issue on GPU * Tmp fix for ARS tests on GPU * Additional tmp fixes for ARS * update docstrings + formatting, fix bad exceptioe string in ARSPolicy * Add comments and docstrings * Fix missing import * Fix type check * Add dosctrings * GPU support, first attempt * Fix test * Add missing docstring * Typos * Update defaults hyperparameters Co-authored-by: Antonin RAFFIN --- README.md | 1 + docs/guide/algos.rst | 3 +- docs/guide/examples.rst | 16 +- docs/index.rst | 5 +- docs/misc/changelog.rst | 9 +- docs/modules/ars.rst | 160 +++++++++++ docs/spelling_wordlist.txt | 4 + sb3_contrib/__init__.py | 1 + sb3_contrib/ars/__init__.py | 2 + sb3_contrib/ars/ars.py | 347 +++++++++++++++++++++++ sb3_contrib/ars/policies.py | 119 ++++++++ sb3_contrib/common/vec_env/__init__.py | 1 + sb3_contrib/common/vec_env/async_eval.py | 208 ++++++++++++++ sb3_contrib/trpo/trpo.py | 4 +- sb3_contrib/version.txt | 2 +- setup.cfg | 4 +- setup.py | 2 +- tests/test_deterministic.py | 55 ++++ tests/test_run.py | 36 ++- tests/test_save_load.py | 12 +- 20 files changed, 974 insertions(+), 17 deletions(-) create mode 100644 docs/modules/ars.rst create mode 100644 sb3_contrib/ars/__init__.py create mode 100644 sb3_contrib/ars/ars.py create mode 100644 sb3_contrib/ars/policies.py create mode 100644 sb3_contrib/common/vec_env/__init__.py create mode 100644 sb3_contrib/common/vec_env/async_eval.py create mode 100644 tests/test_deterministic.py diff --git a/README.md b/README.md index 815956d..59f4e88 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ See documentation for the full list of included features. - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) +- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055) **Gym Wrappers**: - [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 81770b0..4b2cfb4 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -8,9 +8,10 @@ along with some useful characteristics: support for discrete/continuous actions, ============ =========== ============ ================= =============== ================ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing ============ =========== ============ ================= =============== ================ +ARS ✔️ ❌️ ❌ ❌ ✔️ +QR-DQN ️❌ ️✔️ ❌ ❌ ✔️ TQC ✔️ ❌ ❌ ❌ ✔️ TRPO ✔️ ✔️ ✔️ ✔️ ✔️ -QR-DQN ️❌ ️✔️ ❌ ❌ ✔️ ============ =========== ============ ================= =============== ================ diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index c39ce26..6a60be8 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -43,7 +43,7 @@ Train a PPO with invalid action masking agent on a toy environment. env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60) model = MaskablePPO("MlpPolicy", env, verbose=1) model.learn(5000) - model.save("qrdqn_cartpole") + model.save("maskable_toy_env") TRPO ---- @@ -57,3 +57,17 @@ Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environmen model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1) model.learn(total_timesteps=100_000, log_interval=4) model.save("trpo_pendulum") + + +ARS +--- + +Train an agent using Augmented Random Search (ARS) agent on the Pendulum environment + +.. code-block:: python + + from sb3_contrib import ARS + + model = ARS("LinearPolicy", "Pendulum-v0", verbose=1) + model.learn(total_timesteps=10000, log_interval=4) + model.save("ars_pendulum") diff --git a/docs/index.rst b/docs/index.rst index ac610bd..f713dfa 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -31,10 +31,11 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d :maxdepth: 1 :caption: RL Algorithms + modules/ars + modules/ppo_mask + modules/qrdqn modules/tqc modules/trpo - modules/qrdqn - modules/ppo_mask .. toctree:: :maxdepth: 1 diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 1b117e6..b4e42d8 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,20 +4,21 @@ Changelog ========== -Release 1.3.1a7 (WIP) +Release 1.3.1a9 (WIP) ------------------------------- -**Add TRPO** +**Add TRPO and ARS** Breaking Changes: ^^^^^^^^^^^^^^^^^ - Dropped python 3.6 support -- Upgraded to Stable-Baselines3 >= 1.3.1a4 +- Upgraded to Stable-Baselines3 >= 1.3.1a9 - ``MaskablePPO`` was updated to match latest SB3 ``PPO`` version (timeout handling and new method for the policy object) New Features: ^^^^^^^^^^^^^ - Added ``TRPO`` (@cyprienc) - Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported) +- Added Augmented Random Search (ARS) (@sgillen) Bug Fixes: ^^^^^^^^^^ @@ -221,4 +222,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) Contributors: ------------- -@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc +@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen diff --git a/docs/modules/ars.rst b/docs/modules/ars.rst new file mode 100644 index 0000000..7ca1141 --- /dev/null +++ b/docs/modules/ars.rst @@ -0,0 +1,160 @@ +.. _ars: + +.. automodule:: sb3_contrib.ars + + +ARS +=== + + +Augmented Random Search (ARS) is a simple reinforcement algorithm that uses a direct random search over policy +parameters. It can be surprisingly effective compared to more sophisticated algorithms. In the `original paper `_ the authors showed that linear policies trained with ARS were competitive with deep reinforcement learning for the MuJuCo locomotion tasks. + +SB3s implementation allows for linear policies without bias or squashing function, it also allows for training MLP policies, which include linear policies with bias and squashing functions as a special case. + +Normally one wants to train ARS with several seeds to properly evaluate. + +.. warning:: + + ARS multi-processing is different from the classic Stable-Baselines3 multi-processing: it runs n environments + in parallel but asynchronously. This asynchronous multi-processing is considered experimental + and does not fully support callbacks: the ``on_step()`` event is called artificially after the evaluation episodes are over. + + +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + LinearPolicy + MlpPolicy + + +Notes +----- + +- Original paper: https://arxiv.org/abs/1803.07055 +- Original Implementation: https://github.com/modestyachts/ARS + + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ✔️ (cf. example) +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ✔️ ❌ +Box ✔️ ✔️ +MultiDiscrete ❌ ❌ +MultiBinary ❌ ❌ +Dict ❌ ❌ +============= ====== =========== + + +Example +------- + +.. code-block:: python + + from sb3_contrib import ARS + + # Policy can be LinearPolicy or MlpPolicy + model = ARS("LinearPolicy", "Pendulum-v0", verbose=1) + model.learn(total_timesteps=10000, log_interval=4) + model.save("ars_pendulum") + + +With experimental asynchronous multi-processing: + +.. code-block:: python + + from sb3_contrib import ARS + from sb3_contrib.common.vec_env import AsyncEval + + from stable_baselines3.common.env_util import make_vec_env + + env_id = "CartPole-v1" + n_envs = 2 + + model = ARS("LinearPolicy", env_id, n_delta=2, n_top=1, verbose=1) + # Create env for asynchronous evaluation (run in different processes) + async_eval = AsyncEval([lambda: make_vec_env(env_id) for _ in range(n_envs)], model.policy) + + model.learn(total_timesteps=200_000, log_interval=4, async_eval=async_eval) + + +Results +------- + +Replicating results from the original paper, which used the Mujoco benchmarks. Same parameters from the original paper, using 8 seeds. + +============= ============ +Environments ARS +============= ============ +\ +HalfCheetah 4398 +/- 320 +Swimmer 241 +/- 51 +Hopper 3320 +/- 120 +============= ============ + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Clone RL-Zoo and checkout the branch ``feat/ars`` + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + git checkout feat/ars + +Run the benchmark. The following code snippet trains 8 seeds in parallel + +.. code-block:: bash + + for ENV_ID in Swimmer-v3 HalfCheetah-v3 Hopper-v3 + do + for SEED_NUM in {1..8} + do + SEED=$RANDOM + python train.py --algo ars --env $ENV_ID --eval-episodes 10 --eval-freq 10000 -n 20000000 --seed $SEED & + sleep 1 + done + wait + done + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a ars -e HalfCheetah Swimmer Hopper -f logs/ -o logs/ars_results -max 20000000 + python scripts/plot_from_file.py -i logs/ars_results.pkl -l ARS + + + +Parameters +---------- + +.. autoclass:: ARS + :members: + :inherited-members: + +.. _ars_policies: + +ARS Policies +------------- + +.. autoclass:: sb3_contrib.ars.policies.ARSPolicy + :members: + :noindex: + +.. autoclass:: LinearPolicy + :members: + :inherited-members: + +.. autoclass:: MlpPolicy + :members: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 42669bf..09e577c 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -119,3 +119,7 @@ cuda Polyak gSDE rollouts +quantiles +quantile +contrib +Contrib diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 790eaaa..2ff218d 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -1,5 +1,6 @@ import os +from sb3_contrib.ars import ARS from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC diff --git a/sb3_contrib/ars/__init__.py b/sb3_contrib/ars/__init__.py new file mode 100644 index 0000000..655ac8b --- /dev/null +++ b/sb3_contrib/ars/__init__.py @@ -0,0 +1,2 @@ +from sb3_contrib.ars.ars import ARS +from sb3_contrib.ars.policies import LinearPolicy, MlpPolicy diff --git a/sb3_contrib/ars/ars.py b/sb3_contrib/ars/ars.py new file mode 100644 index 0000000..8fed745 --- /dev/null +++ b/sb3_contrib/ars/ars.py @@ -0,0 +1,347 @@ +import copy +import time +import warnings +from functools import partial +from typing import Any, Dict, Optional, Type, Union + +import gym +import numpy as np +import torch as th +import torch.nn.utils +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import get_schedule_fn, safe_mean + +from sb3_contrib.ars.policies import ARSPolicy +from sb3_contrib.common.vec_env.async_eval import AsyncEval + + +class ARS(BaseAlgorithm): + """ + Augmented Random Search: https://arxiv.org/abs/1803.07055 + + Original implementation: https://github.com/modestyachts/ARS + C++/Cuda Implementation: https://github.com/google-research/tiny-differentiable-simulator/ + 150 LOC Numpy Implementation: https://github.com/alexis-jacq/numpy_ARS/blob/master/asr.py + + :param policy: The policy to train, can be an instance of ``ARSPolicy``, or a string from ["LinearPolicy", "MlpPolicy"] + :param env: The environment to train on, may be a string if registered with gym + :param n_delta: How many random perturbations of the policy to try at each update step. + :param n_top: How many of the top delta to use in each update step. Default is n_delta + :param learning_rate: Float or schedule for the step size + :param delta_std: Float or schedule for the exploration noise + :param zero_policy: Boolean determining if the passed policy should have it's weights zeroed before training. + :param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses. + :param n_eval_episodes: Number of episodes to evaluate each candidate. + :param policy_kwargs: Keyword arguments to pass to the policy on creation + :param policy_base: Base class to use for the policy + :param tensorboard_log: String with the directory to put tensorboard logs: + :param seed: Random seed for the training + :param verbose: Verbosity level: 0 no output, 1 info, 2 debug + :param device: Torch device to use for training, defaults to "cpu" + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + def __init__( + self, + policy: Union[str, Type[ARSPolicy]], + env: Union[GymEnv, str], + n_delta: int = 8, + n_top: Optional[int] = None, + learning_rate: Union[float, Schedule] = 0.02, + delta_std: Union[float, Schedule] = 0.05, + zero_policy: bool = True, + alive_bonus_offset: float = 0, + n_eval_episodes: int = 1, + policy_kwargs: Optional[Dict[str, Any]] = None, + policy_base: Type[BasePolicy] = ARSPolicy, + tensorboard_log: Optional[str] = None, + seed: Optional[int] = None, + verbose: int = 0, + device: Union[th.device, str] = "cpu", + _init_setup_model: bool = True, + ): + + super().__init__( + policy, + env, + learning_rate=learning_rate, + tensorboard_log=tensorboard_log, + policy_base=policy_base, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + supported_action_spaces=(gym.spaces.Box, gym.spaces.Discrete), + support_multi_env=True, + seed=seed, + ) + + self.n_delta = n_delta + self.pop_size = 2 * n_delta + self.delta_std_schedule = get_schedule_fn(delta_std) + self.n_eval_episodes = n_eval_episodes + + if n_top is None: + n_top = n_delta + + # Make sure our hyper parameters are valid and auto correct them if they are not + if n_top > n_delta: + warnings.warn(f"n_top = {n_top} > n_delta = {n_top}, setting n_top = n_delta") + n_top = n_delta + + self.n_top = n_top + + self.alive_bonus_offset = alive_bonus_offset + self.zero_policy = zero_policy + self.weights = None # Need to call init model to initialize weight + self.processes = None + # Keep track of how many steps where elapsed before a new rollout + # Important for syncing observation normalization between workers + self.old_count = 0 + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + self.policy = self.policy_class(self.observation_space, self.action_space, **self.policy_kwargs) + self.policy = self.policy.to(self.device) + self.weights = th.nn.utils.parameters_to_vector(self.policy.parameters()).detach() + self.n_params = len(self.weights) + + if self.zero_policy: + self.weights = th.zeros_like(self.weights, requires_grad=False) + self.policy.load_from_vector(self.weights.cpu()) + + def _mimic_monitor_wrapper(self, episode_rewards: np.ndarray, episode_lengths: np.ndarray) -> None: + """ + Helper to mimic Monitor wrapper and report episode statistics (mean reward, mean episode length). + + :param episode_rewards: List containing per-episode rewards + :param episode_lengths: List containing per-episode lengths (in number of steps) + """ + # Mimic Monitor Wrapper + infos = [ + {"episode": {"r": episode_reward, "l": episode_length}} + for episode_reward, episode_length in zip(episode_rewards, episode_lengths) + ] + + self._update_info_buffer(infos) + + def _trigger_callback( + self, + _locals: Dict[str, Any], + _globals: Dict[str, Any], + callback: BaseCallback, + n_envs: int, + ) -> None: + """ + Callback passed to the ``evaluate_policy()`` helper + in order to increment the number of timesteps + and trigger events in the single process version. + + :param _locals: + :param _globals: + :param callback: Callback that will be called at every step + :param n_envs: Number of environments + """ + self.num_timesteps += n_envs + callback.on_step() + + def evaluate_candidates( + self, candidate_weights: th.Tensor, callback: BaseCallback, async_eval: Optional[AsyncEval] + ) -> th.Tensor: + """ + Evaluate each candidate. + + :param candidate_weights: The candidate weights to be evaluated. + :param callback: Callback that will be called at each step + (or after evaluation in the multiprocess version) + :param async_eval: The object for asynchronous evaluation of candidates. + :return: The episodic return for each candidate. + """ + + batch_steps = 0 + # returns == sum of rewards + candidate_returns = th.zeros(self.pop_size, device=self.device) + train_policy = copy.deepcopy(self.policy) + # Empty buffer to show only mean over one iteration (one set of candidates) in the logs + self.ep_info_buffer = [] + callback.on_rollout_start() + + if async_eval is not None: + # Multiprocess asynchronous version + async_eval.send_jobs(candidate_weights, self.pop_size) + results = async_eval.get_results() + + for weights_idx, (episode_rewards, episode_lengths) in results: + + # Update reward to cancel out alive bonus if needed + candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths) + batch_steps += np.sum(episode_lengths) + self._mimic_monitor_wrapper(episode_rewards, episode_lengths) + + # Combine the filter stats of each process for normalization + for worker_obs_rms in async_eval.get_obs_rms(): + if self._vec_normalize_env is not None: + # worker_obs_rms.count -= self.old_count + self._vec_normalize_env.obs_rms.combine(worker_obs_rms) + # Hack: don't count timesteps twice (between the two are synced) + # otherwise it will lead to overflow, + # in practice we would need two RunningMeanStats + self._vec_normalize_env.obs_rms.count -= self.old_count + + # Synchronise VecNormalize if needed + if self._vec_normalize_env is not None: + async_eval.sync_obs_rms(self._vec_normalize_env.obs_rms.copy()) + self.old_count = self._vec_normalize_env.obs_rms.count + + # Hack to have Callback events + for _ in range(batch_steps // len(async_eval.remotes)): + self.num_timesteps += len(async_eval.remotes) + callback.on_step() + else: + # Single process, synchronous version + for weights_idx in range(self.pop_size): + + # Load current candidate weights + train_policy.load_from_vector(candidate_weights[weights_idx].cpu()) + # Evaluate the candidate + episode_rewards, episode_lengths = evaluate_policy( + train_policy, + self.env, + n_eval_episodes=self.n_eval_episodes, + return_episode_rewards=True, + # Increment num_timesteps too (slight mismatch with multi envs) + callback=partial(self._trigger_callback, callback=callback, n_envs=self.env.num_envs), + warn=False, + ) + # Update reward to cancel out alive bonus if needed + candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths) + batch_steps += sum(episode_lengths) + self._mimic_monitor_wrapper(episode_rewards, episode_lengths) + + # Note: we increment the num_timesteps inside the evaluate_policy() + # however when using multiple environments, there will be a slight + # mismatch between the number of timesteps used and the number + # of calls to the step() method (cf. implementation of evaluate_policy()) + # self.num_timesteps += batch_steps + + callback.on_rollout_end() + + return candidate_returns + + def _log_and_dump(self) -> None: + """ + Dump information to the logger. + """ + time_elapsed = time.time() - self.start_time + fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8)) + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(step=self.num_timesteps) + + def _do_one_update(self, callback: BaseCallback, async_eval: Optional[AsyncEval]) -> None: + """ + Sample new candidates, evaluate them and then update current policy. + + :param callback: callback(s) called at every step with state of the algorithm. + :param async_eval: The object for asynchronous evaluation of candidates. + """ + # Retrieve current parameter noise standard deviation + # and current learning rate + delta_std = self.delta_std_schedule(self._current_progress_remaining) + learning_rate = self.lr_schedule(self._current_progress_remaining) + # Sample the parameter noise, it will be scaled by delta_std + deltas = th.normal(mean=0.0, std=1.0, size=(self.n_delta, self.n_params), device=self.device) + policy_deltas = deltas * delta_std + # Generate 2 * n_delta candidate policies by adding noise to the current weights + candidate_weights = th.cat([self.weights + policy_deltas, self.weights - policy_deltas]) + + with th.no_grad(): + candidate_returns = self.evaluate_candidates(candidate_weights, callback, async_eval) + + # Returns corresponding to weights + deltas + plus_returns = candidate_returns[: self.n_delta] + # Returns corresponding to weights - deltas + minus_returns = candidate_returns[self.n_delta :] + + # Keep only the top performing candidates for update + top_returns, _ = th.max(th.vstack((plus_returns, minus_returns)), dim=0) + top_idx = th.argsort(top_returns, descending=True)[: self.n_top] + plus_returns = plus_returns[top_idx] + minus_returns = minus_returns[top_idx] + deltas = deltas[top_idx] + + # Scale learning rate by the return standard deviation: + # take smaller steps when there is a high variance in the returns + return_std = th.cat([plus_returns, minus_returns]).std() + step_size = learning_rate / (self.n_top * return_std + 1e-6) + # Approximate gradient step + self.weights = self.weights + step_size * ((plus_returns - minus_returns) @ deltas) + self.policy.load_from_vector(self.weights.cpu()) + + self.logger.record("train/iterations", self._n_updates, exclude="tensorboard") + self.logger.record("train/delta_std", delta_std) + self.logger.record("train/learning_rate", learning_rate) + self.logger.record("train/step_size", step_size.item()) + self.logger.record("rollout/return_std", return_std.item()) + + self._n_updates += 1 + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + tb_log_name: str = "ARS", + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + async_eval: Optional[AsyncEval] = None, + ) -> "ARS": + """ + Return a trained model. + + :param total_timesteps: The total number of samples (env steps) to train on + :param callback: callback(s) called at every step with state of the algorithm. + :param log_interval: The number of timesteps before logging. + :param tb_log_name: the name of the run for TensorBoard logging + :param eval_env: Environment that will be used to evaluate the agent + :param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little) + :param n_eval_episodes: Number of episode to evaluate the agent + :param eval_log_path: Path to a folder where the evaluations will be saved + :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) + :param async_eval: The object for asynchronous evaluation of candidates. + :return: the trained model + """ + + total_steps, callback = self._setup_learn( + total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name + ) + + callback.on_training_start(locals(), globals()) + + while self.num_timesteps < total_steps: + self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + self._do_one_update(callback, async_eval) + if log_interval is not None and self._n_updates % log_interval == 0: + self._log_and_dump() + + if async_eval is not None: + async_eval.close() + + callback.on_training_end() + + return self diff --git a/sb3_contrib/ars/policies.py b/sb3_contrib/ars/policies.py new file mode 100644 index 0000000..e52ed15 --- /dev/null +++ b/sb3_contrib/ars/policies.py @@ -0,0 +1,119 @@ +from typing import Any, Dict, List, Optional, Type + +import gym +import torch as th +from stable_baselines3.common.policies import BasePolicy, register_policy +from stable_baselines3.common.preprocessing import get_action_dim +from stable_baselines3.common.torch_layers import create_mlp +from torch import nn + + +class ARSPolicy(BasePolicy): + """ + Policy network for ARS. + + :param observation_space: The observation space of the environment + :param action_space: The action space of the environment + :param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes. + :param activation_fn: Activation function + :param squash_output: For continuous actions, whether the output is squashed + or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped. + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + squash_output: bool = True, + ): + + super().__init__( + observation_space, + action_space, + squash_output=isinstance(action_space, gym.spaces.Box) and squash_output, + ) + + if net_arch is None: + net_arch = [64, 64] + + self.net_arch = net_arch + self.features_extractor = self.make_features_extractor() + self.features_dim = self.features_extractor.features_dim + self.activation_fn = activation_fn + + if isinstance(action_space, gym.spaces.Box): + action_dim = get_action_dim(action_space) + actor_net = create_mlp(self.features_dim, action_dim, net_arch, activation_fn, squash_output=True) + elif isinstance(action_space, gym.spaces.Discrete): + actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn) + else: + raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.") + + self.action_net = nn.Sequential(*actor_net) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + # data = super()._get_constructor_parameters() this adds normalize_images, which we don't support... + data = dict( + observation_space=self.observation_space, + action_space=self.action_space, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + ) + return data + + def forward(self, obs: th.Tensor) -> th.Tensor: + + features = self.extract_features(obs) + if isinstance(self.action_space, gym.spaces.Box): + return self.action_net(features) + elif isinstance(self.action_space, gym.spaces.Discrete): + logits = self.action_net(features) + return th.argmax(logits, dim=1) + else: + raise NotImplementedError() + + def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: + # Non deterministic action does not really make sense for ARS, we ignore this parameter for now.. + return self.forward(observation) + + +class ARSLinearPolicy(ARSPolicy): + """ + Linear policy network for ARS. + + :param observation_space: The observation space of the environment + :param action_space: The action space of the environment + :param with_bias: With or without bias on the output + :param squash_output: For continuous actions, whether the output is squashed + or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped. + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + with_bias: bool = False, + squash_output: bool = False, + ): + + super().__init__(observation_space, action_space, squash_output=squash_output) + + if isinstance(action_space, gym.spaces.Box): + action_dim = get_action_dim(action_space) + self.action_net = nn.Linear(self.features_dim, action_dim, bias=with_bias) + if squash_output: + self.action_net = nn.Sequential(self.action_net, nn.Tanh()) + elif isinstance(action_space, gym.spaces.Discrete): + self.action_net = nn.Linear(self.features_dim, action_space.n, bias=with_bias) + else: + raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.") + + +MlpPolicy = ARSPolicy +LinearPolicy = ARSLinearPolicy + + +register_policy("LinearPolicy", LinearPolicy) +register_policy("MlpPolicy", MlpPolicy) diff --git a/sb3_contrib/common/vec_env/__init__.py b/sb3_contrib/common/vec_env/__init__.py new file mode 100644 index 0000000..488abf8 --- /dev/null +++ b/sb3_contrib/common/vec_env/__init__.py @@ -0,0 +1 @@ +from sb3_contrib.common.vec_env.async_eval import AsyncEval diff --git a/sb3_contrib/common/vec_env/async_eval.py b/sb3_contrib/common/vec_env/async_eval.py new file mode 100644 index 0000000..767ff44 --- /dev/null +++ b/sb3_contrib/common/vec_env/async_eval.py @@ -0,0 +1,208 @@ +import multiprocessing +import multiprocessing as mp +from collections import defaultdict +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch as th +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.running_mean_std import RunningMeanStd +from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize +from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper + + +def _worker( + remote: mp.connection.Connection, + parent_remote: mp.connection.Connection, + worker_env_wrapper: CloudpickleWrapper, + train_policy_wrapper: CloudpickleWrapper, + n_eval_episodes: int = 1, +) -> None: + """ + Function that will be run in each process. + It is in charge of creating environments, evaluating candidates + and communicating with the main process. + + :param remote: Pipe to communicate with the parent process. + :param parent_remote: + :param worker_env_wrapper: Callable used to create the environment inside the process. + :param train_policy_wrapper: Callable used to create the policy inside the process. + :param n_eval_episodes: Number of evaluation episodes per candidate. + """ + parent_remote.close() + env = worker_env_wrapper.var() + train_policy = train_policy_wrapper.var + vec_normalize = unwrap_vec_normalize(env) + if vec_normalize is not None: + obs_rms = vec_normalize.obs_rms + else: + obs_rms = None + while True: + try: + cmd, data = remote.recv() + if cmd == "eval": + results = [] + # Evaluate each candidate and save results + for weights_idx, candidate_weights in data: + train_policy.load_from_vector(candidate_weights.cpu()) + episode_rewards, episode_lengths = evaluate_policy( + train_policy, + env, + n_eval_episodes=n_eval_episodes, + return_episode_rewards=True, + warn=False, + ) + results.append((weights_idx, (episode_rewards, episode_lengths))) + remote.send(results) + elif cmd == "seed": + remote.send(env.seed(data)) + elif cmd == "get_obs_rms": + remote.send(obs_rms) + elif cmd == "sync_obs_rms": + vec_normalize.obs_rms = data + obs_rms = data + elif cmd == "close": + env.close() + remote.close() + break + else: + raise NotImplementedError(f"`{cmd}` is not implemented in the worker") + except EOFError: + break + + +class AsyncEval(object): + """ + Helper class to do asynchronous evaluation of different policies with multiple processes. + It is useful when implementing population based methods like Evolution Strategies (ES), + Cross Entropy Method (CEM) or Augmented Random Search (ARS). + + .. warning:: + + Only 'forkserver' and 'spawn' start methods are thread-safe, + which is important to avoid race conditions. + However, compared to + 'fork' they incur a small start-up cost and have restrictions on + global variables. With those methods, users must wrap the code in an + ``if __name__ == "__main__":`` block. + For more information, see the multiprocessing documentation. + + :param envs_fn: Vectorized environments to run in subprocesses (callable) + :param train_policy: The policy object that will load the different candidate + weights. + :param start_method: method used to start the subprocesses. + Must be one of the methods returned by ``multiprocessing.get_all_start_methods()``. + Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. + :param n_eval_episodes: The number of episodes to test each agent + """ + + def __init__( + self, + envs_fn: List[Callable[[], VecEnv]], + train_policy: BasePolicy, + start_method: Optional[str] = None, + n_eval_episodes: int = 1, + ): + self.waiting = False + self.closed = False + n_envs = len(envs_fn) + + if start_method is None: + # Fork is not a thread safe method (see issue #217) + # but is more user friendly (does not require to wrap the code in + # a `if __name__ == "__main__":`) + forkserver_available = "forkserver" in mp.get_all_start_methods() + start_method = "forkserver" if forkserver_available else "spawn" + ctx = mp.get_context(start_method) + + self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)]) + self.processes = [] + for work_remote, remote, worker_env in zip(self.work_remotes, self.remotes, envs_fn): + args = ( + work_remote, + remote, + CloudpickleWrapper(worker_env), + CloudpickleWrapper(train_policy), + n_eval_episodes, + ) + # daemon=True: if the main process crashes, we should not cause things to hang + process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error + process.start() + self.processes.append(process) + work_remote.close() + + def send_jobs(self, candidate_weights: th.Tensor, pop_size: int) -> None: + """ + Send jobs to the workers to evaluate new candidates. + + :param candidate_weights: The weights to be evaluated. + :pop_size: The number of candidate (size of the population) + """ + jobs_per_worker = defaultdict(list) + for weights_idx in range(pop_size): + jobs_per_worker[weights_idx % len(self.remotes)].append((weights_idx, candidate_weights[weights_idx])) + + for remote_idx, remote in enumerate(self.remotes): + remote.send(("eval", jobs_per_worker[remote_idx])) + self.waiting = True + + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + """ + Seed the environments. + + :param seed: The seed for the pseudo-random generators. + :return: + """ + for idx, remote in enumerate(self.remotes): + remote.send(("seed", seed + idx)) + return [remote.recv() for remote in self.remotes] + + def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]: + """ + Retreive episode rewards and lengths from each worker + for all candidates (there might be multiple candidates per worker) + + :return: A list of tuples containing each candidate index and its + result (episodic reward and episode length) + """ + results = [remote.recv() for remote in self.remotes] + flat_results = [result for worker_results in results for result in worker_results] + self.waiting = False + return flat_results + + def get_obs_rms(self) -> List[RunningMeanStd]: + """ + Retrieve the observation filters (observation running mean std) + of each process, they will be combined in the main process. + Synchronisation is done afterward using ``sync_obs_rms()``. + :return: A list of ``RunningMeanStd`` objects (one per process) + """ + for remote in self.remotes: + remote.send(("get_obs_rms", None)) + return [remote.recv() for remote in self.remotes] + + def sync_obs_rms(self, obs_rms: RunningMeanStd) -> None: + """ + Synchronise (and update) the observation filters + (observation running mean std) + :param obs_rms: The updated ``RunningMeanStd`` to be used + by workers for normalizing observations. + """ + for remote in self.remotes: + remote.send(("sync_obs_rms", obs_rms)) + + def close(self) -> None: + """ + Close the processes. + """ + if self.closed: + return + if self.waiting: + for remote in self.remotes: + remote.recv() + for remote in self.remotes: + remote.send(("close", None)) + for process in self.processes: + process.join() + self.closed = True diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 9c46103..5eeeef0 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -206,8 +206,8 @@ class TRPO(OnPolicyAlgorithm): policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True) grad_shape.append(kl_param_grad.shape) - grad_kl.append(kl_param_grad.view(-1)) - policy_objective_gradients.append(policy_objective_grad.view(-1)) + grad_kl.append(kl_param_grad.reshape(-1)) + policy_objective_gradients.append(policy_objective_grad.reshape(-1)) actor_params.append(param) # Gradients are concatenated before the conjugate gradient step diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index f625807..89cccac 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.3.1a7 +1.3.1a9 diff --git a/setup.cfg b/setup.cfg index dc9cf36..1100948 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,13 +22,15 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators # Ignore import not used when aliases are defined per-file-ignores = ./sb3_contrib/__init__.py:F401 + ./sb3_contrib/ars/__init__.py:F401 ./sb3_contrib/ppo_mask/__init__.py:F401 ./sb3_contrib/qrdqn/__init__.py:F401 ./sb3_contrib/tqc/__init__.py:F401 ./sb3_contrib/trpo/__init__.py:F401 - ./sb3_contrib/common/vec_env/wrappers/__init__.py:F401 ./sb3_contrib/common/wrappers/__init__.py:F401 ./sb3_contrib/common/envs/__init__.py:F401 + ./sb3_contrib/common/vec_env/__init__.py:F401 + exclude = # No need to traverse our git directory .git, diff --git a/setup.py b/setup.py index 2426ebc..8da3d1d 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.3.1a4", + "stable_baselines3>=1.3.1a9", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_deterministic.py b/tests/test_deterministic.py new file mode 100644 index 0000000..19e7175 --- /dev/null +++ b/tests/test_deterministic.py @@ -0,0 +1,55 @@ +import pytest +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.noise import NormalActionNoise +from stable_baselines3.common.vec_env import VecNormalize + +from sb3_contrib import ARS, QRDQN, TQC +from sb3_contrib.common.vec_env import AsyncEval + +N_STEPS_TRAINING = 500 +SEED = 0 +ARS_MULTI = "ars_multi" + + +@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI]) +def test_deterministic_training_common(algo): + results = [[], []] + rewards = [[], []] + # Smaller network + kwargs = {"policy_kwargs": dict(net_arch=[64])} + env_id = "Pendulum-v0" + if algo == ARS_MULTI: + algo = ARS + ars_multi = True + else: + ars_multi = False + + if algo in [TQC]: + kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4}) + else: + if algo == QRDQN: + env_id = "CartPole-v1" + kwargs.update({"learning_starts": 100, "target_update_interval": 100}) + elif algo == ARS: + kwargs.update({"n_delta": 2}) + + for i in range(2): + model = algo("MlpPolicy", env_id, seed=SEED, **kwargs) + + learn_kwargs = {"total_timesteps": N_STEPS_TRAINING} + if ars_multi: + learn_kwargs["async_eval"] = AsyncEval( + [lambda: VecNormalize(make_vec_env(env_id, seed=SEED, n_envs=1)) for _ in range(2)], + model.policy, + ) + + model.learn(**learn_kwargs) + env = model.get_env() + obs = env.reset() + for _ in range(100): + action, _ = model.predict(obs, deterministic=False) + obs, reward, _, _ = env.step(action) + results[i].append(action) + rewards[i].append(reward) + assert sum(results[0]) == sum(results[1]), results + assert sum(rewards[0]) == sum(rewards[1]), rewards diff --git a/tests/test_run.py b/tests/test_run.py index ab2d311..be8aef3 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,8 +1,10 @@ import gym import pytest from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env import VecNormalize -from sb3_contrib import QRDQN, TQC, TRPO +from sb3_contrib import ARS, QRDQN, TQC, TRPO +from sb3_contrib.common.vec_env import AsyncEval @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) @@ -82,6 +84,38 @@ def test_trpo_params(): model.learn(total_timesteps=500) +@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) +@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"]) +def test_ars(policy_str, env_id): + model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0) + model.learn(total_timesteps=500, log_interval=1, eval_freq=250) + + +def test_ars_multi_env(): + env = make_vec_env("Pendulum-v0", n_envs=2) + model = ARS("MlpPolicy", env, n_delta=1) + model.learn(total_timesteps=250) + + env = VecNormalize(make_vec_env("Pendulum-v0", n_envs=1)) + model = ARS("MlpPolicy", env, n_delta=2, seed=0) + # with parallelism + async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v0", n_envs=1)) for _ in range(2)], model.policy) + async_eval.seed(0) + model.learn(500, async_eval=async_eval) + + +@pytest.mark.parametrize("n_top", [2, 8]) +def test_ars_n_top(n_top): + n_delta = 3 + if n_top > n_delta: + with pytest.warns(UserWarning): + model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top) + model.learn(total_timesteps=500) + else: + model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top) + model.learn(total_timesteps=500) + + @pytest.mark.parametrize("model_class", [TQC, QRDQN]) def test_offpolicy_multi_env(model_class): if model_class in [TQC]: diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 10202e7..477b7a1 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -12,9 +12,9 @@ from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnv from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv -from sb3_contrib import QRDQN, TQC, TRPO +from sb3_contrib import ARS, QRDQN, TQC, TRPO -MODEL_LIST = [TQC, QRDQN, TRPO] +MODEL_LIST = [ARS, QRDQN, TQC, TRPO] def select_env(model_class: BaseAlgorithm) -> gym.Env: @@ -266,6 +266,10 @@ def test_save_load_policy(tmp_path, model_class, policy_str): :param policy_str: (str) Name of the policy. """ kwargs = dict(policy_kwargs=dict(net_arch=[16])) + + if policy_str == "CnnPolicy" and model_class is ARS: + pytest.skip("ARS does not support CnnPolicy") + if policy_str == "MlpPolicy": env = select_env(model_class) else: @@ -332,9 +336,11 @@ def test_save_load_policy(tmp_path, model_class, policy_str): if actor is not None: actor.save(tmp_path / "actor.pkl") + device = policy.device + del policy, actor - policy = policy_class.load(tmp_path / "policy.pkl") + policy = policy_class.load(tmp_path / "policy.pkl").to(device) if actor_class is not None: actor = actor_class.load(tmp_path / "actor.pkl")