From a1b5ea67aebb71788703743eb2ea4ae97d0e2233 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 2 Dec 2021 10:40:21 +0100 Subject: [PATCH] Multiprocessing support for off policy algorithms (#50) * TQC support for multienv * Add optional layer norm for TQC * Add layer nprm for all policies * Revert "Add layer nprm for all policies" This reverts commit 1306c3c64eb12613464982c66cb416a3bbc66285. * Revert "Add optional layer norm for TQC" This reverts commit 200222e3a8878007aa6032d540ae74274a4d0788. * Add experimental support to train off-policy algorithms with multiple envs * Bump version * Update version --- docs/guide/algos.rst | 4 ++-- docs/misc/changelog.rst | 5 +++-- docs/modules/qrdqn.rst | 2 +- docs/modules/tqc.rst | 2 +- sb3_contrib/qrdqn/qrdqn.py | 1 + sb3_contrib/tqc/tqc.py | 1 + sb3_contrib/version.txt | 2 +- setup.py | 2 +- tests/test_run.py | 30 ++++++++++++++++++++++++++++++ 9 files changed, 41 insertions(+), 8 deletions(-) diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index f1b0ed5..879a84e 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -8,8 +8,8 @@ along with some useful characteristics: support for discrete/continuous actions, ============ =========== ============ ================= =============== ================ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing ============ =========== ============ ================= =============== ================ -TQC ✔️ ❌ ❌ ❌ ❌ -QR-DQN ️❌ ️✔️ ❌ ❌ ❌ +TQC ✔️ ❌ ❌ ❌ ✔️ +QR-DQN ️❌ ️✔️ ❌ ❌ ✔️ ============ =========== ============ ================= =============== ================ diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e885656..067d652 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,15 +4,16 @@ Changelog ========== -Release 1.3.1a3 (WIP) +Release 1.3.1a4 (WIP) ------------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Upgraded to Stable-Baselines3 >= 1.3.1a3 +- Upgraded to Stable-Baselines3 >= 1.3.1a4 New Features: ^^^^^^^^^^^^^ +- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported) Bug Fixes: ^^^^^^^^^^ diff --git a/docs/modules/qrdqn.rst b/docs/modules/qrdqn.rst index 1ae5a7b..9d79d22 100644 --- a/docs/modules/qrdqn.rst +++ b/docs/modules/qrdqn.rst @@ -33,7 +33,7 @@ Can I use? ---------- - Recurrent policies: ❌ -- Multi processing: ❌ +- Multi processing: ✔️ - Gym spaces: diff --git a/docs/modules/tqc.rst b/docs/modules/tqc.rst index 9b9c80d..8341f86 100644 --- a/docs/modules/tqc.rst +++ b/docs/modules/tqc.rst @@ -32,7 +32,7 @@ Can I use? ---------- - Recurrent policies: ❌ -- Multi processing: ❌ +- Multi processing: ✔️ - Gym spaces: diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index e51ca06..bd08b59 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -110,6 +110,7 @@ class QRDQN(OffPolicyAlgorithm): sde_support=False, optimize_memory_usage=optimize_memory_usage, supported_action_spaces=(gym.spaces.Discrete,), + support_multi_env=True, ) self.exploration_initial_eps = exploration_initial_eps diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 05dbd7a..687fc46 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -122,6 +122,7 @@ class TQC(OffPolicyAlgorithm): use_sde_at_warmup=use_sde_at_warmup, optimize_memory_usage=optimize_memory_usage, supported_action_spaces=(gym.spaces.Box), + support_multi_env=True, ) self.target_entropy = target_entropy diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 896c1f3..aca78c7 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.3.1a3 +1.3.1a4 diff --git a/setup.py b/setup.py index 492f059..881114b 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.3.1a3", + "stable_baselines3>=1.3.1a4", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_run.py b/tests/test_run.py index b1431a3..b53641c 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,4 +1,6 @@ +import gym import pytest +from stable_baselines3.common.env_util import make_vec_env from sb3_contrib import QRDQN, TQC @@ -56,3 +58,31 @@ def test_qrdqn(): create_eval_env=True, ) model.learn(total_timesteps=500, eval_freq=250) + + +@pytest.mark.parametrize("model_class", [TQC, QRDQN]) +def test_offpolicy_multi_env(model_class): + if model_class in [TQC]: + env_id = "Pendulum-v0" + policy_kwargs = dict(net_arch=[64], n_critics=1) + else: + env_id = "CartPole-v1" + policy_kwargs = dict(net_arch=[64]) + + def make_env(): + env = gym.make(env_id) + # to check that the code handling timeouts runs + env = gym.wrappers.TimeLimit(env, 50) + return env + + env = make_vec_env(make_env, n_envs=2) + model = model_class( + "MlpPolicy", + env, + policy_kwargs=policy_kwargs, + learning_starts=100, + buffer_size=10000, + verbose=0, + train_freq=5, + ) + model.learn(total_timesteps=150)