Multiprocessing support for off policy algorithms (#50)
* TQC support for multienv * Add optional layer norm for TQC * Add layer nprm for all policies * Revert "Add layer nprm for all policies" This reverts commit 1306c3c64eb12613464982c66cb416a3bbc66285. * Revert "Add optional layer norm for TQC" This reverts commit 200222e3a8878007aa6032d540ae74274a4d0788. * Add experimental support to train off-policy algorithms with multiple envs * Bump version * Update version
This commit is contained in:
parent
cd0a5e516f
commit
a1b5ea67ae
|
|
@ -8,8 +8,8 @@ along with some useful characteristics: support for discrete/continuous actions,
|
||||||
============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
|
||||||
============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
TQC ✔️ ❌ ❌ ❌ ❌
|
TQC ✔️ ❌ ❌ ❌ ✔️
|
||||||
QR-DQN ️❌ ️✔️ ❌ ❌ ❌
|
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
|
||||||
============ =========== ============ ================= =============== ================
|
============ =========== ============ ================= =============== ================
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,16 @@ Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
Release 1.3.1a3 (WIP)
|
Release 1.3.1a4 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
- Upgraded to Stable-Baselines3 >= 1.3.1a3
|
- Upgraded to Stable-Baselines3 >= 1.3.1a4
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ Can I use?
|
||||||
----------
|
----------
|
||||||
|
|
||||||
- Recurrent policies: ❌
|
- Recurrent policies: ❌
|
||||||
- Multi processing: ❌
|
- Multi processing: ✔️
|
||||||
- Gym spaces:
|
- Gym spaces:
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ Can I use?
|
||||||
----------
|
----------
|
||||||
|
|
||||||
- Recurrent policies: ❌
|
- Recurrent policies: ❌
|
||||||
- Multi processing: ❌
|
- Multi processing: ✔️
|
||||||
- Gym spaces:
|
- Gym spaces:
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -110,6 +110,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
sde_support=False,
|
sde_support=False,
|
||||||
optimize_memory_usage=optimize_memory_usage,
|
optimize_memory_usage=optimize_memory_usage,
|
||||||
supported_action_spaces=(gym.spaces.Discrete,),
|
supported_action_spaces=(gym.spaces.Discrete,),
|
||||||
|
support_multi_env=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.exploration_initial_eps = exploration_initial_eps
|
self.exploration_initial_eps = exploration_initial_eps
|
||||||
|
|
|
||||||
|
|
@ -122,6 +122,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
use_sde_at_warmup=use_sde_at_warmup,
|
use_sde_at_warmup=use_sde_at_warmup,
|
||||||
optimize_memory_usage=optimize_memory_usage,
|
optimize_memory_usage=optimize_memory_usage,
|
||||||
supported_action_spaces=(gym.spaces.Box),
|
supported_action_spaces=(gym.spaces.Box),
|
||||||
|
support_multi_env=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.target_entropy = target_entropy
|
self.target_entropy = target_entropy
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.3.1a3
|
1.3.1a4
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -63,7 +63,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.3.1a3",
|
"stable_baselines3>=1.3.1a4",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
|
import gym
|
||||||
import pytest
|
import pytest
|
||||||
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
|
|
||||||
from sb3_contrib import QRDQN, TQC
|
from sb3_contrib import QRDQN, TQC
|
||||||
|
|
||||||
|
|
@ -56,3 +58,31 @@ def test_qrdqn():
|
||||||
create_eval_env=True,
|
create_eval_env=True,
|
||||||
)
|
)
|
||||||
model.learn(total_timesteps=500, eval_freq=250)
|
model.learn(total_timesteps=500, eval_freq=250)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||||
|
def test_offpolicy_multi_env(model_class):
|
||||||
|
if model_class in [TQC]:
|
||||||
|
env_id = "Pendulum-v0"
|
||||||
|
policy_kwargs = dict(net_arch=[64], n_critics=1)
|
||||||
|
else:
|
||||||
|
env_id = "CartPole-v1"
|
||||||
|
policy_kwargs = dict(net_arch=[64])
|
||||||
|
|
||||||
|
def make_env():
|
||||||
|
env = gym.make(env_id)
|
||||||
|
# to check that the code handling timeouts runs
|
||||||
|
env = gym.wrappers.TimeLimit(env, 50)
|
||||||
|
return env
|
||||||
|
|
||||||
|
env = make_vec_env(make_env, n_envs=2)
|
||||||
|
model = model_class(
|
||||||
|
"MlpPolicy",
|
||||||
|
env,
|
||||||
|
policy_kwargs=policy_kwargs,
|
||||||
|
learning_starts=100,
|
||||||
|
buffer_size=10000,
|
||||||
|
verbose=0,
|
||||||
|
train_freq=5,
|
||||||
|
)
|
||||||
|
model.learn(total_timesteps=150)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue