Augmented Random Search (ARS) (#42)
* first pass at ars, replicates initial results, still needs more testing, cleanup * add a few docs and tests, bugfixes for ARS * debug and comment * break out dump logs * rollback so there are now predict workers, some refactoring * remove callback from self, remove torch multiprocessing * add module docs * run formatter * fix load and rerun formatter * rename to less mathy variable names, rename _validate_hypers * refactor to use evaluatate_policy, linear policy no longer uses bias or squashing * move everything to torch, add support for discrete action spaces, bugfix for alive reward offset * added tests, passing all of them, add support for discrete action spaces * update documentation * allow for reward offset when there are multiple envs * update results again * Reformat * Ignore unused imports * Renaming + Cleanup * Experimental multiprocessing * Cleaner multiprocessing * Reformat * Fixes for callback * Fix combining stats * 2nd way * Make the implementation cpu only * Fixes + POC with mp module * POC Processes * Cleaner aync implementation * Remove unused arg * Add typing * Revert vec normalize offset hack * Add `squash_output` parameter * Add more tests * Add comments * Update doc * Add comments * Add more logging * Fix TRPO issue on GPU * Tmp fix for ARS tests on GPU * Additional tmp fixes for ARS * update docstrings + formatting, fix bad exceptioe string in ARSPolicy * Add comments and docstrings * Fix missing import * Fix type check * Add dosctrings * GPU support, first attempt * Fix test * Add missing docstring * Typos * Update defaults hyperparameters Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
3b007ae93b
commit
675304d8fa
|
|
@ -29,6 +29,7 @@ See documentation for the full list of included features.
|
|||
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
|
||||
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
|
||||
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
|
||||
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
|
||||
|
||||
**Gym Wrappers**:
|
||||
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
||||
|
|
|
|||
|
|
@ -8,9 +8,10 @@ along with some useful characteristics: support for discrete/continuous actions,
|
|||
============ =========== ============ ================= =============== ================
|
||||
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||
============ =========== ============ ================= =============== ================
|
||||
ARS ✔️ ❌️ ❌ ❌ ✔️
|
||||
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
|
||||
TQC ✔️ ❌ ❌ ❌ ✔️
|
||||
TRPO ✔️ ✔️ ✔️ ✔️ ✔️
|
||||
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
|
||||
============ =========== ============ ================= =============== ================
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ Train a PPO with invalid action masking agent on a toy environment.
|
|||
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
||||
model = MaskablePPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(5000)
|
||||
model.save("qrdqn_cartpole")
|
||||
model.save("maskable_toy_env")
|
||||
|
||||
TRPO
|
||||
----
|
||||
|
|
@ -57,3 +57,17 @@ Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environmen
|
|||
model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1)
|
||||
model.learn(total_timesteps=100_000, log_interval=4)
|
||||
model.save("trpo_pendulum")
|
||||
|
||||
|
||||
ARS
|
||||
---
|
||||
|
||||
Train an agent using Augmented Random Search (ARS) agent on the Pendulum environment
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sb3_contrib import ARS
|
||||
|
||||
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.save("ars_pendulum")
|
||||
|
|
|
|||
|
|
@ -31,10 +31,11 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
|
|||
:maxdepth: 1
|
||||
:caption: RL Algorithms
|
||||
|
||||
modules/ars
|
||||
modules/ppo_mask
|
||||
modules/qrdqn
|
||||
modules/tqc
|
||||
modules/trpo
|
||||
modules/qrdqn
|
||||
modules/ppo_mask
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
|
|
|||
|
|
@ -4,20 +4,21 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.3.1a7 (WIP)
|
||||
Release 1.3.1a9 (WIP)
|
||||
-------------------------------
|
||||
**Add TRPO**
|
||||
**Add TRPO and ARS**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Dropped python 3.6 support
|
||||
- Upgraded to Stable-Baselines3 >= 1.3.1a4
|
||||
- Upgraded to Stable-Baselines3 >= 1.3.1a9
|
||||
- ``MaskablePPO`` was updated to match latest SB3 ``PPO`` version (timeout handling and new method for the policy object)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``TRPO`` (@cyprienc)
|
||||
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
|
||||
- Added Augmented Random Search (ARS) (@sgillen)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -221,4 +222,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
|
|||
Contributors:
|
||||
-------------
|
||||
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen
|
||||
|
|
|
|||
|
|
@ -0,0 +1,160 @@
|
|||
.. _ars:
|
||||
|
||||
.. automodule:: sb3_contrib.ars
|
||||
|
||||
|
||||
ARS
|
||||
===
|
||||
|
||||
|
||||
Augmented Random Search (ARS) is a simple reinforcement algorithm that uses a direct random search over policy
|
||||
parameters. It can be surprisingly effective compared to more sophisticated algorithms. In the `original paper <https://arxiv.org/abs/1803.07055>`_ the authors showed that linear policies trained with ARS were competitive with deep reinforcement learning for the MuJuCo locomotion tasks.
|
||||
|
||||
SB3s implementation allows for linear policies without bias or squashing function, it also allows for training MLP policies, which include linear policies with bias and squashing functions as a special case.
|
||||
|
||||
Normally one wants to train ARS with several seeds to properly evaluate.
|
||||
|
||||
.. warning::
|
||||
|
||||
ARS multi-processing is different from the classic Stable-Baselines3 multi-processing: it runs n environments
|
||||
in parallel but asynchronously. This asynchronous multi-processing is considered experimental
|
||||
and does not fully support callbacks: the ``on_step()`` event is called artificially after the evaluation episodes are over.
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
LinearPolicy
|
||||
MlpPolicy
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
- Original paper: https://arxiv.org/abs/1803.07055
|
||||
- Original Implementation: https://github.com/modestyachts/ARS
|
||||
|
||||
|
||||
Can I use?
|
||||
----------
|
||||
|
||||
- Recurrent policies: ❌
|
||||
- Multi processing: ✔️ (cf. example)
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ✔️ ❌
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ❌
|
||||
MultiBinary ❌ ❌
|
||||
Dict ❌ ❌
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sb3_contrib import ARS
|
||||
|
||||
# Policy can be LinearPolicy or MlpPolicy
|
||||
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.save("ars_pendulum")
|
||||
|
||||
|
||||
With experimental asynchronous multi-processing:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sb3_contrib import ARS
|
||||
from sb3_contrib.common.vec_env import AsyncEval
|
||||
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
||||
env_id = "CartPole-v1"
|
||||
n_envs = 2
|
||||
|
||||
model = ARS("LinearPolicy", env_id, n_delta=2, n_top=1, verbose=1)
|
||||
# Create env for asynchronous evaluation (run in different processes)
|
||||
async_eval = AsyncEval([lambda: make_vec_env(env_id) for _ in range(n_envs)], model.policy)
|
||||
|
||||
model.learn(total_timesteps=200_000, log_interval=4, async_eval=async_eval)
|
||||
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
Replicating results from the original paper, which used the Mujoco benchmarks. Same parameters from the original paper, using 8 seeds.
|
||||
|
||||
============= ============
|
||||
Environments ARS
|
||||
============= ============
|
||||
\
|
||||
HalfCheetah 4398 +/- 320
|
||||
Swimmer 241 +/- 51
|
||||
Hopper 3320 +/- 120
|
||||
============= ============
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Clone RL-Zoo and checkout the branch ``feat/ars``
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
cd rl-baselines3-zoo/
|
||||
git checkout feat/ars
|
||||
|
||||
Run the benchmark. The following code snippet trains 8 seeds in parallel
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
for ENV_ID in Swimmer-v3 HalfCheetah-v3 Hopper-v3
|
||||
do
|
||||
for SEED_NUM in {1..8}
|
||||
do
|
||||
SEED=$RANDOM
|
||||
python train.py --algo ars --env $ENV_ID --eval-episodes 10 --eval-freq 10000 -n 20000000 --seed $SEED &
|
||||
sleep 1
|
||||
done
|
||||
wait
|
||||
done
|
||||
|
||||
Plot the results:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/all_plots.py -a ars -e HalfCheetah Swimmer Hopper -f logs/ -o logs/ars_results -max 20000000
|
||||
python scripts/plot_from_file.py -i logs/ars_results.pkl -l ARS
|
||||
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
.. autoclass:: ARS
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. _ars_policies:
|
||||
|
||||
ARS Policies
|
||||
-------------
|
||||
|
||||
.. autoclass:: sb3_contrib.ars.policies.ARSPolicy
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
.. autoclass:: LinearPolicy
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: MlpPolicy
|
||||
:members:
|
||||
|
|
@ -119,3 +119,7 @@ cuda
|
|||
Polyak
|
||||
gSDE
|
||||
rollouts
|
||||
quantiles
|
||||
quantile
|
||||
contrib
|
||||
Contrib
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
|
||||
from sb3_contrib.ars import ARS
|
||||
from sb3_contrib.ppo_mask import MaskablePPO
|
||||
from sb3_contrib.qrdqn import QRDQN
|
||||
from sb3_contrib.tqc import TQC
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
from sb3_contrib.ars.ars import ARS
|
||||
from sb3_contrib.ars.policies import LinearPolicy, MlpPolicy
|
||||
|
|
@ -0,0 +1,347 @@
|
|||
import copy
|
||||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn.utils
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
|
||||
|
||||
from sb3_contrib.ars.policies import ARSPolicy
|
||||
from sb3_contrib.common.vec_env.async_eval import AsyncEval
|
||||
|
||||
|
||||
class ARS(BaseAlgorithm):
|
||||
"""
|
||||
Augmented Random Search: https://arxiv.org/abs/1803.07055
|
||||
|
||||
Original implementation: https://github.com/modestyachts/ARS
|
||||
C++/Cuda Implementation: https://github.com/google-research/tiny-differentiable-simulator/
|
||||
150 LOC Numpy Implementation: https://github.com/alexis-jacq/numpy_ARS/blob/master/asr.py
|
||||
|
||||
:param policy: The policy to train, can be an instance of ``ARSPolicy``, or a string from ["LinearPolicy", "MlpPolicy"]
|
||||
:param env: The environment to train on, may be a string if registered with gym
|
||||
:param n_delta: How many random perturbations of the policy to try at each update step.
|
||||
:param n_top: How many of the top delta to use in each update step. Default is n_delta
|
||||
:param learning_rate: Float or schedule for the step size
|
||||
:param delta_std: Float or schedule for the exploration noise
|
||||
:param zero_policy: Boolean determining if the passed policy should have it's weights zeroed before training.
|
||||
:param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses.
|
||||
:param n_eval_episodes: Number of episodes to evaluate each candidate.
|
||||
:param policy_kwargs: Keyword arguments to pass to the policy on creation
|
||||
:param policy_base: Base class to use for the policy
|
||||
:param tensorboard_log: String with the directory to put tensorboard logs:
|
||||
:param seed: Random seed for the training
|
||||
:param verbose: Verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param device: Torch device to use for training, defaults to "cpu"
|
||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[ARSPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
n_delta: int = 8,
|
||||
n_top: Optional[int] = None,
|
||||
learning_rate: Union[float, Schedule] = 0.02,
|
||||
delta_std: Union[float, Schedule] = 0.05,
|
||||
zero_policy: bool = True,
|
||||
alive_bonus_offset: float = 0,
|
||||
n_eval_episodes: int = 1,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_base: Type[BasePolicy] = ARSPolicy,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = "cpu",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
learning_rate=learning_rate,
|
||||
tensorboard_log=tensorboard_log,
|
||||
policy_base=policy_base,
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=verbose,
|
||||
device=device,
|
||||
supported_action_spaces=(gym.spaces.Box, gym.spaces.Discrete),
|
||||
support_multi_env=True,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
self.n_delta = n_delta
|
||||
self.pop_size = 2 * n_delta
|
||||
self.delta_std_schedule = get_schedule_fn(delta_std)
|
||||
self.n_eval_episodes = n_eval_episodes
|
||||
|
||||
if n_top is None:
|
||||
n_top = n_delta
|
||||
|
||||
# Make sure our hyper parameters are valid and auto correct them if they are not
|
||||
if n_top > n_delta:
|
||||
warnings.warn(f"n_top = {n_top} > n_delta = {n_top}, setting n_top = n_delta")
|
||||
n_top = n_delta
|
||||
|
||||
self.n_top = n_top
|
||||
|
||||
self.alive_bonus_offset = alive_bonus_offset
|
||||
self.zero_policy = zero_policy
|
||||
self.weights = None # Need to call init model to initialize weight
|
||||
self.processes = None
|
||||
# Keep track of how many steps where elapsed before a new rollout
|
||||
# Important for syncing observation normalization between workers
|
||||
self.old_count = 0
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self.weights = th.nn.utils.parameters_to_vector(self.policy.parameters()).detach()
|
||||
self.n_params = len(self.weights)
|
||||
|
||||
if self.zero_policy:
|
||||
self.weights = th.zeros_like(self.weights, requires_grad=False)
|
||||
self.policy.load_from_vector(self.weights.cpu())
|
||||
|
||||
def _mimic_monitor_wrapper(self, episode_rewards: np.ndarray, episode_lengths: np.ndarray) -> None:
|
||||
"""
|
||||
Helper to mimic Monitor wrapper and report episode statistics (mean reward, mean episode length).
|
||||
|
||||
:param episode_rewards: List containing per-episode rewards
|
||||
:param episode_lengths: List containing per-episode lengths (in number of steps)
|
||||
"""
|
||||
# Mimic Monitor Wrapper
|
||||
infos = [
|
||||
{"episode": {"r": episode_reward, "l": episode_length}}
|
||||
for episode_reward, episode_length in zip(episode_rewards, episode_lengths)
|
||||
]
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
|
||||
def _trigger_callback(
|
||||
self,
|
||||
_locals: Dict[str, Any],
|
||||
_globals: Dict[str, Any],
|
||||
callback: BaseCallback,
|
||||
n_envs: int,
|
||||
) -> None:
|
||||
"""
|
||||
Callback passed to the ``evaluate_policy()`` helper
|
||||
in order to increment the number of timesteps
|
||||
and trigger events in the single process version.
|
||||
|
||||
:param _locals:
|
||||
:param _globals:
|
||||
:param callback: Callback that will be called at every step
|
||||
:param n_envs: Number of environments
|
||||
"""
|
||||
self.num_timesteps += n_envs
|
||||
callback.on_step()
|
||||
|
||||
def evaluate_candidates(
|
||||
self, candidate_weights: th.Tensor, callback: BaseCallback, async_eval: Optional[AsyncEval]
|
||||
) -> th.Tensor:
|
||||
"""
|
||||
Evaluate each candidate.
|
||||
|
||||
:param candidate_weights: The candidate weights to be evaluated.
|
||||
:param callback: Callback that will be called at each step
|
||||
(or after evaluation in the multiprocess version)
|
||||
:param async_eval: The object for asynchronous evaluation of candidates.
|
||||
:return: The episodic return for each candidate.
|
||||
"""
|
||||
|
||||
batch_steps = 0
|
||||
# returns == sum of rewards
|
||||
candidate_returns = th.zeros(self.pop_size, device=self.device)
|
||||
train_policy = copy.deepcopy(self.policy)
|
||||
# Empty buffer to show only mean over one iteration (one set of candidates) in the logs
|
||||
self.ep_info_buffer = []
|
||||
callback.on_rollout_start()
|
||||
|
||||
if async_eval is not None:
|
||||
# Multiprocess asynchronous version
|
||||
async_eval.send_jobs(candidate_weights, self.pop_size)
|
||||
results = async_eval.get_results()
|
||||
|
||||
for weights_idx, (episode_rewards, episode_lengths) in results:
|
||||
|
||||
# Update reward to cancel out alive bonus if needed
|
||||
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
|
||||
batch_steps += np.sum(episode_lengths)
|
||||
self._mimic_monitor_wrapper(episode_rewards, episode_lengths)
|
||||
|
||||
# Combine the filter stats of each process for normalization
|
||||
for worker_obs_rms in async_eval.get_obs_rms():
|
||||
if self._vec_normalize_env is not None:
|
||||
# worker_obs_rms.count -= self.old_count
|
||||
self._vec_normalize_env.obs_rms.combine(worker_obs_rms)
|
||||
# Hack: don't count timesteps twice (between the two are synced)
|
||||
# otherwise it will lead to overflow,
|
||||
# in practice we would need two RunningMeanStats
|
||||
self._vec_normalize_env.obs_rms.count -= self.old_count
|
||||
|
||||
# Synchronise VecNormalize if needed
|
||||
if self._vec_normalize_env is not None:
|
||||
async_eval.sync_obs_rms(self._vec_normalize_env.obs_rms.copy())
|
||||
self.old_count = self._vec_normalize_env.obs_rms.count
|
||||
|
||||
# Hack to have Callback events
|
||||
for _ in range(batch_steps // len(async_eval.remotes)):
|
||||
self.num_timesteps += len(async_eval.remotes)
|
||||
callback.on_step()
|
||||
else:
|
||||
# Single process, synchronous version
|
||||
for weights_idx in range(self.pop_size):
|
||||
|
||||
# Load current candidate weights
|
||||
train_policy.load_from_vector(candidate_weights[weights_idx].cpu())
|
||||
# Evaluate the candidate
|
||||
episode_rewards, episode_lengths = evaluate_policy(
|
||||
train_policy,
|
||||
self.env,
|
||||
n_eval_episodes=self.n_eval_episodes,
|
||||
return_episode_rewards=True,
|
||||
# Increment num_timesteps too (slight mismatch with multi envs)
|
||||
callback=partial(self._trigger_callback, callback=callback, n_envs=self.env.num_envs),
|
||||
warn=False,
|
||||
)
|
||||
# Update reward to cancel out alive bonus if needed
|
||||
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
|
||||
batch_steps += sum(episode_lengths)
|
||||
self._mimic_monitor_wrapper(episode_rewards, episode_lengths)
|
||||
|
||||
# Note: we increment the num_timesteps inside the evaluate_policy()
|
||||
# however when using multiple environments, there will be a slight
|
||||
# mismatch between the number of timesteps used and the number
|
||||
# of calls to the step() method (cf. implementation of evaluate_policy())
|
||||
# self.num_timesteps += batch_steps
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return candidate_returns
|
||||
|
||||
def _log_and_dump(self) -> None:
|
||||
"""
|
||||
Dump information to the logger.
|
||||
"""
|
||||
time_elapsed = time.time() - self.start_time
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("time/fps", fps)
|
||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
|
||||
def _do_one_update(self, callback: BaseCallback, async_eval: Optional[AsyncEval]) -> None:
|
||||
"""
|
||||
Sample new candidates, evaluate them and then update current policy.
|
||||
|
||||
:param callback: callback(s) called at every step with state of the algorithm.
|
||||
:param async_eval: The object for asynchronous evaluation of candidates.
|
||||
"""
|
||||
# Retrieve current parameter noise standard deviation
|
||||
# and current learning rate
|
||||
delta_std = self.delta_std_schedule(self._current_progress_remaining)
|
||||
learning_rate = self.lr_schedule(self._current_progress_remaining)
|
||||
# Sample the parameter noise, it will be scaled by delta_std
|
||||
deltas = th.normal(mean=0.0, std=1.0, size=(self.n_delta, self.n_params), device=self.device)
|
||||
policy_deltas = deltas * delta_std
|
||||
# Generate 2 * n_delta candidate policies by adding noise to the current weights
|
||||
candidate_weights = th.cat([self.weights + policy_deltas, self.weights - policy_deltas])
|
||||
|
||||
with th.no_grad():
|
||||
candidate_returns = self.evaluate_candidates(candidate_weights, callback, async_eval)
|
||||
|
||||
# Returns corresponding to weights + deltas
|
||||
plus_returns = candidate_returns[: self.n_delta]
|
||||
# Returns corresponding to weights - deltas
|
||||
minus_returns = candidate_returns[self.n_delta :]
|
||||
|
||||
# Keep only the top performing candidates for update
|
||||
top_returns, _ = th.max(th.vstack((plus_returns, minus_returns)), dim=0)
|
||||
top_idx = th.argsort(top_returns, descending=True)[: self.n_top]
|
||||
plus_returns = plus_returns[top_idx]
|
||||
minus_returns = minus_returns[top_idx]
|
||||
deltas = deltas[top_idx]
|
||||
|
||||
# Scale learning rate by the return standard deviation:
|
||||
# take smaller steps when there is a high variance in the returns
|
||||
return_std = th.cat([plus_returns, minus_returns]).std()
|
||||
step_size = learning_rate / (self.n_top * return_std + 1e-6)
|
||||
# Approximate gradient step
|
||||
self.weights = self.weights + step_size * ((plus_returns - minus_returns) @ deltas)
|
||||
self.policy.load_from_vector(self.weights.cpu())
|
||||
|
||||
self.logger.record("train/iterations", self._n_updates, exclude="tensorboard")
|
||||
self.logger.record("train/delta_std", delta_std)
|
||||
self.logger.record("train/learning_rate", learning_rate)
|
||||
self.logger.record("train/step_size", step_size.item())
|
||||
self.logger.record("rollout/return_std", return_std.item())
|
||||
|
||||
self._n_updates += 1
|
||||
|
||||
def learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
tb_log_name: str = "ARS",
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
async_eval: Optional[AsyncEval] = None,
|
||||
) -> "ARS":
|
||||
"""
|
||||
Return a trained model.
|
||||
|
||||
:param total_timesteps: The total number of samples (env steps) to train on
|
||||
:param callback: callback(s) called at every step with state of the algorithm.
|
||||
:param log_interval: The number of timesteps before logging.
|
||||
:param tb_log_name: the name of the run for TensorBoard logging
|
||||
:param eval_env: Environment that will be used to evaluate the agent
|
||||
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
|
||||
:param n_eval_episodes: Number of episode to evaluate the agent
|
||||
:param eval_log_path: Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
||||
:param async_eval: The object for asynchronous evaluation of candidates.
|
||||
:return: the trained model
|
||||
"""
|
||||
|
||||
total_steps, callback = self._setup_learn(
|
||||
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
|
||||
)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_steps:
|
||||
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
||||
self._do_one_update(callback, async_eval)
|
||||
if log_interval is not None and self._n_updates % log_interval == 0:
|
||||
self._log_and_dump()
|
||||
|
||||
if async_eval is not None:
|
||||
async_eval.close()
|
||||
|
||||
callback.on_training_end()
|
||||
|
||||
return self
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import create_mlp
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ARSPolicy(BasePolicy):
|
||||
"""
|
||||
Policy network for ARS.
|
||||
|
||||
:param observation_space: The observation space of the environment
|
||||
:param action_space: The action space of the environment
|
||||
:param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes.
|
||||
:param activation_fn: Activation function
|
||||
:param squash_output: For continuous actions, whether the output is squashed
|
||||
or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
squash_output: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
squash_output=isinstance(action_space, gym.spaces.Box) and squash_output,
|
||||
)
|
||||
|
||||
if net_arch is None:
|
||||
net_arch = [64, 64]
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.features_extractor = self.make_features_extractor()
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
self.activation_fn = activation_fn
|
||||
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
action_dim = get_action_dim(action_space)
|
||||
actor_net = create_mlp(self.features_dim, action_dim, net_arch, activation_fn, squash_output=True)
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn)
|
||||
else:
|
||||
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
||||
|
||||
self.action_net = nn.Sequential(*actor_net)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
# data = super()._get_constructor_parameters() this adds normalize_images, which we don't support...
|
||||
data = dict(
|
||||
observation_space=self.observation_space,
|
||||
action_space=self.action_space,
|
||||
net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn,
|
||||
)
|
||||
return data
|
||||
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
|
||||
features = self.extract_features(obs)
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
return self.action_net(features)
|
||||
elif isinstance(self.action_space, gym.spaces.Discrete):
|
||||
logits = self.action_net(features)
|
||||
return th.argmax(logits, dim=1)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
# Non deterministic action does not really make sense for ARS, we ignore this parameter for now..
|
||||
return self.forward(observation)
|
||||
|
||||
|
||||
class ARSLinearPolicy(ARSPolicy):
|
||||
"""
|
||||
Linear policy network for ARS.
|
||||
|
||||
:param observation_space: The observation space of the environment
|
||||
:param action_space: The action space of the environment
|
||||
:param with_bias: With or without bias on the output
|
||||
:param squash_output: For continuous actions, whether the output is squashed
|
||||
or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
with_bias: bool = False,
|
||||
squash_output: bool = False,
|
||||
):
|
||||
|
||||
super().__init__(observation_space, action_space, squash_output=squash_output)
|
||||
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
action_dim = get_action_dim(action_space)
|
||||
self.action_net = nn.Linear(self.features_dim, action_dim, bias=with_bias)
|
||||
if squash_output:
|
||||
self.action_net = nn.Sequential(self.action_net, nn.Tanh())
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
self.action_net = nn.Linear(self.features_dim, action_space.n, bias=with_bias)
|
||||
else:
|
||||
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
||||
|
||||
|
||||
MlpPolicy = ARSPolicy
|
||||
LinearPolicy = ARSLinearPolicy
|
||||
|
||||
|
||||
register_policy("LinearPolicy", LinearPolicy)
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
|
|
@ -0,0 +1 @@
|
|||
from sb3_contrib.common.vec_env.async_eval import AsyncEval
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
import multiprocessing
|
||||
import multiprocessing as mp
|
||||
from collections import defaultdict
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
||||
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
|
||||
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
|
||||
|
||||
|
||||
def _worker(
|
||||
remote: mp.connection.Connection,
|
||||
parent_remote: mp.connection.Connection,
|
||||
worker_env_wrapper: CloudpickleWrapper,
|
||||
train_policy_wrapper: CloudpickleWrapper,
|
||||
n_eval_episodes: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Function that will be run in each process.
|
||||
It is in charge of creating environments, evaluating candidates
|
||||
and communicating with the main process.
|
||||
|
||||
:param remote: Pipe to communicate with the parent process.
|
||||
:param parent_remote:
|
||||
:param worker_env_wrapper: Callable used to create the environment inside the process.
|
||||
:param train_policy_wrapper: Callable used to create the policy inside the process.
|
||||
:param n_eval_episodes: Number of evaluation episodes per candidate.
|
||||
"""
|
||||
parent_remote.close()
|
||||
env = worker_env_wrapper.var()
|
||||
train_policy = train_policy_wrapper.var
|
||||
vec_normalize = unwrap_vec_normalize(env)
|
||||
if vec_normalize is not None:
|
||||
obs_rms = vec_normalize.obs_rms
|
||||
else:
|
||||
obs_rms = None
|
||||
while True:
|
||||
try:
|
||||
cmd, data = remote.recv()
|
||||
if cmd == "eval":
|
||||
results = []
|
||||
# Evaluate each candidate and save results
|
||||
for weights_idx, candidate_weights in data:
|
||||
train_policy.load_from_vector(candidate_weights.cpu())
|
||||
episode_rewards, episode_lengths = evaluate_policy(
|
||||
train_policy,
|
||||
env,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
return_episode_rewards=True,
|
||||
warn=False,
|
||||
)
|
||||
results.append((weights_idx, (episode_rewards, episode_lengths)))
|
||||
remote.send(results)
|
||||
elif cmd == "seed":
|
||||
remote.send(env.seed(data))
|
||||
elif cmd == "get_obs_rms":
|
||||
remote.send(obs_rms)
|
||||
elif cmd == "sync_obs_rms":
|
||||
vec_normalize.obs_rms = data
|
||||
obs_rms = data
|
||||
elif cmd == "close":
|
||||
env.close()
|
||||
remote.close()
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
|
||||
class AsyncEval(object):
|
||||
"""
|
||||
Helper class to do asynchronous evaluation of different policies with multiple processes.
|
||||
It is useful when implementing population based methods like Evolution Strategies (ES),
|
||||
Cross Entropy Method (CEM) or Augmented Random Search (ARS).
|
||||
|
||||
.. warning::
|
||||
|
||||
Only 'forkserver' and 'spawn' start methods are thread-safe,
|
||||
which is important to avoid race conditions.
|
||||
However, compared to
|
||||
'fork' they incur a small start-up cost and have restrictions on
|
||||
global variables. With those methods, users must wrap the code in an
|
||||
``if __name__ == "__main__":`` block.
|
||||
For more information, see the multiprocessing documentation.
|
||||
|
||||
:param envs_fn: Vectorized environments to run in subprocesses (callable)
|
||||
:param train_policy: The policy object that will load the different candidate
|
||||
weights.
|
||||
:param start_method: method used to start the subprocesses.
|
||||
Must be one of the methods returned by ``multiprocessing.get_all_start_methods()``.
|
||||
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
|
||||
:param n_eval_episodes: The number of episodes to test each agent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
envs_fn: List[Callable[[], VecEnv]],
|
||||
train_policy: BasePolicy,
|
||||
start_method: Optional[str] = None,
|
||||
n_eval_episodes: int = 1,
|
||||
):
|
||||
self.waiting = False
|
||||
self.closed = False
|
||||
n_envs = len(envs_fn)
|
||||
|
||||
if start_method is None:
|
||||
# Fork is not a thread safe method (see issue #217)
|
||||
# but is more user friendly (does not require to wrap the code in
|
||||
# a `if __name__ == "__main__":`)
|
||||
forkserver_available = "forkserver" in mp.get_all_start_methods()
|
||||
start_method = "forkserver" if forkserver_available else "spawn"
|
||||
ctx = mp.get_context(start_method)
|
||||
|
||||
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
|
||||
self.processes = []
|
||||
for work_remote, remote, worker_env in zip(self.work_remotes, self.remotes, envs_fn):
|
||||
args = (
|
||||
work_remote,
|
||||
remote,
|
||||
CloudpickleWrapper(worker_env),
|
||||
CloudpickleWrapper(train_policy),
|
||||
n_eval_episodes,
|
||||
)
|
||||
# daemon=True: if the main process crashes, we should not cause things to hang
|
||||
process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
work_remote.close()
|
||||
|
||||
def send_jobs(self, candidate_weights: th.Tensor, pop_size: int) -> None:
|
||||
"""
|
||||
Send jobs to the workers to evaluate new candidates.
|
||||
|
||||
:param candidate_weights: The weights to be evaluated.
|
||||
:pop_size: The number of candidate (size of the population)
|
||||
"""
|
||||
jobs_per_worker = defaultdict(list)
|
||||
for weights_idx in range(pop_size):
|
||||
jobs_per_worker[weights_idx % len(self.remotes)].append((weights_idx, candidate_weights[weights_idx]))
|
||||
|
||||
for remote_idx, remote in enumerate(self.remotes):
|
||||
remote.send(("eval", jobs_per_worker[remote_idx]))
|
||||
self.waiting = True
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||
"""
|
||||
Seed the environments.
|
||||
|
||||
:param seed: The seed for the pseudo-random generators.
|
||||
:return:
|
||||
"""
|
||||
for idx, remote in enumerate(self.remotes):
|
||||
remote.send(("seed", seed + idx))
|
||||
return [remote.recv() for remote in self.remotes]
|
||||
|
||||
def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]:
|
||||
"""
|
||||
Retreive episode rewards and lengths from each worker
|
||||
for all candidates (there might be multiple candidates per worker)
|
||||
|
||||
:return: A list of tuples containing each candidate index and its
|
||||
result (episodic reward and episode length)
|
||||
"""
|
||||
results = [remote.recv() for remote in self.remotes]
|
||||
flat_results = [result for worker_results in results for result in worker_results]
|
||||
self.waiting = False
|
||||
return flat_results
|
||||
|
||||
def get_obs_rms(self) -> List[RunningMeanStd]:
|
||||
"""
|
||||
Retrieve the observation filters (observation running mean std)
|
||||
of each process, they will be combined in the main process.
|
||||
Synchronisation is done afterward using ``sync_obs_rms()``.
|
||||
:return: A list of ``RunningMeanStd`` objects (one per process)
|
||||
"""
|
||||
for remote in self.remotes:
|
||||
remote.send(("get_obs_rms", None))
|
||||
return [remote.recv() for remote in self.remotes]
|
||||
|
||||
def sync_obs_rms(self, obs_rms: RunningMeanStd) -> None:
|
||||
"""
|
||||
Synchronise (and update) the observation filters
|
||||
(observation running mean std)
|
||||
:param obs_rms: The updated ``RunningMeanStd`` to be used
|
||||
by workers for normalizing observations.
|
||||
"""
|
||||
for remote in self.remotes:
|
||||
remote.send(("sync_obs_rms", obs_rms))
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the processes.
|
||||
"""
|
||||
if self.closed:
|
||||
return
|
||||
if self.waiting:
|
||||
for remote in self.remotes:
|
||||
remote.recv()
|
||||
for remote in self.remotes:
|
||||
remote.send(("close", None))
|
||||
for process in self.processes:
|
||||
process.join()
|
||||
self.closed = True
|
||||
|
|
@ -206,8 +206,8 @@ class TRPO(OnPolicyAlgorithm):
|
|||
policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True)
|
||||
|
||||
grad_shape.append(kl_param_grad.shape)
|
||||
grad_kl.append(kl_param_grad.view(-1))
|
||||
policy_objective_gradients.append(policy_objective_grad.view(-1))
|
||||
grad_kl.append(kl_param_grad.reshape(-1))
|
||||
policy_objective_gradients.append(policy_objective_grad.reshape(-1))
|
||||
actor_params.append(param)
|
||||
|
||||
# Gradients are concatenated before the conjugate gradient step
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.3.1a7
|
||||
1.3.1a9
|
||||
|
|
|
|||
|
|
@ -22,13 +22,15 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators
|
|||
# Ignore import not used when aliases are defined
|
||||
per-file-ignores =
|
||||
./sb3_contrib/__init__.py:F401
|
||||
./sb3_contrib/ars/__init__.py:F401
|
||||
./sb3_contrib/ppo_mask/__init__.py:F401
|
||||
./sb3_contrib/qrdqn/__init__.py:F401
|
||||
./sb3_contrib/tqc/__init__.py:F401
|
||||
./sb3_contrib/trpo/__init__.py:F401
|
||||
./sb3_contrib/common/vec_env/wrappers/__init__.py:F401
|
||||
./sb3_contrib/common/wrappers/__init__.py:F401
|
||||
./sb3_contrib/common/envs/__init__.py:F401
|
||||
./sb3_contrib/common/vec_env/__init__.py:F401
|
||||
|
||||
exclude =
|
||||
# No need to traverse our git directory
|
||||
.git,
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -63,7 +63,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.3.1a4",
|
||||
"stable_baselines3>=1.3.1a9",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
import pytest
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.noise import NormalActionNoise
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
from sb3_contrib import ARS, QRDQN, TQC
|
||||
from sb3_contrib.common.vec_env import AsyncEval
|
||||
|
||||
N_STEPS_TRAINING = 500
|
||||
SEED = 0
|
||||
ARS_MULTI = "ars_multi"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI])
|
||||
def test_deterministic_training_common(algo):
|
||||
results = [[], []]
|
||||
rewards = [[], []]
|
||||
# Smaller network
|
||||
kwargs = {"policy_kwargs": dict(net_arch=[64])}
|
||||
env_id = "Pendulum-v0"
|
||||
if algo == ARS_MULTI:
|
||||
algo = ARS
|
||||
ars_multi = True
|
||||
else:
|
||||
ars_multi = False
|
||||
|
||||
if algo in [TQC]:
|
||||
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
|
||||
else:
|
||||
if algo == QRDQN:
|
||||
env_id = "CartPole-v1"
|
||||
kwargs.update({"learning_starts": 100, "target_update_interval": 100})
|
||||
elif algo == ARS:
|
||||
kwargs.update({"n_delta": 2})
|
||||
|
||||
for i in range(2):
|
||||
model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
|
||||
|
||||
learn_kwargs = {"total_timesteps": N_STEPS_TRAINING}
|
||||
if ars_multi:
|
||||
learn_kwargs["async_eval"] = AsyncEval(
|
||||
[lambda: VecNormalize(make_vec_env(env_id, seed=SEED, n_envs=1)) for _ in range(2)],
|
||||
model.policy,
|
||||
)
|
||||
|
||||
model.learn(**learn_kwargs)
|
||||
env = model.get_env()
|
||||
obs = env.reset()
|
||||
for _ in range(100):
|
||||
action, _ = model.predict(obs, deterministic=False)
|
||||
obs, reward, _, _ = env.step(action)
|
||||
results[i].append(action)
|
||||
rewards[i].append(reward)
|
||||
assert sum(results[0]) == sum(results[1]), results
|
||||
assert sum(rewards[0]) == sum(rewards[1]), rewards
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
import gym
|
||||
import pytest
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
from sb3_contrib import QRDQN, TQC, TRPO
|
||||
from sb3_contrib import ARS, QRDQN, TQC, TRPO
|
||||
from sb3_contrib.common.vec_env import AsyncEval
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
||||
|
|
@ -82,6 +84,38 @@ def test_trpo_params():
|
|||
model.learn(total_timesteps=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
|
||||
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
|
||||
def test_ars(policy_str, env_id):
|
||||
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
|
||||
model.learn(total_timesteps=500, log_interval=1, eval_freq=250)
|
||||
|
||||
|
||||
def test_ars_multi_env():
|
||||
env = make_vec_env("Pendulum-v0", n_envs=2)
|
||||
model = ARS("MlpPolicy", env, n_delta=1)
|
||||
model.learn(total_timesteps=250)
|
||||
|
||||
env = VecNormalize(make_vec_env("Pendulum-v0", n_envs=1))
|
||||
model = ARS("MlpPolicy", env, n_delta=2, seed=0)
|
||||
# with parallelism
|
||||
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v0", n_envs=1)) for _ in range(2)], model.policy)
|
||||
async_eval.seed(0)
|
||||
model.learn(500, async_eval=async_eval)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_top", [2, 8])
|
||||
def test_ars_n_top(n_top):
|
||||
n_delta = 3
|
||||
if n_top > n_delta:
|
||||
with pytest.warns(UserWarning):
|
||||
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top)
|
||||
model.learn(total_timesteps=500)
|
||||
else:
|
||||
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top)
|
||||
model.learn(total_timesteps=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||
def test_offpolicy_multi_env(model_class):
|
||||
if model_class in [TQC]:
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnv
|
|||
from stable_baselines3.common.utils import get_device
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from sb3_contrib import QRDQN, TQC, TRPO
|
||||
from sb3_contrib import ARS, QRDQN, TQC, TRPO
|
||||
|
||||
MODEL_LIST = [TQC, QRDQN, TRPO]
|
||||
MODEL_LIST = [ARS, QRDQN, TQC, TRPO]
|
||||
|
||||
|
||||
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
||||
|
|
@ -266,6 +266,10 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
:param policy_str: (str) Name of the policy.
|
||||
"""
|
||||
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
||||
|
||||
if policy_str == "CnnPolicy" and model_class is ARS:
|
||||
pytest.skip("ARS does not support CnnPolicy")
|
||||
|
||||
if policy_str == "MlpPolicy":
|
||||
env = select_env(model_class)
|
||||
else:
|
||||
|
|
@ -332,9 +336,11 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
if actor is not None:
|
||||
actor.save(tmp_path / "actor.pkl")
|
||||
|
||||
device = policy.device
|
||||
|
||||
del policy, actor
|
||||
|
||||
policy = policy_class.load(tmp_path / "policy.pkl")
|
||||
policy = policy_class.load(tmp_path / "policy.pkl").to(device)
|
||||
if actor_class is not None:
|
||||
actor = actor_class.load(tmp_path / "actor.pkl")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue