Sync SB3 Contrib with SB3 (#213)
* Update RTD config * Switch to ruff for sorting imports * Evaluate falsy to truthy with not rather than `is False` * Add `features_extractor` argument to maskable policy * Add set_options for AsyncEval * Doc fixes
This commit is contained in:
parent
5be11deaf3
commit
4d7ed004af
|
|
@ -14,3 +14,8 @@ formats: all
|
||||||
# Set requirements using conda env
|
# Set requirements using conda env
|
||||||
conda:
|
conda:
|
||||||
environment: docs/conda_env.yml
|
environment: docs/conda_env.yml
|
||||||
|
|
||||||
|
build:
|
||||||
|
os: ubuntu-22.04
|
||||||
|
tools:
|
||||||
|
python: "mambaforge-22.9"
|
||||||
|
|
|
||||||
4
Makefile
4
Makefile
|
|
@ -22,13 +22,13 @@ lint:
|
||||||
|
|
||||||
format:
|
format:
|
||||||
# Sort imports
|
# Sort imports
|
||||||
isort ${LINT_PATHS}
|
ruff --select I ${LINT_PATHS} --fix
|
||||||
# Reformat using black
|
# Reformat using black
|
||||||
black ${LINT_PATHS}
|
black ${LINT_PATHS}
|
||||||
|
|
||||||
check-codestyle:
|
check-codestyle:
|
||||||
# Sort imports
|
# Sort imports
|
||||||
isort --check ${LINT_PATHS}
|
ruff --select I ${LINT_PATHS}
|
||||||
# Reformat using black
|
# Reformat using black
|
||||||
black --check ${LINT_PATHS}
|
black --check ${LINT_PATHS}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,7 @@ dependencies:
|
||||||
- pandas
|
- pandas
|
||||||
- numpy
|
- numpy
|
||||||
- matplotlib
|
- matplotlib
|
||||||
- sphinx_autodoc_typehints
|
|
||||||
- stable-baselines3>=2.0.0
|
- stable-baselines3>=2.0.0
|
||||||
- sphinx>=5.3,<7.0
|
- sphinx>=5,<8
|
||||||
- sphinx_rtd_theme>=1.0
|
- sphinx_rtd_theme>=1.3.0
|
||||||
- sphinx_copybutton
|
- sphinx_copybutton
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,6 @@ release = __version__
|
||||||
# ones.
|
# ones.
|
||||||
extensions = [
|
extensions = [
|
||||||
"sphinx.ext.autodoc",
|
"sphinx.ext.autodoc",
|
||||||
"sphinx_autodoc_typehints",
|
|
||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"sphinx.ext.mathjax",
|
"sphinx.ext.mathjax",
|
||||||
"sphinx.ext.ifconfig",
|
"sphinx.ext.ifconfig",
|
||||||
|
|
@ -73,6 +72,8 @@ extensions = [
|
||||||
# 'sphinx.ext.doctest'
|
# 'sphinx.ext.doctest'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
autodoc_typehints = "description"
|
||||||
|
|
||||||
if enable_spell_check:
|
if enable_spell_check:
|
||||||
extensions.append("sphinxcontrib.spelling")
|
extensions.append("sphinxcontrib.spelling")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,34 @@ Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
|
Release 2.2.0a8 (WIP)
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
Breaking Changes:
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Upgraded to Stable-Baselines3 >= 2.2.0
|
||||||
|
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
|
||||||
|
- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)
|
||||||
|
|
||||||
|
New Features:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
- Added ``set_options`` for ``AsyncEval``
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
Deprecations:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Others:
|
||||||
|
^^^^^^^
|
||||||
|
- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument
|
||||||
|
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
|
||||||
|
|
||||||
|
Documentation:
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
||||||
Release 2.1.0 (2023-08-17)
|
Release 2.1.0 (2023-08-17)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
.. _tqc:
|
.. _trpo:
|
||||||
|
|
||||||
.. automodule:: sb3_contrib.trpo
|
.. automodule:: sb3_contrib.trpo
|
||||||
|
|
||||||
|
|
@ -105,7 +105,7 @@ Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
python train.py --algo tqc --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000
|
python train.py --algo trpo --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000
|
||||||
|
|
||||||
|
|
||||||
Plot the results:
|
Plot the results:
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,6 @@ max-complexity = 15
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 127
|
line-length = 127
|
||||||
|
|
||||||
[tool.isort]
|
|
||||||
profile = "black"
|
|
||||||
line_length = 127
|
|
||||||
src_paths = ["sb3_contrib"]
|
|
||||||
|
|
||||||
[tool.pytype]
|
[tool.pytype]
|
||||||
inputs = ['sb3_contrib']
|
inputs = ['sb3_contrib']
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -141,15 +141,26 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
log_prob = distribution.log_prob(actions)
|
log_prob = distribution.log_prob(actions)
|
||||||
return actions, values, log_prob
|
return actions, values, log_prob
|
||||||
|
|
||||||
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
def extract_features(
|
||||||
|
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||||
|
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Preprocess the observation if needed and extract features.
|
Preprocess the observation if needed and extract features.
|
||||||
|
|
||||||
:param obs: Observation
|
:param obs: Observation
|
||||||
:return: the output of the features extractor(s)
|
:param features_extractor: The features extractor to use. If None, then ``self.features_extractor`` is used.
|
||||||
|
:return: The extracted features. If features extractor is not shared, returns a tuple with the
|
||||||
|
features for the actor and the features for the critic.
|
||||||
"""
|
"""
|
||||||
if self.share_features_extractor:
|
if self.share_features_extractor:
|
||||||
return super().extract_features(obs, self.features_extractor)
|
return super().extract_features(obs, features_extractor or self.features_extractor)
|
||||||
else:
|
else:
|
||||||
|
if features_extractor is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Provided features_extractor will be ignored because the features extractor is not shared.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
pi_features = super().extract_features(obs, self.pi_features_extractor)
|
pi_features = super().extract_features(obs, self.pi_features_extractor)
|
||||||
vf_features = super().extract_features(obs, self.vf_features_extractor)
|
vf_features = super().extract_features(obs, self.vf_features_extractor)
|
||||||
return pi_features, vf_features
|
return pi_features, vf_features
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
|
|
@ -57,6 +57,9 @@ def _worker(
|
||||||
elif cmd == "seed":
|
elif cmd == "seed":
|
||||||
# Note: the seed will only be effective at the next reset
|
# Note: the seed will only be effective at the next reset
|
||||||
remote.send(vec_env.seed(seed=data))
|
remote.send(vec_env.seed(seed=data))
|
||||||
|
elif cmd == "set_options":
|
||||||
|
# Note: the options will only be effective at the next reset
|
||||||
|
remote.send(vec_env.set_options(data))
|
||||||
elif cmd == "get_obs_rms":
|
elif cmd == "get_obs_rms":
|
||||||
remote.send(obs_rms)
|
remote.send(obs_rms)
|
||||||
elif cmd == "sync_obs_rms":
|
elif cmd == "sync_obs_rms":
|
||||||
|
|
@ -158,6 +161,19 @@ class AsyncEval:
|
||||||
remote.send(("seed", seed + idx))
|
remote.send(("seed", seed + idx))
|
||||||
return [remote.recv() for remote in self.remotes]
|
return [remote.recv() for remote in self.remotes]
|
||||||
|
|
||||||
|
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> List[Union[None, int]]:
|
||||||
|
"""
|
||||||
|
Set environment options for all environments.
|
||||||
|
If a dict is passed instead of a list, the same options will be used for all environments.
|
||||||
|
WARNING: Those options will only be passed to the environment at the next reset.
|
||||||
|
|
||||||
|
:param options: A dictionary of environment options to pass to each environment at the next reset.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for remote in self.remotes:
|
||||||
|
remote.send(("set_options", options))
|
||||||
|
return [remote.recv() for remote in self.remotes]
|
||||||
|
|
||||||
def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]:
|
def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]:
|
||||||
"""
|
"""
|
||||||
Retreive episode rewards and lengths from each worker
|
Retreive episode rewards and lengths from each worker
|
||||||
|
|
|
||||||
|
|
@ -309,7 +309,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
# Give access to local variables
|
# Give access to local variables
|
||||||
callback.update_locals(locals())
|
callback.update_locals(locals())
|
||||||
if callback.on_step() is False:
|
if not callback.on_step():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._update_info_buffer(infos)
|
self._update_info_buffer(infos)
|
||||||
|
|
@ -525,7 +525,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
while self.num_timesteps < total_timesteps:
|
while self.num_timesteps < total_timesteps:
|
||||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
|
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
|
||||||
|
|
||||||
if continue_training is False:
|
if not continue_training:
|
||||||
break
|
break
|
||||||
|
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
|
||||||
|
|
@ -257,7 +257,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
# Give access to local variables
|
# Give access to local variables
|
||||||
callback.update_locals(locals())
|
callback.update_locals(locals())
|
||||||
if callback.on_step() is False:
|
if not callback.on_step():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._update_info_buffer(infos)
|
self._update_info_buffer(infos)
|
||||||
|
|
@ -468,7 +468,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
while self.num_timesteps < total_timesteps:
|
while self.num_timesteps < total_timesteps:
|
||||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
||||||
|
|
||||||
if continue_training is False:
|
if not continue_training:
|
||||||
break
|
break
|
||||||
|
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
2.1.0
|
2.2.0a8
|
||||||
|
|
|
||||||
7
setup.py
7
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=2.1.0",
|
"stable_baselines3>=2.2.0a8,<3.0",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
@ -95,8 +95,3 @@ setup(
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# python setup.py sdist
|
|
||||||
# python setup.py bdist_wheel
|
|
||||||
# twine upload --repository-url https://test.pypi.org/legacy/ dist/*
|
|
||||||
# twine upload dist/*
|
|
||||||
|
|
|
||||||
|
|
@ -111,6 +111,7 @@ def test_ars_multi_env():
|
||||||
# with parallelism
|
# with parallelism
|
||||||
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) for _ in range(2)], model.policy)
|
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) for _ in range(2)], model.policy)
|
||||||
async_eval.seed(0)
|
async_eval.seed(0)
|
||||||
|
async_eval.set_options()
|
||||||
model.learn(500, async_eval=async_eval)
|
model.learn(500, async_eval=async_eval)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue