Sync SB3 Contrib with SB3 (#213)

* Update RTD config

* Switch to ruff for sorting imports

* Evaluate falsy to truthy with not rather than `is False`

* Add `features_extractor` argument to maskable policy

* Add set_options for AsyncEval

* Doc fixes
This commit is contained in:
Antonin RAFFIN 2023-10-25 14:32:51 +02:00 committed by Andreas Schaarschmidt
parent 5be11deaf3
commit 4d7ed004af
14 changed files with 79 additions and 28 deletions

View File

@ -14,3 +14,8 @@ formats: all
# Set requirements using conda env
conda:
environment: docs/conda_env.yml
build:
os: ubuntu-22.04
tools:
python: "mambaforge-22.9"

View File

@ -22,13 +22,13 @@ lint:
format:
# Sort imports
isort ${LINT_PATHS}
ruff --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}
check-codestyle:
# Sort imports
isort --check ${LINT_PATHS}
ruff --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

View File

@ -13,8 +13,7 @@ dependencies:
- pandas
- numpy
- matplotlib
- sphinx_autodoc_typehints
- stable-baselines3>=2.0.0
- sphinx>=5.3,<7.0
- sphinx_rtd_theme>=1.0
- sphinx>=5,<8
- sphinx_rtd_theme>=1.3.0
- sphinx_copybutton

View File

@ -64,7 +64,6 @@ release = __version__
# ones.
extensions = [
"sphinx.ext.autodoc",
"sphinx_autodoc_typehints",
"sphinx.ext.autosummary",
"sphinx.ext.mathjax",
"sphinx.ext.ifconfig",
@ -73,6 +72,8 @@ extensions = [
# 'sphinx.ext.doctest'
]
autodoc_typehints = "description"
if enable_spell_check:
extensions.append("sphinxcontrib.spelling")

View File

@ -4,6 +4,34 @@ Changelog
==========
Release 2.2.0a8 (WIP)
--------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 2.2.0
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)
New Features:
^^^^^^^^^^^^^
- Added ``set_options`` for ``AsyncEval``
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
Documentation:
^^^^^^^^^^^^^^
Release 2.1.0 (2023-08-17)
--------------------------

View File

@ -1,4 +1,4 @@
.. _tqc:
.. _trpo:
.. automodule:: sb3_contrib.trpo
@ -105,7 +105,7 @@ Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
.. code-block:: bash
python train.py --algo tqc --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000
python train.py --algo trpo --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000
Plot the results:

View File

@ -19,11 +19,6 @@ max-complexity = 15
[tool.black]
line-length = 127
[tool.isort]
profile = "black"
line_length = 127
src_paths = ["sb3_contrib"]
[tool.pytype]
inputs = ['sb3_contrib']

View File

@ -141,15 +141,26 @@ class MaskableActorCriticPolicy(BasePolicy):
log_prob = distribution.log_prob(actions)
return actions, values, log_prob
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
def extract_features(
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
:param obs: Observation
:return: the output of the features extractor(s)
:param features_extractor: The features extractor to use. If None, then ``self.features_extractor`` is used.
:return: The extracted features. If features extractor is not shared, returns a tuple with the
features for the actor and the features for the critic.
"""
if self.share_features_extractor:
return super().extract_features(obs, self.features_extractor)
return super().extract_features(obs, features_extractor or self.features_extractor)
else:
if features_extractor is not None:
warnings.warn(
"Provided features_extractor will be ignored because the features extractor is not shared.",
UserWarning,
)
pi_features = super().extract_features(obs, self.pi_features_extractor)
vf_features = super().extract_features(obs, self.vf_features_extractor)
return pi_features, vf_features

View File

@ -1,6 +1,6 @@
import multiprocessing as mp
from collections import defaultdict
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch as th
@ -57,6 +57,9 @@ def _worker(
elif cmd == "seed":
# Note: the seed will only be effective at the next reset
remote.send(vec_env.seed(seed=data))
elif cmd == "set_options":
# Note: the options will only be effective at the next reset
remote.send(vec_env.set_options(data))
elif cmd == "get_obs_rms":
remote.send(obs_rms)
elif cmd == "sync_obs_rms":
@ -158,6 +161,19 @@ class AsyncEval:
remote.send(("seed", seed + idx))
return [remote.recv() for remote in self.remotes]
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> List[Union[None, int]]:
"""
Set environment options for all environments.
If a dict is passed instead of a list, the same options will be used for all environments.
WARNING: Those options will only be passed to the environment at the next reset.
:param options: A dictionary of environment options to pass to each environment at the next reset.
:return:
"""
for remote in self.remotes:
remote.send(("set_options", options))
return [remote.recv() for remote in self.remotes]
def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]:
"""
Retreive episode rewards and lengths from each worker

View File

@ -309,7 +309,7 @@ class MaskablePPO(OnPolicyAlgorithm):
# Give access to local variables
callback.update_locals(locals())
if callback.on_step() is False:
if not callback.on_step():
return False
self._update_info_buffer(infos)
@ -525,7 +525,7 @@ class MaskablePPO(OnPolicyAlgorithm):
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
if continue_training is False:
if not continue_training:
break
iteration += 1

View File

@ -257,7 +257,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
# Give access to local variables
callback.update_locals(locals())
if callback.on_step() is False:
if not callback.on_step():
return False
self._update_info_buffer(infos)
@ -468,7 +468,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if continue_training is False:
if not continue_training:
break
iteration += 1

View File

@ -1 +1 @@
2.1.0
2.2.0a8

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.1.0",
"stable_baselines3>=2.2.0a8,<3.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
@ -95,8 +95,3 @@ setup(
"Programming Language :: Python :: 3.11",
],
)
# python setup.py sdist
# python setup.py bdist_wheel
# twine upload --repository-url https://test.pypi.org/legacy/ dist/*
# twine upload dist/*

View File

@ -111,6 +111,7 @@ def test_ars_multi_env():
# with parallelism
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) for _ in range(2)], model.policy)
async_eval.seed(0)
async_eval.set_options()
model.learn(500, async_eval=async_eval)