Multiprocessing support for off policy algorithms (#50)

* TQC support for multienv

* Add optional layer norm for TQC

* Add layer nprm for all policies

* Revert "Add layer nprm for all policies"

This reverts commit 1306c3c64eb12613464982c66cb416a3bbc66285.

* Revert "Add optional layer norm for TQC"

This reverts commit 200222e3a8878007aa6032d540ae74274a4d0788.

* Add experimental support to train off-policy algorithms with multiple envs

* Bump version

* Update version
This commit is contained in:
Antonin RAFFIN 2021-12-02 10:40:21 +01:00 committed by GitHub
parent cd0a5e516f
commit a1b5ea67ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 41 additions and 8 deletions

View File

@ -8,8 +8,8 @@ along with some useful characteristics: support for discrete/continuous actions,
============ =========== ============ ================= =============== ================
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
TQC ✔️ ❌ ❌ ❌
QR-DQN ️❌ ️✔️ ❌ ❌
TQC ✔️ ❌ ❌ ❌ ✔️
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
============ =========== ============ ================= =============== ================

View File

@ -4,15 +4,16 @@ Changelog
==========
Release 1.3.1a3 (WIP)
Release 1.3.1a4 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.3.1a3
- Upgraded to Stable-Baselines3 >= 1.3.1a4
New Features:
^^^^^^^^^^^^^
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
Bug Fixes:
^^^^^^^^^^

View File

@ -33,7 +33,7 @@ Can I use?
----------
- Recurrent policies: ❌
- Multi processing:
- Multi processing: ✔️
- Gym spaces:

View File

@ -32,7 +32,7 @@ Can I use?
----------
- Recurrent policies: ❌
- Multi processing:
- Multi processing: ✔️
- Gym spaces:

View File

@ -110,6 +110,7 @@ class QRDQN(OffPolicyAlgorithm):
sde_support=False,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Discrete,),
support_multi_env=True,
)
self.exploration_initial_eps = exploration_initial_eps

View File

@ -122,6 +122,7 @@ class TQC(OffPolicyAlgorithm):
use_sde_at_warmup=use_sde_at_warmup,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Box),
support_multi_env=True,
)
self.target_entropy = target_entropy

View File

@ -1 +1 @@
1.3.1a3
1.3.1a4

View File

@ -63,7 +63,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.3.1a3",
"stable_baselines3>=1.3.1a4",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",

View File

@ -1,4 +1,6 @@
import gym
import pytest
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib import QRDQN, TQC
@ -56,3 +58,31 @@ def test_qrdqn():
create_eval_env=True,
)
model.learn(total_timesteps=500, eval_freq=250)
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
def test_offpolicy_multi_env(model_class):
if model_class in [TQC]:
env_id = "Pendulum-v0"
policy_kwargs = dict(net_arch=[64], n_critics=1)
else:
env_id = "CartPole-v1"
policy_kwargs = dict(net_arch=[64])
def make_env():
env = gym.make(env_id)
# to check that the code handling timeouts runs
env = gym.wrappers.TimeLimit(env, 50)
return env
env = make_vec_env(make_env, n_envs=2)
model = model_class(
"MlpPolicy",
env,
policy_kwargs=policy_kwargs,
learning_starts=100,
buffer_size=10000,
verbose=0,
train_freq=5,
)
model.learn(total_timesteps=150)