Multiprocessing support for off policy algorithms (#50)
* TQC support for multienv * Add optional layer norm for TQC * Add layer nprm for all policies * Revert "Add layer nprm for all policies" This reverts commit 1306c3c64eb12613464982c66cb416a3bbc66285. * Revert "Add optional layer norm for TQC" This reverts commit 200222e3a8878007aa6032d540ae74274a4d0788. * Add experimental support to train off-policy algorithms with multiple envs * Bump version * Update version
This commit is contained in:
parent
cd0a5e516f
commit
a1b5ea67ae
|
|
@ -8,8 +8,8 @@ along with some useful characteristics: support for discrete/continuous actions,
|
|||
============ =========== ============ ================= =============== ================
|
||||
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||
============ =========== ============ ================= =============== ================
|
||||
TQC ✔️ ❌ ❌ ❌ ❌
|
||||
QR-DQN ️❌ ️✔️ ❌ ❌ ❌
|
||||
TQC ✔️ ❌ ❌ ❌ ✔️
|
||||
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
|
||||
============ =========== ============ ================= =============== ================
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,15 +4,16 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.3.1a3 (WIP)
|
||||
Release 1.3.1a4 (WIP)
|
||||
-------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Upgraded to Stable-Baselines3 >= 1.3.1a3
|
||||
- Upgraded to Stable-Baselines3 >= 1.3.1a4
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ Can I use?
|
|||
----------
|
||||
|
||||
- Recurrent policies: ❌
|
||||
- Multi processing: ❌
|
||||
- Multi processing: ✔️
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ Can I use?
|
|||
----------
|
||||
|
||||
- Recurrent policies: ❌
|
||||
- Multi processing: ❌
|
||||
- Multi processing: ✔️
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
sde_support=False,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
supported_action_spaces=(gym.spaces.Discrete,),
|
||||
support_multi_env=True,
|
||||
)
|
||||
|
||||
self.exploration_initial_eps = exploration_initial_eps
|
||||
|
|
|
|||
|
|
@ -122,6 +122,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
use_sde_at_warmup=use_sde_at_warmup,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
supported_action_spaces=(gym.spaces.Box),
|
||||
support_multi_env=True,
|
||||
)
|
||||
|
||||
self.target_entropy = target_entropy
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.3.1a3
|
||||
1.3.1a4
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -63,7 +63,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.3.1a3",
|
||||
"stable_baselines3>=1.3.1a4",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import gym
|
||||
import pytest
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
||||
from sb3_contrib import QRDQN, TQC
|
||||
|
||||
|
|
@ -56,3 +58,31 @@ def test_qrdqn():
|
|||
create_eval_env=True,
|
||||
)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||
def test_offpolicy_multi_env(model_class):
|
||||
if model_class in [TQC]:
|
||||
env_id = "Pendulum-v0"
|
||||
policy_kwargs = dict(net_arch=[64], n_critics=1)
|
||||
else:
|
||||
env_id = "CartPole-v1"
|
||||
policy_kwargs = dict(net_arch=[64])
|
||||
|
||||
def make_env():
|
||||
env = gym.make(env_id)
|
||||
# to check that the code handling timeouts runs
|
||||
env = gym.wrappers.TimeLimit(env, 50)
|
||||
return env
|
||||
|
||||
env = make_vec_env(make_env, n_envs=2)
|
||||
model = model_class(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
policy_kwargs=policy_kwargs,
|
||||
learning_starts=100,
|
||||
buffer_size=10000,
|
||||
verbose=0,
|
||||
train_freq=5,
|
||||
)
|
||||
model.learn(total_timesteps=150)
|
||||
|
|
|
|||
Loading…
Reference in New Issue