Augmented Random Search (ARS) (#42)

* first pass at ars, replicates initial results, still needs more testing, cleanup

* add a few docs and tests, bugfixes for ARS

* debug and comment

* break out dump logs

* rollback so there are now predict workers, some refactoring

* remove callback from self, remove torch multiprocessing

* add module docs

* run formatter

* fix load and rerun formatter

* rename to less mathy variable names, rename _validate_hypers

* refactor to use evaluatate_policy, linear policy no longer uses bias or squashing

* move everything to torch, add support for discrete action spaces, bugfix for alive reward offset

* added tests, passing all of them, add support for discrete action spaces

* update documentation

* allow for reward offset when there are multiple envs

* update results again

* Reformat

* Ignore unused imports

* Renaming + Cleanup

* Experimental multiprocessing

* Cleaner multiprocessing

* Reformat

* Fixes for callback

* Fix combining stats

* 2nd way

* Make the implementation cpu only

* Fixes + POC with mp module

* POC Processes

* Cleaner aync implementation

* Remove unused arg

* Add typing

* Revert vec normalize offset hack

* Add `squash_output` parameter

* Add more tests

* Add comments

* Update doc

* Add comments

* Add more logging

* Fix TRPO issue on GPU

* Tmp fix for ARS tests on GPU

* Additional tmp fixes for ARS

* update docstrings + formatting, fix bad exceptioe string in ARSPolicy

* Add comments and docstrings

* Fix missing import

* Fix type check

* Add dosctrings

* GPU support, first attempt

* Fix test

* Add missing docstring

* Typos

* Update defaults hyperparameters

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Sean Gillen 2022-01-18 04:57:27 -08:00 committed by GitHub
parent 3b007ae93b
commit 675304d8fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 974 additions and 17 deletions

View File

@ -29,6 +29,7 @@ See documentation for the full list of included features.
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)

View File

@ -8,9 +8,10 @@ along with some useful characteristics: support for discrete/continuous actions,
============ =========== ============ ================= =============== ================
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
ARS ✔️ ❌️ ❌ ❌ ✔️
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
TQC ✔️ ❌ ❌ ❌ ✔️
TRPO ✔️ ✔️ ✔️ ✔️ ✔️
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
============ =========== ============ ================= =============== ================

View File

@ -43,7 +43,7 @@ Train a PPO with invalid action masking agent on a toy environment.
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
model = MaskablePPO("MlpPolicy", env, verbose=1)
model.learn(5000)
model.save("qrdqn_cartpole")
model.save("maskable_toy_env")
TRPO
----
@ -57,3 +57,17 @@ Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environmen
model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1)
model.learn(total_timesteps=100_000, log_interval=4)
model.save("trpo_pendulum")
ARS
---
Train an agent using Augmented Random Search (ARS) agent on the Pendulum environment
.. code-block:: python
from sb3_contrib import ARS
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("ars_pendulum")

View File

@ -31,10 +31,11 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:maxdepth: 1
:caption: RL Algorithms
modules/ars
modules/ppo_mask
modules/qrdqn
modules/tqc
modules/trpo
modules/qrdqn
modules/ppo_mask
.. toctree::
:maxdepth: 1

View File

@ -4,20 +4,21 @@ Changelog
==========
Release 1.3.1a7 (WIP)
Release 1.3.1a9 (WIP)
-------------------------------
**Add TRPO**
**Add TRPO and ARS**
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Dropped python 3.6 support
- Upgraded to Stable-Baselines3 >= 1.3.1a4
- Upgraded to Stable-Baselines3 >= 1.3.1a9
- ``MaskablePPO`` was updated to match latest SB3 ``PPO`` version (timeout handling and new method for the policy object)
New Features:
^^^^^^^^^^^^^
- Added ``TRPO`` (@cyprienc)
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
- Added Augmented Random Search (ARS) (@sgillen)
Bug Fixes:
^^^^^^^^^^
@ -221,4 +222,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
Contributors:
-------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen

160
docs/modules/ars.rst Normal file
View File

@ -0,0 +1,160 @@
.. _ars:
.. automodule:: sb3_contrib.ars
ARS
===
Augmented Random Search (ARS) is a simple reinforcement algorithm that uses a direct random search over policy
parameters. It can be surprisingly effective compared to more sophisticated algorithms. In the `original paper <https://arxiv.org/abs/1803.07055>`_ the authors showed that linear policies trained with ARS were competitive with deep reinforcement learning for the MuJuCo locomotion tasks.
SB3s implementation allows for linear policies without bias or squashing function, it also allows for training MLP policies, which include linear policies with bias and squashing functions as a special case.
Normally one wants to train ARS with several seeds to properly evaluate.
.. warning::
ARS multi-processing is different from the classic Stable-Baselines3 multi-processing: it runs n environments
in parallel but asynchronously. This asynchronous multi-processing is considered experimental
and does not fully support callbacks: the ``on_step()`` event is called artificially after the evaluation episodes are over.
.. rubric:: Available Policies
.. autosummary::
:nosignatures:
LinearPolicy
MlpPolicy
Notes
-----
- Original paper: https://arxiv.org/abs/1803.07055
- Original Implementation: https://github.com/modestyachts/ARS
Can I use?
----------
- Recurrent policies: ❌
- Multi processing: ✔️ (cf. example)
- Gym spaces:
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔️ ❌
Box ✔️ ✔️
MultiDiscrete ❌ ❌
MultiBinary ❌ ❌
Dict ❌ ❌
============= ====== ===========
Example
-------
.. code-block:: python
from sb3_contrib import ARS
# Policy can be LinearPolicy or MlpPolicy
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("ars_pendulum")
With experimental asynchronous multi-processing:
.. code-block:: python
from sb3_contrib import ARS
from sb3_contrib.common.vec_env import AsyncEval
from stable_baselines3.common.env_util import make_vec_env
env_id = "CartPole-v1"
n_envs = 2
model = ARS("LinearPolicy", env_id, n_delta=2, n_top=1, verbose=1)
# Create env for asynchronous evaluation (run in different processes)
async_eval = AsyncEval([lambda: make_vec_env(env_id) for _ in range(n_envs)], model.policy)
model.learn(total_timesteps=200_000, log_interval=4, async_eval=async_eval)
Results
-------
Replicating results from the original paper, which used the Mujoco benchmarks. Same parameters from the original paper, using 8 seeds.
============= ============
Environments ARS
============= ============
\
HalfCheetah 4398 +/- 320
Swimmer 241 +/- 51
Hopper 3320 +/- 120
============= ============
How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Clone RL-Zoo and checkout the branch ``feat/ars``
.. code-block:: bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
git checkout feat/ars
Run the benchmark. The following code snippet trains 8 seeds in parallel
.. code-block:: bash
for ENV_ID in Swimmer-v3 HalfCheetah-v3 Hopper-v3
do
for SEED_NUM in {1..8}
do
SEED=$RANDOM
python train.py --algo ars --env $ENV_ID --eval-episodes 10 --eval-freq 10000 -n 20000000 --seed $SEED &
sleep 1
done
wait
done
Plot the results:
.. code-block:: bash
python scripts/all_plots.py -a ars -e HalfCheetah Swimmer Hopper -f logs/ -o logs/ars_results -max 20000000
python scripts/plot_from_file.py -i logs/ars_results.pkl -l ARS
Parameters
----------
.. autoclass:: ARS
:members:
:inherited-members:
.. _ars_policies:
ARS Policies
-------------
.. autoclass:: sb3_contrib.ars.policies.ARSPolicy
:members:
:noindex:
.. autoclass:: LinearPolicy
:members:
:inherited-members:
.. autoclass:: MlpPolicy
:members:

View File

@ -119,3 +119,7 @@ cuda
Polyak
gSDE
rollouts
quantiles
quantile
contrib
Contrib

View File

@ -1,5 +1,6 @@
import os
from sb3_contrib.ars import ARS
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC

View File

@ -0,0 +1,2 @@
from sb3_contrib.ars.ars import ARS
from sb3_contrib.ars.policies import LinearPolicy, MlpPolicy

347
sb3_contrib/ars/ars.py Normal file
View File

@ -0,0 +1,347 @@
import copy
import time
import warnings
from functools import partial
from typing import Any, Dict, Optional, Type, Union
import gym
import numpy as np
import torch as th
import torch.nn.utils
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
from sb3_contrib.ars.policies import ARSPolicy
from sb3_contrib.common.vec_env.async_eval import AsyncEval
class ARS(BaseAlgorithm):
"""
Augmented Random Search: https://arxiv.org/abs/1803.07055
Original implementation: https://github.com/modestyachts/ARS
C++/Cuda Implementation: https://github.com/google-research/tiny-differentiable-simulator/
150 LOC Numpy Implementation: https://github.com/alexis-jacq/numpy_ARS/blob/master/asr.py
:param policy: The policy to train, can be an instance of ``ARSPolicy``, or a string from ["LinearPolicy", "MlpPolicy"]
:param env: The environment to train on, may be a string if registered with gym
:param n_delta: How many random perturbations of the policy to try at each update step.
:param n_top: How many of the top delta to use in each update step. Default is n_delta
:param learning_rate: Float or schedule for the step size
:param delta_std: Float or schedule for the exploration noise
:param zero_policy: Boolean determining if the passed policy should have it's weights zeroed before training.
:param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses.
:param n_eval_episodes: Number of episodes to evaluate each candidate.
:param policy_kwargs: Keyword arguments to pass to the policy on creation
:param policy_base: Base class to use for the policy
:param tensorboard_log: String with the directory to put tensorboard logs:
:param seed: Random seed for the training
:param verbose: Verbosity level: 0 no output, 1 info, 2 debug
:param device: Torch device to use for training, defaults to "cpu"
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
def __init__(
self,
policy: Union[str, Type[ARSPolicy]],
env: Union[GymEnv, str],
n_delta: int = 8,
n_top: Optional[int] = None,
learning_rate: Union[float, Schedule] = 0.02,
delta_std: Union[float, Schedule] = 0.05,
zero_policy: bool = True,
alive_bonus_offset: float = 0,
n_eval_episodes: int = 1,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_base: Type[BasePolicy] = ARSPolicy,
tensorboard_log: Optional[str] = None,
seed: Optional[int] = None,
verbose: int = 0,
device: Union[th.device, str] = "cpu",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate=learning_rate,
tensorboard_log=tensorboard_log,
policy_base=policy_base,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
supported_action_spaces=(gym.spaces.Box, gym.spaces.Discrete),
support_multi_env=True,
seed=seed,
)
self.n_delta = n_delta
self.pop_size = 2 * n_delta
self.delta_std_schedule = get_schedule_fn(delta_std)
self.n_eval_episodes = n_eval_episodes
if n_top is None:
n_top = n_delta
# Make sure our hyper parameters are valid and auto correct them if they are not
if n_top > n_delta:
warnings.warn(f"n_top = {n_top} > n_delta = {n_top}, setting n_top = n_delta")
n_top = n_delta
self.n_top = n_top
self.alive_bonus_offset = alive_bonus_offset
self.zero_policy = zero_policy
self.weights = None # Need to call init model to initialize weight
self.processes = None
# Keep track of how many steps where elapsed before a new rollout
# Important for syncing observation normalization between workers
self.old_count = 0
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
self.policy = self.policy_class(self.observation_space, self.action_space, **self.policy_kwargs)
self.policy = self.policy.to(self.device)
self.weights = th.nn.utils.parameters_to_vector(self.policy.parameters()).detach()
self.n_params = len(self.weights)
if self.zero_policy:
self.weights = th.zeros_like(self.weights, requires_grad=False)
self.policy.load_from_vector(self.weights.cpu())
def _mimic_monitor_wrapper(self, episode_rewards: np.ndarray, episode_lengths: np.ndarray) -> None:
"""
Helper to mimic Monitor wrapper and report episode statistics (mean reward, mean episode length).
:param episode_rewards: List containing per-episode rewards
:param episode_lengths: List containing per-episode lengths (in number of steps)
"""
# Mimic Monitor Wrapper
infos = [
{"episode": {"r": episode_reward, "l": episode_length}}
for episode_reward, episode_length in zip(episode_rewards, episode_lengths)
]
self._update_info_buffer(infos)
def _trigger_callback(
self,
_locals: Dict[str, Any],
_globals: Dict[str, Any],
callback: BaseCallback,
n_envs: int,
) -> None:
"""
Callback passed to the ``evaluate_policy()`` helper
in order to increment the number of timesteps
and trigger events in the single process version.
:param _locals:
:param _globals:
:param callback: Callback that will be called at every step
:param n_envs: Number of environments
"""
self.num_timesteps += n_envs
callback.on_step()
def evaluate_candidates(
self, candidate_weights: th.Tensor, callback: BaseCallback, async_eval: Optional[AsyncEval]
) -> th.Tensor:
"""
Evaluate each candidate.
:param candidate_weights: The candidate weights to be evaluated.
:param callback: Callback that will be called at each step
(or after evaluation in the multiprocess version)
:param async_eval: The object for asynchronous evaluation of candidates.
:return: The episodic return for each candidate.
"""
batch_steps = 0
# returns == sum of rewards
candidate_returns = th.zeros(self.pop_size, device=self.device)
train_policy = copy.deepcopy(self.policy)
# Empty buffer to show only mean over one iteration (one set of candidates) in the logs
self.ep_info_buffer = []
callback.on_rollout_start()
if async_eval is not None:
# Multiprocess asynchronous version
async_eval.send_jobs(candidate_weights, self.pop_size)
results = async_eval.get_results()
for weights_idx, (episode_rewards, episode_lengths) in results:
# Update reward to cancel out alive bonus if needed
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
batch_steps += np.sum(episode_lengths)
self._mimic_monitor_wrapper(episode_rewards, episode_lengths)
# Combine the filter stats of each process for normalization
for worker_obs_rms in async_eval.get_obs_rms():
if self._vec_normalize_env is not None:
# worker_obs_rms.count -= self.old_count
self._vec_normalize_env.obs_rms.combine(worker_obs_rms)
# Hack: don't count timesteps twice (between the two are synced)
# otherwise it will lead to overflow,
# in practice we would need two RunningMeanStats
self._vec_normalize_env.obs_rms.count -= self.old_count
# Synchronise VecNormalize if needed
if self._vec_normalize_env is not None:
async_eval.sync_obs_rms(self._vec_normalize_env.obs_rms.copy())
self.old_count = self._vec_normalize_env.obs_rms.count
# Hack to have Callback events
for _ in range(batch_steps // len(async_eval.remotes)):
self.num_timesteps += len(async_eval.remotes)
callback.on_step()
else:
# Single process, synchronous version
for weights_idx in range(self.pop_size):
# Load current candidate weights
train_policy.load_from_vector(candidate_weights[weights_idx].cpu())
# Evaluate the candidate
episode_rewards, episode_lengths = evaluate_policy(
train_policy,
self.env,
n_eval_episodes=self.n_eval_episodes,
return_episode_rewards=True,
# Increment num_timesteps too (slight mismatch with multi envs)
callback=partial(self._trigger_callback, callback=callback, n_envs=self.env.num_envs),
warn=False,
)
# Update reward to cancel out alive bonus if needed
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
batch_steps += sum(episode_lengths)
self._mimic_monitor_wrapper(episode_rewards, episode_lengths)
# Note: we increment the num_timesteps inside the evaluate_policy()
# however when using multiple environments, there will be a slight
# mismatch between the number of timesteps used and the number
# of calls to the step() method (cf. implementation of evaluate_policy())
# self.num_timesteps += batch_steps
callback.on_rollout_end()
return candidate_returns
def _log_and_dump(self) -> None:
"""
Dump information to the logger.
"""
time_elapsed = time.time() - self.start_time
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
def _do_one_update(self, callback: BaseCallback, async_eval: Optional[AsyncEval]) -> None:
"""
Sample new candidates, evaluate them and then update current policy.
:param callback: callback(s) called at every step with state of the algorithm.
:param async_eval: The object for asynchronous evaluation of candidates.
"""
# Retrieve current parameter noise standard deviation
# and current learning rate
delta_std = self.delta_std_schedule(self._current_progress_remaining)
learning_rate = self.lr_schedule(self._current_progress_remaining)
# Sample the parameter noise, it will be scaled by delta_std
deltas = th.normal(mean=0.0, std=1.0, size=(self.n_delta, self.n_params), device=self.device)
policy_deltas = deltas * delta_std
# Generate 2 * n_delta candidate policies by adding noise to the current weights
candidate_weights = th.cat([self.weights + policy_deltas, self.weights - policy_deltas])
with th.no_grad():
candidate_returns = self.evaluate_candidates(candidate_weights, callback, async_eval)
# Returns corresponding to weights + deltas
plus_returns = candidate_returns[: self.n_delta]
# Returns corresponding to weights - deltas
minus_returns = candidate_returns[self.n_delta :]
# Keep only the top performing candidates for update
top_returns, _ = th.max(th.vstack((plus_returns, minus_returns)), dim=0)
top_idx = th.argsort(top_returns, descending=True)[: self.n_top]
plus_returns = plus_returns[top_idx]
minus_returns = minus_returns[top_idx]
deltas = deltas[top_idx]
# Scale learning rate by the return standard deviation:
# take smaller steps when there is a high variance in the returns
return_std = th.cat([plus_returns, minus_returns]).std()
step_size = learning_rate / (self.n_top * return_std + 1e-6)
# Approximate gradient step
self.weights = self.weights + step_size * ((plus_returns - minus_returns) @ deltas)
self.policy.load_from_vector(self.weights.cpu())
self.logger.record("train/iterations", self._n_updates, exclude="tensorboard")
self.logger.record("train/delta_std", delta_std)
self.logger.record("train/learning_rate", learning_rate)
self.logger.record("train/step_size", step_size.item())
self.logger.record("rollout/return_std", return_std.item())
self._n_updates += 1
def learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "ARS",
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
async_eval: Optional[AsyncEval] = None,
) -> "ARS":
"""
Return a trained model.
:param total_timesteps: The total number of samples (env steps) to train on
:param callback: callback(s) called at every step with state of the algorithm.
:param log_interval: The number of timesteps before logging.
:param tb_log_name: the name of the run for TensorBoard logging
:param eval_env: Environment that will be used to evaluate the agent
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
:param n_eval_episodes: Number of episode to evaluate the agent
:param eval_log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
:param async_eval: The object for asynchronous evaluation of candidates.
:return: the trained model
"""
total_steps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
)
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_steps:
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
self._do_one_update(callback, async_eval)
if log_interval is not None and self._n_updates % log_interval == 0:
self._log_and_dump()
if async_eval is not None:
async_eval.close()
callback.on_training_end()
return self

119
sb3_contrib/ars/policies.py Normal file
View File

@ -0,0 +1,119 @@
from typing import Any, Dict, List, Optional, Type
import gym
import torch as th
from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import create_mlp
from torch import nn
class ARSPolicy(BasePolicy):
"""
Policy network for ARS.
:param observation_space: The observation space of the environment
:param action_space: The action space of the environment
:param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes.
:param activation_fn: Activation function
:param squash_output: For continuous actions, whether the output is squashed
or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
squash_output: bool = True,
):
super().__init__(
observation_space,
action_space,
squash_output=isinstance(action_space, gym.spaces.Box) and squash_output,
)
if net_arch is None:
net_arch = [64, 64]
self.net_arch = net_arch
self.features_extractor = self.make_features_extractor()
self.features_dim = self.features_extractor.features_dim
self.activation_fn = activation_fn
if isinstance(action_space, gym.spaces.Box):
action_dim = get_action_dim(action_space)
actor_net = create_mlp(self.features_dim, action_dim, net_arch, activation_fn, squash_output=True)
elif isinstance(action_space, gym.spaces.Discrete):
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn)
else:
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
self.action_net = nn.Sequential(*actor_net)
def _get_constructor_parameters(self) -> Dict[str, Any]:
# data = super()._get_constructor_parameters() this adds normalize_images, which we don't support...
data = dict(
observation_space=self.observation_space,
action_space=self.action_space,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
)
return data
def forward(self, obs: th.Tensor) -> th.Tensor:
features = self.extract_features(obs)
if isinstance(self.action_space, gym.spaces.Box):
return self.action_net(features)
elif isinstance(self.action_space, gym.spaces.Discrete):
logits = self.action_net(features)
return th.argmax(logits, dim=1)
else:
raise NotImplementedError()
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
# Non deterministic action does not really make sense for ARS, we ignore this parameter for now..
return self.forward(observation)
class ARSLinearPolicy(ARSPolicy):
"""
Linear policy network for ARS.
:param observation_space: The observation space of the environment
:param action_space: The action space of the environment
:param with_bias: With or without bias on the output
:param squash_output: For continuous actions, whether the output is squashed
or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
with_bias: bool = False,
squash_output: bool = False,
):
super().__init__(observation_space, action_space, squash_output=squash_output)
if isinstance(action_space, gym.spaces.Box):
action_dim = get_action_dim(action_space)
self.action_net = nn.Linear(self.features_dim, action_dim, bias=with_bias)
if squash_output:
self.action_net = nn.Sequential(self.action_net, nn.Tanh())
elif isinstance(action_space, gym.spaces.Discrete):
self.action_net = nn.Linear(self.features_dim, action_space.n, bias=with_bias)
else:
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
MlpPolicy = ARSPolicy
LinearPolicy = ARSLinearPolicy
register_policy("LinearPolicy", LinearPolicy)
register_policy("MlpPolicy", MlpPolicy)

View File

@ -0,0 +1 @@
from sb3_contrib.common.vec_env.async_eval import AsyncEval

View File

@ -0,0 +1,208 @@
import multiprocessing
import multiprocessing as mp
from collections import defaultdict
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch as th
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
def _worker(
remote: mp.connection.Connection,
parent_remote: mp.connection.Connection,
worker_env_wrapper: CloudpickleWrapper,
train_policy_wrapper: CloudpickleWrapper,
n_eval_episodes: int = 1,
) -> None:
"""
Function that will be run in each process.
It is in charge of creating environments, evaluating candidates
and communicating with the main process.
:param remote: Pipe to communicate with the parent process.
:param parent_remote:
:param worker_env_wrapper: Callable used to create the environment inside the process.
:param train_policy_wrapper: Callable used to create the policy inside the process.
:param n_eval_episodes: Number of evaluation episodes per candidate.
"""
parent_remote.close()
env = worker_env_wrapper.var()
train_policy = train_policy_wrapper.var
vec_normalize = unwrap_vec_normalize(env)
if vec_normalize is not None:
obs_rms = vec_normalize.obs_rms
else:
obs_rms = None
while True:
try:
cmd, data = remote.recv()
if cmd == "eval":
results = []
# Evaluate each candidate and save results
for weights_idx, candidate_weights in data:
train_policy.load_from_vector(candidate_weights.cpu())
episode_rewards, episode_lengths = evaluate_policy(
train_policy,
env,
n_eval_episodes=n_eval_episodes,
return_episode_rewards=True,
warn=False,
)
results.append((weights_idx, (episode_rewards, episode_lengths)))
remote.send(results)
elif cmd == "seed":
remote.send(env.seed(data))
elif cmd == "get_obs_rms":
remote.send(obs_rms)
elif cmd == "sync_obs_rms":
vec_normalize.obs_rms = data
obs_rms = data
elif cmd == "close":
env.close()
remote.close()
break
else:
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
except EOFError:
break
class AsyncEval(object):
"""
Helper class to do asynchronous evaluation of different policies with multiple processes.
It is useful when implementing population based methods like Evolution Strategies (ES),
Cross Entropy Method (CEM) or Augmented Random Search (ARS).
.. warning::
Only 'forkserver' and 'spawn' start methods are thread-safe,
which is important to avoid race conditions.
However, compared to
'fork' they incur a small start-up cost and have restrictions on
global variables. With those methods, users must wrap the code in an
``if __name__ == "__main__":`` block.
For more information, see the multiprocessing documentation.
:param envs_fn: Vectorized environments to run in subprocesses (callable)
:param train_policy: The policy object that will load the different candidate
weights.
:param start_method: method used to start the subprocesses.
Must be one of the methods returned by ``multiprocessing.get_all_start_methods()``.
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
:param n_eval_episodes: The number of episodes to test each agent
"""
def __init__(
self,
envs_fn: List[Callable[[], VecEnv]],
train_policy: BasePolicy,
start_method: Optional[str] = None,
n_eval_episodes: int = 1,
):
self.waiting = False
self.closed = False
n_envs = len(envs_fn)
if start_method is None:
# Fork is not a thread safe method (see issue #217)
# but is more user friendly (does not require to wrap the code in
# a `if __name__ == "__main__":`)
forkserver_available = "forkserver" in mp.get_all_start_methods()
start_method = "forkserver" if forkserver_available else "spawn"
ctx = mp.get_context(start_method)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
self.processes = []
for work_remote, remote, worker_env in zip(self.work_remotes, self.remotes, envs_fn):
args = (
work_remote,
remote,
CloudpickleWrapper(worker_env),
CloudpickleWrapper(train_policy),
n_eval_episodes,
)
# daemon=True: if the main process crashes, we should not cause things to hang
process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
process.start()
self.processes.append(process)
work_remote.close()
def send_jobs(self, candidate_weights: th.Tensor, pop_size: int) -> None:
"""
Send jobs to the workers to evaluate new candidates.
:param candidate_weights: The weights to be evaluated.
:pop_size: The number of candidate (size of the population)
"""
jobs_per_worker = defaultdict(list)
for weights_idx in range(pop_size):
jobs_per_worker[weights_idx % len(self.remotes)].append((weights_idx, candidate_weights[weights_idx]))
for remote_idx, remote in enumerate(self.remotes):
remote.send(("eval", jobs_per_worker[remote_idx]))
self.waiting = True
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
"""
Seed the environments.
:param seed: The seed for the pseudo-random generators.
:return:
"""
for idx, remote in enumerate(self.remotes):
remote.send(("seed", seed + idx))
return [remote.recv() for remote in self.remotes]
def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]:
"""
Retreive episode rewards and lengths from each worker
for all candidates (there might be multiple candidates per worker)
:return: A list of tuples containing each candidate index and its
result (episodic reward and episode length)
"""
results = [remote.recv() for remote in self.remotes]
flat_results = [result for worker_results in results for result in worker_results]
self.waiting = False
return flat_results
def get_obs_rms(self) -> List[RunningMeanStd]:
"""
Retrieve the observation filters (observation running mean std)
of each process, they will be combined in the main process.
Synchronisation is done afterward using ``sync_obs_rms()``.
:return: A list of ``RunningMeanStd`` objects (one per process)
"""
for remote in self.remotes:
remote.send(("get_obs_rms", None))
return [remote.recv() for remote in self.remotes]
def sync_obs_rms(self, obs_rms: RunningMeanStd) -> None:
"""
Synchronise (and update) the observation filters
(observation running mean std)
:param obs_rms: The updated ``RunningMeanStd`` to be used
by workers for normalizing observations.
"""
for remote in self.remotes:
remote.send(("sync_obs_rms", obs_rms))
def close(self) -> None:
"""
Close the processes.
"""
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(("close", None))
for process in self.processes:
process.join()
self.closed = True

View File

@ -206,8 +206,8 @@ class TRPO(OnPolicyAlgorithm):
policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True)
grad_shape.append(kl_param_grad.shape)
grad_kl.append(kl_param_grad.view(-1))
policy_objective_gradients.append(policy_objective_grad.view(-1))
grad_kl.append(kl_param_grad.reshape(-1))
policy_objective_gradients.append(policy_objective_grad.reshape(-1))
actor_params.append(param)
# Gradients are concatenated before the conjugate gradient step

View File

@ -1 +1 @@
1.3.1a7
1.3.1a9

View File

@ -22,13 +22,15 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators
# Ignore import not used when aliases are defined
per-file-ignores =
./sb3_contrib/__init__.py:F401
./sb3_contrib/ars/__init__.py:F401
./sb3_contrib/ppo_mask/__init__.py:F401
./sb3_contrib/qrdqn/__init__.py:F401
./sb3_contrib/tqc/__init__.py:F401
./sb3_contrib/trpo/__init__.py:F401
./sb3_contrib/common/vec_env/wrappers/__init__.py:F401
./sb3_contrib/common/wrappers/__init__.py:F401
./sb3_contrib/common/envs/__init__.py:F401
./sb3_contrib/common/vec_env/__init__.py:F401
exclude =
# No need to traverse our git directory
.git,

View File

@ -63,7 +63,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.3.1a4",
"stable_baselines3>=1.3.1a9",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",

View File

@ -0,0 +1,55 @@
import pytest
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import ARS, QRDQN, TQC
from sb3_contrib.common.vec_env import AsyncEval
N_STEPS_TRAINING = 500
SEED = 0
ARS_MULTI = "ars_multi"
@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI])
def test_deterministic_training_common(algo):
results = [[], []]
rewards = [[], []]
# Smaller network
kwargs = {"policy_kwargs": dict(net_arch=[64])}
env_id = "Pendulum-v0"
if algo == ARS_MULTI:
algo = ARS
ars_multi = True
else:
ars_multi = False
if algo in [TQC]:
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
else:
if algo == QRDQN:
env_id = "CartPole-v1"
kwargs.update({"learning_starts": 100, "target_update_interval": 100})
elif algo == ARS:
kwargs.update({"n_delta": 2})
for i in range(2):
model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
learn_kwargs = {"total_timesteps": N_STEPS_TRAINING}
if ars_multi:
learn_kwargs["async_eval"] = AsyncEval(
[lambda: VecNormalize(make_vec_env(env_id, seed=SEED, n_envs=1)) for _ in range(2)],
model.policy,
)
model.learn(**learn_kwargs)
env = model.get_env()
obs = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=False)
obs, reward, _, _ = env.step(action)
results[i].append(action)
rewards[i].append(reward)
assert sum(results[0]) == sum(results[1]), results
assert sum(rewards[0]) == sum(rewards[1]), rewards

View File

@ -1,8 +1,10 @@
import gym
import pytest
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import QRDQN, TQC, TRPO
from sb3_contrib import ARS, QRDQN, TQC, TRPO
from sb3_contrib.common.vec_env import AsyncEval
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
@ -82,6 +84,38 @@ def test_trpo_params():
model.learn(total_timesteps=500)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
def test_ars(policy_str, env_id):
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
model.learn(total_timesteps=500, log_interval=1, eval_freq=250)
def test_ars_multi_env():
env = make_vec_env("Pendulum-v0", n_envs=2)
model = ARS("MlpPolicy", env, n_delta=1)
model.learn(total_timesteps=250)
env = VecNormalize(make_vec_env("Pendulum-v0", n_envs=1))
model = ARS("MlpPolicy", env, n_delta=2, seed=0)
# with parallelism
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v0", n_envs=1)) for _ in range(2)], model.policy)
async_eval.seed(0)
model.learn(500, async_eval=async_eval)
@pytest.mark.parametrize("n_top", [2, 8])
def test_ars_n_top(n_top):
n_delta = 3
if n_top > n_delta:
with pytest.warns(UserWarning):
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top)
model.learn(total_timesteps=500)
else:
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top)
model.learn(total_timesteps=500)
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
def test_offpolicy_multi_env(model_class):
if model_class in [TQC]:

View File

@ -12,9 +12,9 @@ from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnv
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib import QRDQN, TQC, TRPO
from sb3_contrib import ARS, QRDQN, TQC, TRPO
MODEL_LIST = [TQC, QRDQN, TRPO]
MODEL_LIST = [ARS, QRDQN, TQC, TRPO]
def select_env(model_class: BaseAlgorithm) -> gym.Env:
@ -266,6 +266,10 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
:param policy_str: (str) Name of the policy.
"""
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
if policy_str == "CnnPolicy" and model_class is ARS:
pytest.skip("ARS does not support CnnPolicy")
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
@ -332,9 +336,11 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
if actor is not None:
actor.save(tmp_path / "actor.pkl")
device = policy.device
del policy, actor
policy = policy_class.load(tmp_path / "policy.pkl")
policy = policy_class.load(tmp_path / "policy.pkl").to(device)
if actor_class is not None:
actor = actor_class.load(tmp_path / "actor.pkl")