Sync SB3 Contrib with SB3 (#213)
* Update RTD config * Switch to ruff for sorting imports * Evaluate falsy to truthy with not rather than `is False` * Add `features_extractor` argument to maskable policy * Add set_options for AsyncEval * Doc fixes
This commit is contained in:
parent
5be11deaf3
commit
4d7ed004af
|
|
@ -14,3 +14,8 @@ formats: all
|
|||
# Set requirements using conda env
|
||||
conda:
|
||||
environment: docs/conda_env.yml
|
||||
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "mambaforge-22.9"
|
||||
|
|
|
|||
4
Makefile
4
Makefile
|
|
@ -22,13 +22,13 @@ lint:
|
|||
|
||||
format:
|
||||
# Sort imports
|
||||
isort ${LINT_PATHS}
|
||||
ruff --select I ${LINT_PATHS} --fix
|
||||
# Reformat using black
|
||||
black ${LINT_PATHS}
|
||||
|
||||
check-codestyle:
|
||||
# Sort imports
|
||||
isort --check ${LINT_PATHS}
|
||||
ruff --select I ${LINT_PATHS}
|
||||
# Reformat using black
|
||||
black --check ${LINT_PATHS}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,8 +13,7 @@ dependencies:
|
|||
- pandas
|
||||
- numpy
|
||||
- matplotlib
|
||||
- sphinx_autodoc_typehints
|
||||
- stable-baselines3>=2.0.0
|
||||
- sphinx>=5.3,<7.0
|
||||
- sphinx_rtd_theme>=1.0
|
||||
- sphinx>=5,<8
|
||||
- sphinx_rtd_theme>=1.3.0
|
||||
- sphinx_copybutton
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ release = __version__
|
|||
# ones.
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx_autodoc_typehints",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.mathjax",
|
||||
"sphinx.ext.ifconfig",
|
||||
|
|
@ -73,6 +72,8 @@ extensions = [
|
|||
# 'sphinx.ext.doctest'
|
||||
]
|
||||
|
||||
autodoc_typehints = "description"
|
||||
|
||||
if enable_spell_check:
|
||||
extensions.append("sphinxcontrib.spelling")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,34 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 2.2.0a8 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Upgraded to Stable-Baselines3 >= 2.2.0
|
||||
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
|
||||
- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``set_options`` for ``AsyncEval``
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument
|
||||
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Release 2.1.0 (2023-08-17)
|
||||
--------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
.. _tqc:
|
||||
.. _trpo:
|
||||
|
||||
.. automodule:: sb3_contrib.trpo
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
|||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo tqc --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000
|
||||
python train.py --algo trpo --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000
|
||||
|
||||
|
||||
Plot the results:
|
||||
|
|
|
|||
|
|
@ -19,11 +19,6 @@ max-complexity = 15
|
|||
[tool.black]
|
||||
line-length = 127
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 127
|
||||
src_paths = ["sb3_contrib"]
|
||||
|
||||
[tool.pytype]
|
||||
inputs = ['sb3_contrib']
|
||||
|
||||
|
|
|
|||
|
|
@ -141,15 +141,26 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
log_prob = distribution.log_prob(actions)
|
||||
return actions, values, log_prob
|
||||
|
||||
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||
def extract_features(
|
||||
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
||||
:param obs: Observation
|
||||
:return: the output of the features extractor(s)
|
||||
:param features_extractor: The features extractor to use. If None, then ``self.features_extractor`` is used.
|
||||
:return: The extracted features. If features extractor is not shared, returns a tuple with the
|
||||
features for the actor and the features for the critic.
|
||||
"""
|
||||
if self.share_features_extractor:
|
||||
return super().extract_features(obs, self.features_extractor)
|
||||
return super().extract_features(obs, features_extractor or self.features_extractor)
|
||||
else:
|
||||
if features_extractor is not None:
|
||||
warnings.warn(
|
||||
"Provided features_extractor will be ignored because the features extractor is not shared.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
pi_features = super().extract_features(obs, self.pi_features_extractor)
|
||||
vf_features = super().extract_features(obs, self.vf_features_extractor)
|
||||
return pi_features, vf_features
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import multiprocessing as mp
|
||||
from collections import defaultdict
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -57,6 +57,9 @@ def _worker(
|
|||
elif cmd == "seed":
|
||||
# Note: the seed will only be effective at the next reset
|
||||
remote.send(vec_env.seed(seed=data))
|
||||
elif cmd == "set_options":
|
||||
# Note: the options will only be effective at the next reset
|
||||
remote.send(vec_env.set_options(data))
|
||||
elif cmd == "get_obs_rms":
|
||||
remote.send(obs_rms)
|
||||
elif cmd == "sync_obs_rms":
|
||||
|
|
@ -158,6 +161,19 @@ class AsyncEval:
|
|||
remote.send(("seed", seed + idx))
|
||||
return [remote.recv() for remote in self.remotes]
|
||||
|
||||
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> List[Union[None, int]]:
|
||||
"""
|
||||
Set environment options for all environments.
|
||||
If a dict is passed instead of a list, the same options will be used for all environments.
|
||||
WARNING: Those options will only be passed to the environment at the next reset.
|
||||
|
||||
:param options: A dictionary of environment options to pass to each environment at the next reset.
|
||||
:return:
|
||||
"""
|
||||
for remote in self.remotes:
|
||||
remote.send(("set_options", options))
|
||||
return [remote.recv() for remote in self.remotes]
|
||||
|
||||
def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]:
|
||||
"""
|
||||
Retreive episode rewards and lengths from each worker
|
||||
|
|
|
|||
|
|
@ -309,7 +309,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
|
||||
# Give access to local variables
|
||||
callback.update_locals(locals())
|
||||
if callback.on_step() is False:
|
||||
if not callback.on_step():
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
|
|
@ -525,7 +525,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
while self.num_timesteps < total_timesteps:
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
|
||||
|
||||
if continue_training is False:
|
||||
if not continue_training:
|
||||
break
|
||||
|
||||
iteration += 1
|
||||
|
|
|
|||
|
|
@ -257,7 +257,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
|
||||
# Give access to local variables
|
||||
callback.update_locals(locals())
|
||||
if callback.on_step() is False:
|
||||
if not callback.on_step():
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
|
|
@ -468,7 +468,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
while self.num_timesteps < total_timesteps:
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
||||
|
||||
if continue_training is False:
|
||||
if not continue_training:
|
||||
break
|
||||
|
||||
iteration += 1
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.1.0
|
||||
2.2.0a8
|
||||
|
|
|
|||
7
setup.py
7
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=2.1.0",
|
||||
"stable_baselines3>=2.2.0a8,<3.0",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
@ -95,8 +95,3 @@ setup(
|
|||
"Programming Language :: Python :: 3.11",
|
||||
],
|
||||
)
|
||||
|
||||
# python setup.py sdist
|
||||
# python setup.py bdist_wheel
|
||||
# twine upload --repository-url https://test.pypi.org/legacy/ dist/*
|
||||
# twine upload dist/*
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ def test_ars_multi_env():
|
|||
# with parallelism
|
||||
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) for _ in range(2)], model.policy)
|
||||
async_eval.seed(0)
|
||||
async_eval.set_options()
|
||||
model.learn(500, async_eval=async_eval)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue