Augmented Random Search (ARS) (#42)
* first pass at ars, replicates initial results, still needs more testing, cleanup * add a few docs and tests, bugfixes for ARS * debug and comment * break out dump logs * rollback so there are now predict workers, some refactoring * remove callback from self, remove torch multiprocessing * add module docs * run formatter * fix load and rerun formatter * rename to less mathy variable names, rename _validate_hypers * refactor to use evaluatate_policy, linear policy no longer uses bias or squashing * move everything to torch, add support for discrete action spaces, bugfix for alive reward offset * added tests, passing all of them, add support for discrete action spaces * update documentation * allow for reward offset when there are multiple envs * update results again * Reformat * Ignore unused imports * Renaming + Cleanup * Experimental multiprocessing * Cleaner multiprocessing * Reformat * Fixes for callback * Fix combining stats * 2nd way * Make the implementation cpu only * Fixes + POC with mp module * POC Processes * Cleaner aync implementation * Remove unused arg * Add typing * Revert vec normalize offset hack * Add `squash_output` parameter * Add more tests * Add comments * Update doc * Add comments * Add more logging * Fix TRPO issue on GPU * Tmp fix for ARS tests on GPU * Additional tmp fixes for ARS * update docstrings + formatting, fix bad exceptioe string in ARSPolicy * Add comments and docstrings * Fix missing import * Fix type check * Add dosctrings * GPU support, first attempt * Fix test * Add missing docstring * Typos * Update defaults hyperparameters Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
3b007ae93b
commit
675304d8fa
|
|
@ -29,6 +29,7 @@ See documentation for the full list of included features.
|
||||||
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
|
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
|
||||||
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
|
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
|
||||||
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
|
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
|
||||||
|
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
|
||||||
|
|
||||||
**Gym Wrappers**:
|
**Gym Wrappers**:
|
||||||
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,10 @@ along with some useful characteristics: support for discrete/continuous actions,
|
||||||
============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||||
============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
|
ARS ✔️ ❌️ ❌ ❌ ✔️
|
||||||
|
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
|
||||||
TQC ✔️ ❌ ❌ ❌ ✔️
|
TQC ✔️ ❌ ❌ ❌ ✔️
|
||||||
TRPO ✔️ ✔️ ✔️ ✔️ ✔️
|
TRPO ✔️ ✔️ ✔️ ✔️ ✔️
|
||||||
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
|
|
||||||
============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ Train a PPO with invalid action masking agent on a toy environment.
|
||||||
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
||||||
model = MaskablePPO("MlpPolicy", env, verbose=1)
|
model = MaskablePPO("MlpPolicy", env, verbose=1)
|
||||||
model.learn(5000)
|
model.learn(5000)
|
||||||
model.save("qrdqn_cartpole")
|
model.save("maskable_toy_env")
|
||||||
|
|
||||||
TRPO
|
TRPO
|
||||||
----
|
----
|
||||||
|
|
@ -57,3 +57,17 @@ Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environmen
|
||||||
model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1)
|
model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1)
|
||||||
model.learn(total_timesteps=100_000, log_interval=4)
|
model.learn(total_timesteps=100_000, log_interval=4)
|
||||||
model.save("trpo_pendulum")
|
model.save("trpo_pendulum")
|
||||||
|
|
||||||
|
|
||||||
|
ARS
|
||||||
|
---
|
||||||
|
|
||||||
|
Train an agent using Augmented Random Search (ARS) agent on the Pendulum environment
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from sb3_contrib import ARS
|
||||||
|
|
||||||
|
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1)
|
||||||
|
model.learn(total_timesteps=10000, log_interval=4)
|
||||||
|
model.save("ars_pendulum")
|
||||||
|
|
|
||||||
|
|
@ -31,10 +31,11 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
:caption: RL Algorithms
|
:caption: RL Algorithms
|
||||||
|
|
||||||
|
modules/ars
|
||||||
|
modules/ppo_mask
|
||||||
|
modules/qrdqn
|
||||||
modules/tqc
|
modules/tqc
|
||||||
modules/trpo
|
modules/trpo
|
||||||
modules/qrdqn
|
|
||||||
modules/ppo_mask
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
|
||||||
|
|
@ -4,20 +4,21 @@ Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
Release 1.3.1a7 (WIP)
|
Release 1.3.1a9 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
**Add TRPO**
|
**Add TRPO and ARS**
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
- Dropped python 3.6 support
|
- Dropped python 3.6 support
|
||||||
- Upgraded to Stable-Baselines3 >= 1.3.1a4
|
- Upgraded to Stable-Baselines3 >= 1.3.1a9
|
||||||
- ``MaskablePPO`` was updated to match latest SB3 ``PPO`` version (timeout handling and new method for the policy object)
|
- ``MaskablePPO`` was updated to match latest SB3 ``PPO`` version (timeout handling and new method for the policy object)
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
- Added ``TRPO`` (@cyprienc)
|
- Added ``TRPO`` (@cyprienc)
|
||||||
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
|
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
|
||||||
|
- Added Augmented Random Search (ARS) (@sgillen)
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
@ -221,4 +222,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
|
||||||
Contributors:
|
Contributors:
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc
|
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,160 @@
|
||||||
|
.. _ars:
|
||||||
|
|
||||||
|
.. automodule:: sb3_contrib.ars
|
||||||
|
|
||||||
|
|
||||||
|
ARS
|
||||||
|
===
|
||||||
|
|
||||||
|
|
||||||
|
Augmented Random Search (ARS) is a simple reinforcement algorithm that uses a direct random search over policy
|
||||||
|
parameters. It can be surprisingly effective compared to more sophisticated algorithms. In the `original paper <https://arxiv.org/abs/1803.07055>`_ the authors showed that linear policies trained with ARS were competitive with deep reinforcement learning for the MuJuCo locomotion tasks.
|
||||||
|
|
||||||
|
SB3s implementation allows for linear policies without bias or squashing function, it also allows for training MLP policies, which include linear policies with bias and squashing functions as a special case.
|
||||||
|
|
||||||
|
Normally one wants to train ARS with several seeds to properly evaluate.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
ARS multi-processing is different from the classic Stable-Baselines3 multi-processing: it runs n environments
|
||||||
|
in parallel but asynchronously. This asynchronous multi-processing is considered experimental
|
||||||
|
and does not fully support callbacks: the ``on_step()`` event is called artificially after the evaluation episodes are over.
|
||||||
|
|
||||||
|
|
||||||
|
.. rubric:: Available Policies
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
LinearPolicy
|
||||||
|
MlpPolicy
|
||||||
|
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
|
||||||
|
- Original paper: https://arxiv.org/abs/1803.07055
|
||||||
|
- Original Implementation: https://github.com/modestyachts/ARS
|
||||||
|
|
||||||
|
|
||||||
|
Can I use?
|
||||||
|
----------
|
||||||
|
|
||||||
|
- Recurrent policies: ❌
|
||||||
|
- Multi processing: ✔️ (cf. example)
|
||||||
|
- Gym spaces:
|
||||||
|
|
||||||
|
|
||||||
|
============= ====== ===========
|
||||||
|
Space Action Observation
|
||||||
|
============= ====== ===========
|
||||||
|
Discrete ✔️ ❌
|
||||||
|
Box ✔️ ✔️
|
||||||
|
MultiDiscrete ❌ ❌
|
||||||
|
MultiBinary ❌ ❌
|
||||||
|
Dict ❌ ❌
|
||||||
|
============= ====== ===========
|
||||||
|
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from sb3_contrib import ARS
|
||||||
|
|
||||||
|
# Policy can be LinearPolicy or MlpPolicy
|
||||||
|
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1)
|
||||||
|
model.learn(total_timesteps=10000, log_interval=4)
|
||||||
|
model.save("ars_pendulum")
|
||||||
|
|
||||||
|
|
||||||
|
With experimental asynchronous multi-processing:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from sb3_contrib import ARS
|
||||||
|
from sb3_contrib.common.vec_env import AsyncEval
|
||||||
|
|
||||||
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
|
|
||||||
|
env_id = "CartPole-v1"
|
||||||
|
n_envs = 2
|
||||||
|
|
||||||
|
model = ARS("LinearPolicy", env_id, n_delta=2, n_top=1, verbose=1)
|
||||||
|
# Create env for asynchronous evaluation (run in different processes)
|
||||||
|
async_eval = AsyncEval([lambda: make_vec_env(env_id) for _ in range(n_envs)], model.policy)
|
||||||
|
|
||||||
|
model.learn(total_timesteps=200_000, log_interval=4, async_eval=async_eval)
|
||||||
|
|
||||||
|
|
||||||
|
Results
|
||||||
|
-------
|
||||||
|
|
||||||
|
Replicating results from the original paper, which used the Mujoco benchmarks. Same parameters from the original paper, using 8 seeds.
|
||||||
|
|
||||||
|
============= ============
|
||||||
|
Environments ARS
|
||||||
|
============= ============
|
||||||
|
\
|
||||||
|
HalfCheetah 4398 +/- 320
|
||||||
|
Swimmer 241 +/- 51
|
||||||
|
Hopper 3320 +/- 120
|
||||||
|
============= ============
|
||||||
|
|
||||||
|
How to replicate the results?
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
Clone RL-Zoo and checkout the branch ``feat/ars``
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||||
|
cd rl-baselines3-zoo/
|
||||||
|
git checkout feat/ars
|
||||||
|
|
||||||
|
Run the benchmark. The following code snippet trains 8 seeds in parallel
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
for ENV_ID in Swimmer-v3 HalfCheetah-v3 Hopper-v3
|
||||||
|
do
|
||||||
|
for SEED_NUM in {1..8}
|
||||||
|
do
|
||||||
|
SEED=$RANDOM
|
||||||
|
python train.py --algo ars --env $ENV_ID --eval-episodes 10 --eval-freq 10000 -n 20000000 --seed $SEED &
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
done
|
||||||
|
|
||||||
|
Plot the results:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python scripts/all_plots.py -a ars -e HalfCheetah Swimmer Hopper -f logs/ -o logs/ars_results -max 20000000
|
||||||
|
python scripts/plot_from_file.py -i logs/ars_results.pkl -l ARS
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
|
||||||
|
.. autoclass:: ARS
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
|
||||||
|
.. _ars_policies:
|
||||||
|
|
||||||
|
ARS Policies
|
||||||
|
-------------
|
||||||
|
|
||||||
|
.. autoclass:: sb3_contrib.ars.policies.ARSPolicy
|
||||||
|
:members:
|
||||||
|
:noindex:
|
||||||
|
|
||||||
|
.. autoclass:: LinearPolicy
|
||||||
|
:members:
|
||||||
|
:inherited-members:
|
||||||
|
|
||||||
|
.. autoclass:: MlpPolicy
|
||||||
|
:members:
|
||||||
|
|
@ -119,3 +119,7 @@ cuda
|
||||||
Polyak
|
Polyak
|
||||||
gSDE
|
gSDE
|
||||||
rollouts
|
rollouts
|
||||||
|
quantiles
|
||||||
|
quantile
|
||||||
|
contrib
|
||||||
|
Contrib
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from sb3_contrib.ars import ARS
|
||||||
from sb3_contrib.ppo_mask import MaskablePPO
|
from sb3_contrib.ppo_mask import MaskablePPO
|
||||||
from sb3_contrib.qrdqn import QRDQN
|
from sb3_contrib.qrdqn import QRDQN
|
||||||
from sb3_contrib.tqc import TQC
|
from sb3_contrib.tqc import TQC
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
from sb3_contrib.ars.ars import ARS
|
||||||
|
from sb3_contrib.ars.policies import LinearPolicy, MlpPolicy
|
||||||
|
|
@ -0,0 +1,347 @@
|
||||||
|
import copy
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, Optional, Type, Union
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
import torch.nn.utils
|
||||||
|
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||||
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
|
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
|
||||||
|
|
||||||
|
from sb3_contrib.ars.policies import ARSPolicy
|
||||||
|
from sb3_contrib.common.vec_env.async_eval import AsyncEval
|
||||||
|
|
||||||
|
|
||||||
|
class ARS(BaseAlgorithm):
|
||||||
|
"""
|
||||||
|
Augmented Random Search: https://arxiv.org/abs/1803.07055
|
||||||
|
|
||||||
|
Original implementation: https://github.com/modestyachts/ARS
|
||||||
|
C++/Cuda Implementation: https://github.com/google-research/tiny-differentiable-simulator/
|
||||||
|
150 LOC Numpy Implementation: https://github.com/alexis-jacq/numpy_ARS/blob/master/asr.py
|
||||||
|
|
||||||
|
:param policy: The policy to train, can be an instance of ``ARSPolicy``, or a string from ["LinearPolicy", "MlpPolicy"]
|
||||||
|
:param env: The environment to train on, may be a string if registered with gym
|
||||||
|
:param n_delta: How many random perturbations of the policy to try at each update step.
|
||||||
|
:param n_top: How many of the top delta to use in each update step. Default is n_delta
|
||||||
|
:param learning_rate: Float or schedule for the step size
|
||||||
|
:param delta_std: Float or schedule for the exploration noise
|
||||||
|
:param zero_policy: Boolean determining if the passed policy should have it's weights zeroed before training.
|
||||||
|
:param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses.
|
||||||
|
:param n_eval_episodes: Number of episodes to evaluate each candidate.
|
||||||
|
:param policy_kwargs: Keyword arguments to pass to the policy on creation
|
||||||
|
:param policy_base: Base class to use for the policy
|
||||||
|
:param tensorboard_log: String with the directory to put tensorboard logs:
|
||||||
|
:param seed: Random seed for the training
|
||||||
|
:param verbose: Verbosity level: 0 no output, 1 info, 2 debug
|
||||||
|
:param device: Torch device to use for training, defaults to "cpu"
|
||||||
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
policy: Union[str, Type[ARSPolicy]],
|
||||||
|
env: Union[GymEnv, str],
|
||||||
|
n_delta: int = 8,
|
||||||
|
n_top: Optional[int] = None,
|
||||||
|
learning_rate: Union[float, Schedule] = 0.02,
|
||||||
|
delta_std: Union[float, Schedule] = 0.05,
|
||||||
|
zero_policy: bool = True,
|
||||||
|
alive_bonus_offset: float = 0,
|
||||||
|
n_eval_episodes: int = 1,
|
||||||
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
policy_base: Type[BasePolicy] = ARSPolicy,
|
||||||
|
tensorboard_log: Optional[str] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
verbose: int = 0,
|
||||||
|
device: Union[th.device, str] = "cpu",
|
||||||
|
_init_setup_model: bool = True,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
policy,
|
||||||
|
env,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
tensorboard_log=tensorboard_log,
|
||||||
|
policy_base=policy_base,
|
||||||
|
policy_kwargs=policy_kwargs,
|
||||||
|
verbose=verbose,
|
||||||
|
device=device,
|
||||||
|
supported_action_spaces=(gym.spaces.Box, gym.spaces.Discrete),
|
||||||
|
support_multi_env=True,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.n_delta = n_delta
|
||||||
|
self.pop_size = 2 * n_delta
|
||||||
|
self.delta_std_schedule = get_schedule_fn(delta_std)
|
||||||
|
self.n_eval_episodes = n_eval_episodes
|
||||||
|
|
||||||
|
if n_top is None:
|
||||||
|
n_top = n_delta
|
||||||
|
|
||||||
|
# Make sure our hyper parameters are valid and auto correct them if they are not
|
||||||
|
if n_top > n_delta:
|
||||||
|
warnings.warn(f"n_top = {n_top} > n_delta = {n_top}, setting n_top = n_delta")
|
||||||
|
n_top = n_delta
|
||||||
|
|
||||||
|
self.n_top = n_top
|
||||||
|
|
||||||
|
self.alive_bonus_offset = alive_bonus_offset
|
||||||
|
self.zero_policy = zero_policy
|
||||||
|
self.weights = None # Need to call init model to initialize weight
|
||||||
|
self.processes = None
|
||||||
|
# Keep track of how many steps where elapsed before a new rollout
|
||||||
|
# Important for syncing observation normalization between workers
|
||||||
|
self.old_count = 0
|
||||||
|
|
||||||
|
if _init_setup_model:
|
||||||
|
self._setup_model()
|
||||||
|
|
||||||
|
def _setup_model(self) -> None:
|
||||||
|
self._setup_lr_schedule()
|
||||||
|
self.set_random_seed(self.seed)
|
||||||
|
|
||||||
|
self.policy = self.policy_class(self.observation_space, self.action_space, **self.policy_kwargs)
|
||||||
|
self.policy = self.policy.to(self.device)
|
||||||
|
self.weights = th.nn.utils.parameters_to_vector(self.policy.parameters()).detach()
|
||||||
|
self.n_params = len(self.weights)
|
||||||
|
|
||||||
|
if self.zero_policy:
|
||||||
|
self.weights = th.zeros_like(self.weights, requires_grad=False)
|
||||||
|
self.policy.load_from_vector(self.weights.cpu())
|
||||||
|
|
||||||
|
def _mimic_monitor_wrapper(self, episode_rewards: np.ndarray, episode_lengths: np.ndarray) -> None:
|
||||||
|
"""
|
||||||
|
Helper to mimic Monitor wrapper and report episode statistics (mean reward, mean episode length).
|
||||||
|
|
||||||
|
:param episode_rewards: List containing per-episode rewards
|
||||||
|
:param episode_lengths: List containing per-episode lengths (in number of steps)
|
||||||
|
"""
|
||||||
|
# Mimic Monitor Wrapper
|
||||||
|
infos = [
|
||||||
|
{"episode": {"r": episode_reward, "l": episode_length}}
|
||||||
|
for episode_reward, episode_length in zip(episode_rewards, episode_lengths)
|
||||||
|
]
|
||||||
|
|
||||||
|
self._update_info_buffer(infos)
|
||||||
|
|
||||||
|
def _trigger_callback(
|
||||||
|
self,
|
||||||
|
_locals: Dict[str, Any],
|
||||||
|
_globals: Dict[str, Any],
|
||||||
|
callback: BaseCallback,
|
||||||
|
n_envs: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Callback passed to the ``evaluate_policy()`` helper
|
||||||
|
in order to increment the number of timesteps
|
||||||
|
and trigger events in the single process version.
|
||||||
|
|
||||||
|
:param _locals:
|
||||||
|
:param _globals:
|
||||||
|
:param callback: Callback that will be called at every step
|
||||||
|
:param n_envs: Number of environments
|
||||||
|
"""
|
||||||
|
self.num_timesteps += n_envs
|
||||||
|
callback.on_step()
|
||||||
|
|
||||||
|
def evaluate_candidates(
|
||||||
|
self, candidate_weights: th.Tensor, callback: BaseCallback, async_eval: Optional[AsyncEval]
|
||||||
|
) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
Evaluate each candidate.
|
||||||
|
|
||||||
|
:param candidate_weights: The candidate weights to be evaluated.
|
||||||
|
:param callback: Callback that will be called at each step
|
||||||
|
(or after evaluation in the multiprocess version)
|
||||||
|
:param async_eval: The object for asynchronous evaluation of candidates.
|
||||||
|
:return: The episodic return for each candidate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_steps = 0
|
||||||
|
# returns == sum of rewards
|
||||||
|
candidate_returns = th.zeros(self.pop_size, device=self.device)
|
||||||
|
train_policy = copy.deepcopy(self.policy)
|
||||||
|
# Empty buffer to show only mean over one iteration (one set of candidates) in the logs
|
||||||
|
self.ep_info_buffer = []
|
||||||
|
callback.on_rollout_start()
|
||||||
|
|
||||||
|
if async_eval is not None:
|
||||||
|
# Multiprocess asynchronous version
|
||||||
|
async_eval.send_jobs(candidate_weights, self.pop_size)
|
||||||
|
results = async_eval.get_results()
|
||||||
|
|
||||||
|
for weights_idx, (episode_rewards, episode_lengths) in results:
|
||||||
|
|
||||||
|
# Update reward to cancel out alive bonus if needed
|
||||||
|
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
|
||||||
|
batch_steps += np.sum(episode_lengths)
|
||||||
|
self._mimic_monitor_wrapper(episode_rewards, episode_lengths)
|
||||||
|
|
||||||
|
# Combine the filter stats of each process for normalization
|
||||||
|
for worker_obs_rms in async_eval.get_obs_rms():
|
||||||
|
if self._vec_normalize_env is not None:
|
||||||
|
# worker_obs_rms.count -= self.old_count
|
||||||
|
self._vec_normalize_env.obs_rms.combine(worker_obs_rms)
|
||||||
|
# Hack: don't count timesteps twice (between the two are synced)
|
||||||
|
# otherwise it will lead to overflow,
|
||||||
|
# in practice we would need two RunningMeanStats
|
||||||
|
self._vec_normalize_env.obs_rms.count -= self.old_count
|
||||||
|
|
||||||
|
# Synchronise VecNormalize if needed
|
||||||
|
if self._vec_normalize_env is not None:
|
||||||
|
async_eval.sync_obs_rms(self._vec_normalize_env.obs_rms.copy())
|
||||||
|
self.old_count = self._vec_normalize_env.obs_rms.count
|
||||||
|
|
||||||
|
# Hack to have Callback events
|
||||||
|
for _ in range(batch_steps // len(async_eval.remotes)):
|
||||||
|
self.num_timesteps += len(async_eval.remotes)
|
||||||
|
callback.on_step()
|
||||||
|
else:
|
||||||
|
# Single process, synchronous version
|
||||||
|
for weights_idx in range(self.pop_size):
|
||||||
|
|
||||||
|
# Load current candidate weights
|
||||||
|
train_policy.load_from_vector(candidate_weights[weights_idx].cpu())
|
||||||
|
# Evaluate the candidate
|
||||||
|
episode_rewards, episode_lengths = evaluate_policy(
|
||||||
|
train_policy,
|
||||||
|
self.env,
|
||||||
|
n_eval_episodes=self.n_eval_episodes,
|
||||||
|
return_episode_rewards=True,
|
||||||
|
# Increment num_timesteps too (slight mismatch with multi envs)
|
||||||
|
callback=partial(self._trigger_callback, callback=callback, n_envs=self.env.num_envs),
|
||||||
|
warn=False,
|
||||||
|
)
|
||||||
|
# Update reward to cancel out alive bonus if needed
|
||||||
|
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
|
||||||
|
batch_steps += sum(episode_lengths)
|
||||||
|
self._mimic_monitor_wrapper(episode_rewards, episode_lengths)
|
||||||
|
|
||||||
|
# Note: we increment the num_timesteps inside the evaluate_policy()
|
||||||
|
# however when using multiple environments, there will be a slight
|
||||||
|
# mismatch between the number of timesteps used and the number
|
||||||
|
# of calls to the step() method (cf. implementation of evaluate_policy())
|
||||||
|
# self.num_timesteps += batch_steps
|
||||||
|
|
||||||
|
callback.on_rollout_end()
|
||||||
|
|
||||||
|
return candidate_returns
|
||||||
|
|
||||||
|
def _log_and_dump(self) -> None:
|
||||||
|
"""
|
||||||
|
Dump information to the logger.
|
||||||
|
"""
|
||||||
|
time_elapsed = time.time() - self.start_time
|
||||||
|
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
|
||||||
|
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||||
|
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||||
|
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||||
|
self.logger.record("time/fps", fps)
|
||||||
|
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
||||||
|
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||||
|
self.logger.dump(step=self.num_timesteps)
|
||||||
|
|
||||||
|
def _do_one_update(self, callback: BaseCallback, async_eval: Optional[AsyncEval]) -> None:
|
||||||
|
"""
|
||||||
|
Sample new candidates, evaluate them and then update current policy.
|
||||||
|
|
||||||
|
:param callback: callback(s) called at every step with state of the algorithm.
|
||||||
|
:param async_eval: The object for asynchronous evaluation of candidates.
|
||||||
|
"""
|
||||||
|
# Retrieve current parameter noise standard deviation
|
||||||
|
# and current learning rate
|
||||||
|
delta_std = self.delta_std_schedule(self._current_progress_remaining)
|
||||||
|
learning_rate = self.lr_schedule(self._current_progress_remaining)
|
||||||
|
# Sample the parameter noise, it will be scaled by delta_std
|
||||||
|
deltas = th.normal(mean=0.0, std=1.0, size=(self.n_delta, self.n_params), device=self.device)
|
||||||
|
policy_deltas = deltas * delta_std
|
||||||
|
# Generate 2 * n_delta candidate policies by adding noise to the current weights
|
||||||
|
candidate_weights = th.cat([self.weights + policy_deltas, self.weights - policy_deltas])
|
||||||
|
|
||||||
|
with th.no_grad():
|
||||||
|
candidate_returns = self.evaluate_candidates(candidate_weights, callback, async_eval)
|
||||||
|
|
||||||
|
# Returns corresponding to weights + deltas
|
||||||
|
plus_returns = candidate_returns[: self.n_delta]
|
||||||
|
# Returns corresponding to weights - deltas
|
||||||
|
minus_returns = candidate_returns[self.n_delta :]
|
||||||
|
|
||||||
|
# Keep only the top performing candidates for update
|
||||||
|
top_returns, _ = th.max(th.vstack((plus_returns, minus_returns)), dim=0)
|
||||||
|
top_idx = th.argsort(top_returns, descending=True)[: self.n_top]
|
||||||
|
plus_returns = plus_returns[top_idx]
|
||||||
|
minus_returns = minus_returns[top_idx]
|
||||||
|
deltas = deltas[top_idx]
|
||||||
|
|
||||||
|
# Scale learning rate by the return standard deviation:
|
||||||
|
# take smaller steps when there is a high variance in the returns
|
||||||
|
return_std = th.cat([plus_returns, minus_returns]).std()
|
||||||
|
step_size = learning_rate / (self.n_top * return_std + 1e-6)
|
||||||
|
# Approximate gradient step
|
||||||
|
self.weights = self.weights + step_size * ((plus_returns - minus_returns) @ deltas)
|
||||||
|
self.policy.load_from_vector(self.weights.cpu())
|
||||||
|
|
||||||
|
self.logger.record("train/iterations", self._n_updates, exclude="tensorboard")
|
||||||
|
self.logger.record("train/delta_std", delta_std)
|
||||||
|
self.logger.record("train/learning_rate", learning_rate)
|
||||||
|
self.logger.record("train/step_size", step_size.item())
|
||||||
|
self.logger.record("rollout/return_std", return_std.item())
|
||||||
|
|
||||||
|
self._n_updates += 1
|
||||||
|
|
||||||
|
def learn(
|
||||||
|
self,
|
||||||
|
total_timesteps: int,
|
||||||
|
callback: MaybeCallback = None,
|
||||||
|
log_interval: int = 1,
|
||||||
|
tb_log_name: str = "ARS",
|
||||||
|
eval_env: Optional[GymEnv] = None,
|
||||||
|
eval_freq: int = -1,
|
||||||
|
n_eval_episodes: int = 5,
|
||||||
|
eval_log_path: Optional[str] = None,
|
||||||
|
reset_num_timesteps: bool = True,
|
||||||
|
async_eval: Optional[AsyncEval] = None,
|
||||||
|
) -> "ARS":
|
||||||
|
"""
|
||||||
|
Return a trained model.
|
||||||
|
|
||||||
|
:param total_timesteps: The total number of samples (env steps) to train on
|
||||||
|
:param callback: callback(s) called at every step with state of the algorithm.
|
||||||
|
:param log_interval: The number of timesteps before logging.
|
||||||
|
:param tb_log_name: the name of the run for TensorBoard logging
|
||||||
|
:param eval_env: Environment that will be used to evaluate the agent
|
||||||
|
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
|
||||||
|
:param n_eval_episodes: Number of episode to evaluate the agent
|
||||||
|
:param eval_log_path: Path to a folder where the evaluations will be saved
|
||||||
|
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
||||||
|
:param async_eval: The object for asynchronous evaluation of candidates.
|
||||||
|
:return: the trained model
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_steps, callback = self._setup_learn(
|
||||||
|
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
|
||||||
|
)
|
||||||
|
|
||||||
|
callback.on_training_start(locals(), globals())
|
||||||
|
|
||||||
|
while self.num_timesteps < total_steps:
|
||||||
|
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
||||||
|
self._do_one_update(callback, async_eval)
|
||||||
|
if log_interval is not None and self._n_updates % log_interval == 0:
|
||||||
|
self._log_and_dump()
|
||||||
|
|
||||||
|
if async_eval is not None:
|
||||||
|
async_eval.close()
|
||||||
|
|
||||||
|
callback.on_training_end()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
@ -0,0 +1,119 @@
|
||||||
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import torch as th
|
||||||
|
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||||
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
|
from stable_baselines3.common.torch_layers import create_mlp
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ARSPolicy(BasePolicy):
|
||||||
|
"""
|
||||||
|
Policy network for ARS.
|
||||||
|
|
||||||
|
:param observation_space: The observation space of the environment
|
||||||
|
:param action_space: The action space of the environment
|
||||||
|
:param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param squash_output: For continuous actions, whether the output is squashed
|
||||||
|
or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
net_arch: Optional[List[int]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
squash_output: bool = True,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
squash_output=isinstance(action_space, gym.spaces.Box) and squash_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
if net_arch is None:
|
||||||
|
net_arch = [64, 64]
|
||||||
|
|
||||||
|
self.net_arch = net_arch
|
||||||
|
self.features_extractor = self.make_features_extractor()
|
||||||
|
self.features_dim = self.features_extractor.features_dim
|
||||||
|
self.activation_fn = activation_fn
|
||||||
|
|
||||||
|
if isinstance(action_space, gym.spaces.Box):
|
||||||
|
action_dim = get_action_dim(action_space)
|
||||||
|
actor_net = create_mlp(self.features_dim, action_dim, net_arch, activation_fn, squash_output=True)
|
||||||
|
elif isinstance(action_space, gym.spaces.Discrete):
|
||||||
|
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
||||||
|
|
||||||
|
self.action_net = nn.Sequential(*actor_net)
|
||||||
|
|
||||||
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
|
# data = super()._get_constructor_parameters() this adds normalize_images, which we don't support...
|
||||||
|
data = dict(
|
||||||
|
observation_space=self.observation_space,
|
||||||
|
action_space=self.action_space,
|
||||||
|
net_arch=self.net_arch,
|
||||||
|
activation_fn=self.activation_fn,
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||||
|
|
||||||
|
features = self.extract_features(obs)
|
||||||
|
if isinstance(self.action_space, gym.spaces.Box):
|
||||||
|
return self.action_net(features)
|
||||||
|
elif isinstance(self.action_space, gym.spaces.Discrete):
|
||||||
|
logits = self.action_net(features)
|
||||||
|
return th.argmax(logits, dim=1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||||
|
# Non deterministic action does not really make sense for ARS, we ignore this parameter for now..
|
||||||
|
return self.forward(observation)
|
||||||
|
|
||||||
|
|
||||||
|
class ARSLinearPolicy(ARSPolicy):
|
||||||
|
"""
|
||||||
|
Linear policy network for ARS.
|
||||||
|
|
||||||
|
:param observation_space: The observation space of the environment
|
||||||
|
:param action_space: The action space of the environment
|
||||||
|
:param with_bias: With or without bias on the output
|
||||||
|
:param squash_output: For continuous actions, whether the output is squashed
|
||||||
|
or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
with_bias: bool = False,
|
||||||
|
squash_output: bool = False,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(observation_space, action_space, squash_output=squash_output)
|
||||||
|
|
||||||
|
if isinstance(action_space, gym.spaces.Box):
|
||||||
|
action_dim = get_action_dim(action_space)
|
||||||
|
self.action_net = nn.Linear(self.features_dim, action_dim, bias=with_bias)
|
||||||
|
if squash_output:
|
||||||
|
self.action_net = nn.Sequential(self.action_net, nn.Tanh())
|
||||||
|
elif isinstance(action_space, gym.spaces.Discrete):
|
||||||
|
self.action_net = nn.Linear(self.features_dim, action_space.n, bias=with_bias)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
||||||
|
|
||||||
|
|
||||||
|
MlpPolicy = ARSPolicy
|
||||||
|
LinearPolicy = ARSLinearPolicy
|
||||||
|
|
||||||
|
|
||||||
|
register_policy("LinearPolicy", LinearPolicy)
|
||||||
|
register_policy("MlpPolicy", MlpPolicy)
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from sb3_contrib.common.vec_env.async_eval import AsyncEval
|
||||||
|
|
@ -0,0 +1,208 @@
|
||||||
|
import multiprocessing
|
||||||
|
import multiprocessing as mp
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
|
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
||||||
|
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
|
||||||
|
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _worker(
|
||||||
|
remote: mp.connection.Connection,
|
||||||
|
parent_remote: mp.connection.Connection,
|
||||||
|
worker_env_wrapper: CloudpickleWrapper,
|
||||||
|
train_policy_wrapper: CloudpickleWrapper,
|
||||||
|
n_eval_episodes: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Function that will be run in each process.
|
||||||
|
It is in charge of creating environments, evaluating candidates
|
||||||
|
and communicating with the main process.
|
||||||
|
|
||||||
|
:param remote: Pipe to communicate with the parent process.
|
||||||
|
:param parent_remote:
|
||||||
|
:param worker_env_wrapper: Callable used to create the environment inside the process.
|
||||||
|
:param train_policy_wrapper: Callable used to create the policy inside the process.
|
||||||
|
:param n_eval_episodes: Number of evaluation episodes per candidate.
|
||||||
|
"""
|
||||||
|
parent_remote.close()
|
||||||
|
env = worker_env_wrapper.var()
|
||||||
|
train_policy = train_policy_wrapper.var
|
||||||
|
vec_normalize = unwrap_vec_normalize(env)
|
||||||
|
if vec_normalize is not None:
|
||||||
|
obs_rms = vec_normalize.obs_rms
|
||||||
|
else:
|
||||||
|
obs_rms = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
cmd, data = remote.recv()
|
||||||
|
if cmd == "eval":
|
||||||
|
results = []
|
||||||
|
# Evaluate each candidate and save results
|
||||||
|
for weights_idx, candidate_weights in data:
|
||||||
|
train_policy.load_from_vector(candidate_weights.cpu())
|
||||||
|
episode_rewards, episode_lengths = evaluate_policy(
|
||||||
|
train_policy,
|
||||||
|
env,
|
||||||
|
n_eval_episodes=n_eval_episodes,
|
||||||
|
return_episode_rewards=True,
|
||||||
|
warn=False,
|
||||||
|
)
|
||||||
|
results.append((weights_idx, (episode_rewards, episode_lengths)))
|
||||||
|
remote.send(results)
|
||||||
|
elif cmd == "seed":
|
||||||
|
remote.send(env.seed(data))
|
||||||
|
elif cmd == "get_obs_rms":
|
||||||
|
remote.send(obs_rms)
|
||||||
|
elif cmd == "sync_obs_rms":
|
||||||
|
vec_normalize.obs_rms = data
|
||||||
|
obs_rms = data
|
||||||
|
elif cmd == "close":
|
||||||
|
env.close()
|
||||||
|
remote.close()
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncEval(object):
|
||||||
|
"""
|
||||||
|
Helper class to do asynchronous evaluation of different policies with multiple processes.
|
||||||
|
It is useful when implementing population based methods like Evolution Strategies (ES),
|
||||||
|
Cross Entropy Method (CEM) or Augmented Random Search (ARS).
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Only 'forkserver' and 'spawn' start methods are thread-safe,
|
||||||
|
which is important to avoid race conditions.
|
||||||
|
However, compared to
|
||||||
|
'fork' they incur a small start-up cost and have restrictions on
|
||||||
|
global variables. With those methods, users must wrap the code in an
|
||||||
|
``if __name__ == "__main__":`` block.
|
||||||
|
For more information, see the multiprocessing documentation.
|
||||||
|
|
||||||
|
:param envs_fn: Vectorized environments to run in subprocesses (callable)
|
||||||
|
:param train_policy: The policy object that will load the different candidate
|
||||||
|
weights.
|
||||||
|
:param start_method: method used to start the subprocesses.
|
||||||
|
Must be one of the methods returned by ``multiprocessing.get_all_start_methods()``.
|
||||||
|
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
|
||||||
|
:param n_eval_episodes: The number of episodes to test each agent
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
envs_fn: List[Callable[[], VecEnv]],
|
||||||
|
train_policy: BasePolicy,
|
||||||
|
start_method: Optional[str] = None,
|
||||||
|
n_eval_episodes: int = 1,
|
||||||
|
):
|
||||||
|
self.waiting = False
|
||||||
|
self.closed = False
|
||||||
|
n_envs = len(envs_fn)
|
||||||
|
|
||||||
|
if start_method is None:
|
||||||
|
# Fork is not a thread safe method (see issue #217)
|
||||||
|
# but is more user friendly (does not require to wrap the code in
|
||||||
|
# a `if __name__ == "__main__":`)
|
||||||
|
forkserver_available = "forkserver" in mp.get_all_start_methods()
|
||||||
|
start_method = "forkserver" if forkserver_available else "spawn"
|
||||||
|
ctx = mp.get_context(start_method)
|
||||||
|
|
||||||
|
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
|
||||||
|
self.processes = []
|
||||||
|
for work_remote, remote, worker_env in zip(self.work_remotes, self.remotes, envs_fn):
|
||||||
|
args = (
|
||||||
|
work_remote,
|
||||||
|
remote,
|
||||||
|
CloudpickleWrapper(worker_env),
|
||||||
|
CloudpickleWrapper(train_policy),
|
||||||
|
n_eval_episodes,
|
||||||
|
)
|
||||||
|
# daemon=True: if the main process crashes, we should not cause things to hang
|
||||||
|
process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
|
||||||
|
process.start()
|
||||||
|
self.processes.append(process)
|
||||||
|
work_remote.close()
|
||||||
|
|
||||||
|
def send_jobs(self, candidate_weights: th.Tensor, pop_size: int) -> None:
|
||||||
|
"""
|
||||||
|
Send jobs to the workers to evaluate new candidates.
|
||||||
|
|
||||||
|
:param candidate_weights: The weights to be evaluated.
|
||||||
|
:pop_size: The number of candidate (size of the population)
|
||||||
|
"""
|
||||||
|
jobs_per_worker = defaultdict(list)
|
||||||
|
for weights_idx in range(pop_size):
|
||||||
|
jobs_per_worker[weights_idx % len(self.remotes)].append((weights_idx, candidate_weights[weights_idx]))
|
||||||
|
|
||||||
|
for remote_idx, remote in enumerate(self.remotes):
|
||||||
|
remote.send(("eval", jobs_per_worker[remote_idx]))
|
||||||
|
self.waiting = True
|
||||||
|
|
||||||
|
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||||
|
"""
|
||||||
|
Seed the environments.
|
||||||
|
|
||||||
|
:param seed: The seed for the pseudo-random generators.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for idx, remote in enumerate(self.remotes):
|
||||||
|
remote.send(("seed", seed + idx))
|
||||||
|
return [remote.recv() for remote in self.remotes]
|
||||||
|
|
||||||
|
def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]:
|
||||||
|
"""
|
||||||
|
Retreive episode rewards and lengths from each worker
|
||||||
|
for all candidates (there might be multiple candidates per worker)
|
||||||
|
|
||||||
|
:return: A list of tuples containing each candidate index and its
|
||||||
|
result (episodic reward and episode length)
|
||||||
|
"""
|
||||||
|
results = [remote.recv() for remote in self.remotes]
|
||||||
|
flat_results = [result for worker_results in results for result in worker_results]
|
||||||
|
self.waiting = False
|
||||||
|
return flat_results
|
||||||
|
|
||||||
|
def get_obs_rms(self) -> List[RunningMeanStd]:
|
||||||
|
"""
|
||||||
|
Retrieve the observation filters (observation running mean std)
|
||||||
|
of each process, they will be combined in the main process.
|
||||||
|
Synchronisation is done afterward using ``sync_obs_rms()``.
|
||||||
|
:return: A list of ``RunningMeanStd`` objects (one per process)
|
||||||
|
"""
|
||||||
|
for remote in self.remotes:
|
||||||
|
remote.send(("get_obs_rms", None))
|
||||||
|
return [remote.recv() for remote in self.remotes]
|
||||||
|
|
||||||
|
def sync_obs_rms(self, obs_rms: RunningMeanStd) -> None:
|
||||||
|
"""
|
||||||
|
Synchronise (and update) the observation filters
|
||||||
|
(observation running mean std)
|
||||||
|
:param obs_rms: The updated ``RunningMeanStd`` to be used
|
||||||
|
by workers for normalizing observations.
|
||||||
|
"""
|
||||||
|
for remote in self.remotes:
|
||||||
|
remote.send(("sync_obs_rms", obs_rms))
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""
|
||||||
|
Close the processes.
|
||||||
|
"""
|
||||||
|
if self.closed:
|
||||||
|
return
|
||||||
|
if self.waiting:
|
||||||
|
for remote in self.remotes:
|
||||||
|
remote.recv()
|
||||||
|
for remote in self.remotes:
|
||||||
|
remote.send(("close", None))
|
||||||
|
for process in self.processes:
|
||||||
|
process.join()
|
||||||
|
self.closed = True
|
||||||
|
|
@ -206,8 +206,8 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True)
|
policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True)
|
||||||
|
|
||||||
grad_shape.append(kl_param_grad.shape)
|
grad_shape.append(kl_param_grad.shape)
|
||||||
grad_kl.append(kl_param_grad.view(-1))
|
grad_kl.append(kl_param_grad.reshape(-1))
|
||||||
policy_objective_gradients.append(policy_objective_grad.view(-1))
|
policy_objective_gradients.append(policy_objective_grad.reshape(-1))
|
||||||
actor_params.append(param)
|
actor_params.append(param)
|
||||||
|
|
||||||
# Gradients are concatenated before the conjugate gradient step
|
# Gradients are concatenated before the conjugate gradient step
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.3.1a7
|
1.3.1a9
|
||||||
|
|
|
||||||
|
|
@ -22,13 +22,15 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators
|
||||||
# Ignore import not used when aliases are defined
|
# Ignore import not used when aliases are defined
|
||||||
per-file-ignores =
|
per-file-ignores =
|
||||||
./sb3_contrib/__init__.py:F401
|
./sb3_contrib/__init__.py:F401
|
||||||
|
./sb3_contrib/ars/__init__.py:F401
|
||||||
./sb3_contrib/ppo_mask/__init__.py:F401
|
./sb3_contrib/ppo_mask/__init__.py:F401
|
||||||
./sb3_contrib/qrdqn/__init__.py:F401
|
./sb3_contrib/qrdqn/__init__.py:F401
|
||||||
./sb3_contrib/tqc/__init__.py:F401
|
./sb3_contrib/tqc/__init__.py:F401
|
||||||
./sb3_contrib/trpo/__init__.py:F401
|
./sb3_contrib/trpo/__init__.py:F401
|
||||||
./sb3_contrib/common/vec_env/wrappers/__init__.py:F401
|
|
||||||
./sb3_contrib/common/wrappers/__init__.py:F401
|
./sb3_contrib/common/wrappers/__init__.py:F401
|
||||||
./sb3_contrib/common/envs/__init__.py:F401
|
./sb3_contrib/common/envs/__init__.py:F401
|
||||||
|
./sb3_contrib/common/vec_env/__init__.py:F401
|
||||||
|
|
||||||
exclude =
|
exclude =
|
||||||
# No need to traverse our git directory
|
# No need to traverse our git directory
|
||||||
.git,
|
.git,
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -63,7 +63,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.3.1a4",
|
"stable_baselines3>=1.3.1a9",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
import pytest
|
||||||
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
|
from stable_baselines3.common.noise import NormalActionNoise
|
||||||
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
||||||
|
from sb3_contrib import ARS, QRDQN, TQC
|
||||||
|
from sb3_contrib.common.vec_env import AsyncEval
|
||||||
|
|
||||||
|
N_STEPS_TRAINING = 500
|
||||||
|
SEED = 0
|
||||||
|
ARS_MULTI = "ars_multi"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI])
|
||||||
|
def test_deterministic_training_common(algo):
|
||||||
|
results = [[], []]
|
||||||
|
rewards = [[], []]
|
||||||
|
# Smaller network
|
||||||
|
kwargs = {"policy_kwargs": dict(net_arch=[64])}
|
||||||
|
env_id = "Pendulum-v0"
|
||||||
|
if algo == ARS_MULTI:
|
||||||
|
algo = ARS
|
||||||
|
ars_multi = True
|
||||||
|
else:
|
||||||
|
ars_multi = False
|
||||||
|
|
||||||
|
if algo in [TQC]:
|
||||||
|
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
|
||||||
|
else:
|
||||||
|
if algo == QRDQN:
|
||||||
|
env_id = "CartPole-v1"
|
||||||
|
kwargs.update({"learning_starts": 100, "target_update_interval": 100})
|
||||||
|
elif algo == ARS:
|
||||||
|
kwargs.update({"n_delta": 2})
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
|
||||||
|
|
||||||
|
learn_kwargs = {"total_timesteps": N_STEPS_TRAINING}
|
||||||
|
if ars_multi:
|
||||||
|
learn_kwargs["async_eval"] = AsyncEval(
|
||||||
|
[lambda: VecNormalize(make_vec_env(env_id, seed=SEED, n_envs=1)) for _ in range(2)],
|
||||||
|
model.policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.learn(**learn_kwargs)
|
||||||
|
env = model.get_env()
|
||||||
|
obs = env.reset()
|
||||||
|
for _ in range(100):
|
||||||
|
action, _ = model.predict(obs, deterministic=False)
|
||||||
|
obs, reward, _, _ = env.step(action)
|
||||||
|
results[i].append(action)
|
||||||
|
rewards[i].append(reward)
|
||||||
|
assert sum(results[0]) == sum(results[1]), results
|
||||||
|
assert sum(rewards[0]) == sum(rewards[1]), rewards
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
import gym
|
import gym
|
||||||
import pytest
|
import pytest
|
||||||
from stable_baselines3.common.env_util import make_vec_env
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
||||||
from sb3_contrib import QRDQN, TQC, TRPO
|
from sb3_contrib import ARS, QRDQN, TQC, TRPO
|
||||||
|
from sb3_contrib.common.vec_env import AsyncEval
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
||||||
|
|
@ -82,6 +84,38 @@ def test_trpo_params():
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=500)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
|
||||||
|
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
|
||||||
|
def test_ars(policy_str, env_id):
|
||||||
|
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
|
||||||
|
model.learn(total_timesteps=500, log_interval=1, eval_freq=250)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ars_multi_env():
|
||||||
|
env = make_vec_env("Pendulum-v0", n_envs=2)
|
||||||
|
model = ARS("MlpPolicy", env, n_delta=1)
|
||||||
|
model.learn(total_timesteps=250)
|
||||||
|
|
||||||
|
env = VecNormalize(make_vec_env("Pendulum-v0", n_envs=1))
|
||||||
|
model = ARS("MlpPolicy", env, n_delta=2, seed=0)
|
||||||
|
# with parallelism
|
||||||
|
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v0", n_envs=1)) for _ in range(2)], model.policy)
|
||||||
|
async_eval.seed(0)
|
||||||
|
model.learn(500, async_eval=async_eval)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_top", [2, 8])
|
||||||
|
def test_ars_n_top(n_top):
|
||||||
|
n_delta = 3
|
||||||
|
if n_top > n_delta:
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top)
|
||||||
|
model.learn(total_timesteps=500)
|
||||||
|
else:
|
||||||
|
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top)
|
||||||
|
model.learn(total_timesteps=500)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||||
def test_offpolicy_multi_env(model_class):
|
def test_offpolicy_multi_env(model_class):
|
||||||
if model_class in [TQC]:
|
if model_class in [TQC]:
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,9 @@ from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnv
|
||||||
from stable_baselines3.common.utils import get_device
|
from stable_baselines3.common.utils import get_device
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||||
|
|
||||||
from sb3_contrib import QRDQN, TQC, TRPO
|
from sb3_contrib import ARS, QRDQN, TQC, TRPO
|
||||||
|
|
||||||
MODEL_LIST = [TQC, QRDQN, TRPO]
|
MODEL_LIST = [ARS, QRDQN, TQC, TRPO]
|
||||||
|
|
||||||
|
|
||||||
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
||||||
|
|
@ -266,6 +266,10 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
||||||
:param policy_str: (str) Name of the policy.
|
:param policy_str: (str) Name of the policy.
|
||||||
"""
|
"""
|
||||||
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
||||||
|
|
||||||
|
if policy_str == "CnnPolicy" and model_class is ARS:
|
||||||
|
pytest.skip("ARS does not support CnnPolicy")
|
||||||
|
|
||||||
if policy_str == "MlpPolicy":
|
if policy_str == "MlpPolicy":
|
||||||
env = select_env(model_class)
|
env = select_env(model_class)
|
||||||
else:
|
else:
|
||||||
|
|
@ -332,9 +336,11 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
||||||
if actor is not None:
|
if actor is not None:
|
||||||
actor.save(tmp_path / "actor.pkl")
|
actor.save(tmp_path / "actor.pkl")
|
||||||
|
|
||||||
|
device = policy.device
|
||||||
|
|
||||||
del policy, actor
|
del policy, actor
|
||||||
|
|
||||||
policy = policy_class.load(tmp_path / "policy.pkl")
|
policy = policy_class.load(tmp_path / "policy.pkl").to(device)
|
||||||
if actor_class is not None:
|
if actor_class is not None:
|
||||||
actor = actor_class.load(tmp_path / "actor.pkl")
|
actor = actor_class.load(tmp_path / "actor.pkl")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue