diff --git a/.readthedocs.yml b/.readthedocs.yml index 6753d87..dbb2fad 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -14,3 +14,8 @@ formats: all # Set requirements using conda env conda: environment: docs/conda_env.yml + +build: + os: ubuntu-22.04 + tools: + python: "mambaforge-22.9" diff --git a/Makefile b/Makefile index 9643530..fab99ff 100644 --- a/Makefile +++ b/Makefile @@ -22,13 +22,13 @@ lint: format: # Sort imports - isort ${LINT_PATHS} + ruff --select I ${LINT_PATHS} --fix # Reformat using black black ${LINT_PATHS} check-codestyle: # Sort imports - isort --check ${LINT_PATHS} + ruff --select I ${LINT_PATHS} # Reformat using black black --check ${LINT_PATHS} diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 16c1d96..b86838e 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -13,8 +13,7 @@ dependencies: - pandas - numpy - matplotlib - - sphinx_autodoc_typehints - stable-baselines3>=2.0.0 - - sphinx>=5.3,<7.0 - - sphinx_rtd_theme>=1.0 + - sphinx>=5,<8 + - sphinx_rtd_theme>=1.3.0 - sphinx_copybutton diff --git a/docs/conf.py b/docs/conf.py index 218ef3c..7e987ca 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -64,7 +64,6 @@ release = __version__ # ones. extensions = [ "sphinx.ext.autodoc", - "sphinx_autodoc_typehints", "sphinx.ext.autosummary", "sphinx.ext.mathjax", "sphinx.ext.ifconfig", @@ -73,6 +72,8 @@ extensions = [ # 'sphinx.ext.doctest' ] +autodoc_typehints = "description" + if enable_spell_check: extensions.append("sphinxcontrib.spelling") diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8c93bae..9d4a4d2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,6 +4,34 @@ Changelog ========== +Release 2.2.0a8 (WIP) +-------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Upgraded to Stable-Baselines3 >= 2.2.0 +- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version +- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle) + +New Features: +^^^^^^^^^^^^^ +- Added ``set_options`` for ``AsyncEval`` + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument +- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``) + +Documentation: +^^^^^^^^^^^^^^ + + Release 2.1.0 (2023-08-17) -------------------------- diff --git a/docs/modules/trpo.rst b/docs/modules/trpo.rst index 0a7e985..4376faa 100644 --- a/docs/modules/trpo.rst +++ b/docs/modules/trpo.rst @@ -1,4 +1,4 @@ -.. _tqc: +.. _trpo: .. automodule:: sb3_contrib.trpo @@ -105,7 +105,7 @@ Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): .. code-block:: bash - python train.py --algo tqc --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000 + python train.py --algo trpo --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000 Plot the results: diff --git a/pyproject.toml b/pyproject.toml index eaf4e29..925931c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,11 +19,6 @@ max-complexity = 15 [tool.black] line-length = 127 -[tool.isort] -profile = "black" -line_length = 127 -src_paths = ["sb3_contrib"] - [tool.pytype] inputs = ['sb3_contrib'] diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index 34e9f08..f6b5eeb 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -141,15 +141,26 @@ class MaskableActorCriticPolicy(BasePolicy): log_prob = distribution.log_prob(actions) return actions, values, log_prob - def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: + def extract_features( + self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None + ) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: """ Preprocess the observation if needed and extract features. + :param obs: Observation - :return: the output of the features extractor(s) + :param features_extractor: The features extractor to use. If None, then ``self.features_extractor`` is used. + :return: The extracted features. If features extractor is not shared, returns a tuple with the + features for the actor and the features for the critic. """ if self.share_features_extractor: - return super().extract_features(obs, self.features_extractor) + return super().extract_features(obs, features_extractor or self.features_extractor) else: + if features_extractor is not None: + warnings.warn( + "Provided features_extractor will be ignored because the features extractor is not shared.", + UserWarning, + ) + pi_features = super().extract_features(obs, self.pi_features_extractor) vf_features = super().extract_features(obs, self.vf_features_extractor) return pi_features, vf_features diff --git a/sb3_contrib/common/vec_env/async_eval.py b/sb3_contrib/common/vec_env/async_eval.py index 9eed7b5..a7be9ba 100644 --- a/sb3_contrib/common/vec_env/async_eval.py +++ b/sb3_contrib/common/vec_env/async_eval.py @@ -1,6 +1,6 @@ import multiprocessing as mp from collections import defaultdict -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch as th @@ -57,6 +57,9 @@ def _worker( elif cmd == "seed": # Note: the seed will only be effective at the next reset remote.send(vec_env.seed(seed=data)) + elif cmd == "set_options": + # Note: the options will only be effective at the next reset + remote.send(vec_env.set_options(data)) elif cmd == "get_obs_rms": remote.send(obs_rms) elif cmd == "sync_obs_rms": @@ -158,6 +161,19 @@ class AsyncEval: remote.send(("seed", seed + idx)) return [remote.recv() for remote in self.remotes] + def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> List[Union[None, int]]: + """ + Set environment options for all environments. + If a dict is passed instead of a list, the same options will be used for all environments. + WARNING: Those options will only be passed to the environment at the next reset. + + :param options: A dictionary of environment options to pass to each environment at the next reset. + :return: + """ + for remote in self.remotes: + remote.send(("set_options", options)) + return [remote.recv() for remote in self.remotes] + def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]: """ Retreive episode rewards and lengths from each worker diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index ffd831c..3248ba7 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -309,7 +309,7 @@ class MaskablePPO(OnPolicyAlgorithm): # Give access to local variables callback.update_locals(locals()) - if callback.on_step() is False: + if not callback.on_step(): return False self._update_info_buffer(infos) @@ -525,7 +525,7 @@ class MaskablePPO(OnPolicyAlgorithm): while self.num_timesteps < total_timesteps: continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking) - if continue_training is False: + if not continue_training: break iteration += 1 diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index acd44c9..39fd941 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -257,7 +257,7 @@ class RecurrentPPO(OnPolicyAlgorithm): # Give access to local variables callback.update_locals(locals()) - if callback.on_step() is False: + if not callback.on_step(): return False self._update_info_buffer(infos) @@ -468,7 +468,7 @@ class RecurrentPPO(OnPolicyAlgorithm): while self.num_timesteps < total_timesteps: continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) - if continue_training is False: + if not continue_training: break iteration += 1 diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 7ec1d6d..f1f23b3 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.1.0 +2.2.0a8 diff --git a/setup.py b/setup.py index 07b1ab5..88242e2 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.1.0", + "stable_baselines3>=2.2.0a8,<3.0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", @@ -95,8 +95,3 @@ setup( "Programming Language :: Python :: 3.11", ], ) - -# python setup.py sdist -# python setup.py bdist_wheel -# twine upload --repository-url https://test.pypi.org/legacy/ dist/* -# twine upload dist/* diff --git a/tests/test_run.py b/tests/test_run.py index 9da6525..f944aff 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -111,6 +111,7 @@ def test_ars_multi_env(): # with parallelism async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) for _ in range(2)], model.policy) async_eval.seed(0) + async_eval.set_options() model.learn(500, async_eval=async_eval)