diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 62770eb..228ee93 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,15 +29,16 @@ jobs: python -m pip install --upgrade pip # cpu version of pytorch pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + # Install dependencies for docs and tests pip install stable_baselines3[extra,tests,docs] # Install master version pip install git+https://github.com/DLR-RM/stable-baselines3 pip install . # Use headless version pip install opencv-python-headless - # - name: Build the doc - # run: | - # make doc + - name: Build the doc + run: | + make doc - name: Type check run: | make type diff --git a/Makefile b/Makefile index 639d266..d740b60 100644 --- a/Makefile +++ b/Makefile @@ -28,6 +28,11 @@ check-codestyle: commit-checks: format type lint +doc: + cd docs && make html + +spelling: + cd docs && make spelling # PyPi package release release: @@ -41,4 +46,4 @@ test-release: python setup.py bdist_wheel twine upload --repository-url https://test.pypi.org/legacy/ dist/* -.PHONY: lint format check-codestyle commit-checks +.PHONY: lint format check-codestyle commit-checks doc spelling diff --git a/README.md b/README.md index 36cdb24..7280244 100644 --- a/README.md +++ b/README.md @@ -18,3 +18,19 @@ Implemented: ``` pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib ``` + + +## Citing the Project + +To cite this repository in publications (please cite SB3 directly): + +``` +@misc{stable-baselines3, + author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah}, + title = {Stable Baselines3}, + year = {2019}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}}, +} +``` diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..47f98cd --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = -W # make warnings fatal +SPHINXBUILD = sphinx-build +SPHINXPROJ = StableBaselines +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..1427a79 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,25 @@ +## Stable Baselines3 Documentation + +This folder contains documentation for the RL baselines. + + +### Build the Documentation + +#### Install Sphinx and Theme + +``` +pip install sphinx sphinx-autobuild sphinx-rtd-theme +``` + +#### Building the Docs + +In the `docs/` folder: +``` +make html +``` + +if you want to building each time a file is changed: + +``` +sphinx-autobuild . _build/html +``` diff --git a/docs/_static/css/baselines_theme.css b/docs/_static/css/baselines_theme.css new file mode 100644 index 0000000..450864e --- /dev/null +++ b/docs/_static/css/baselines_theme.css @@ -0,0 +1,61 @@ +/* Main colors adapted from pytorch doc */ +:root{ + --main-bg-color: #343A40; + --link-color: #FD7E14; +} + +/* Header fonts y */ +h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { + font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; +} + + +/* Docs background */ +.wy-side-nav-search{ + background-color: var(--main-bg-color); +} + +/* Mobile version */ +.wy-nav-top{ + background-color: var(--main-bg-color); +} + +/* Change link colors (except for the menu) */ +a { + color: var(--link-color); +} + +a:hover { + color: #4F778F; +} + +.wy-menu a { + color: #b3b3b3; +} + +.wy-menu a:hover { + color: #b3b3b3; +} + +a.icon.icon-home { + color: #b3b3b3; +} + +.version{ + color: var(--link-color) !important; +} + + +/* Make code blocks have a background */ +.codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] { + background: #f8f8f8;; +} + +/* Change style of types in the docstrings .rst-content .field-list */ +.field-list .xref.py.docutils, .field-list code.docutils, .field-list .docutils.literal.notranslate +{ + border: None; + padding-left: 0; + padding-right: 0; + color: #404040; +} diff --git a/docs/_static/img/colab-badge.svg b/docs/_static/img/colab-badge.svg new file mode 100644 index 0000000..c08066e --- /dev/null +++ b/docs/_static/img/colab-badge.svg @@ -0,0 +1 @@ + Open in ColabOpen in Colab diff --git a/docs/_static/img/colab.svg b/docs/_static/img/colab.svg new file mode 100644 index 0000000..c2d30e9 --- /dev/null +++ b/docs/_static/img/colab.svg @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/docs/_static/img/logo.png b/docs/_static/img/logo.png new file mode 100644 index 0000000..7cec1c3 Binary files /dev/null and b/docs/_static/img/logo.png differ diff --git a/docs/_static/img/logo.png~ b/docs/_static/img/logo.png~ new file mode 100644 index 0000000..ef592f5 Binary files /dev/null and b/docs/_static/img/logo.png~ differ diff --git a/docs/conda_env.yml b/docs/conda_env.yml new file mode 100644 index 0000000..35e604d --- /dev/null +++ b/docs/conda_env.yml @@ -0,0 +1,18 @@ +name: root +channels: + - pytorch + - defaults +dependencies: + - cpuonly=1.0=0 + - pip=20.2 + - python=3.6 + - pytorch=1.5.0=py3.6_cpu_0 + - pip: + - gym>=0.17.2 + - cloudpickle + - opencv-python-headless + - pandas + - numpy + - matplotlib + - sphinx_autodoc_typehints + - stable-baselines3>=0.9.0 diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..ed733df --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- +# +# Configuration file for the Sphinx documentation builder. +# +# This file does only contain a selection of the most common options. For a +# full list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +from unittest.mock import MagicMock + +# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support +# PyEnchant. +try: + import sphinxcontrib.spelling # noqa: F401 + + enable_spell_check = True +except ImportError: + enable_spell_check = False + +# source code directory, relative to this file, for sphinx-autobuild +sys.path.insert(0, os.path.abspath("..")) + + +class Mock(MagicMock): + __subclasses__ = [] + + @classmethod + def __getattr__(cls, name): + return MagicMock() + + +# Mock modules that requires C modules +# Note: because of that we cannot test examples using CI +# 'torch', 'torch.nn', 'torch.nn.functional', +# DO not mock modules for now, we will need to do that for read the docs later +MOCK_MODULES = [] +sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) + +# Read version from file +version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt") +with open(version_file, "r") as file_handler: + __version__ = file_handler.read().strip() + +# -- Project information ----------------------------------------------------- + +project = "Stable Baselines3 - Contrib" +copyright = "2020, Stable Baselines3" +author = "Stable Baselines3 Contributors" + +# The short X.Y version +version = "master (" + __version__ + " )" +# The full version, including alpha/beta/rc tags +release = __version__ + + +# -- General configuration --------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx_autodoc_typehints", + "sphinx.ext.autosummary", + "sphinx.ext.mathjax", + "sphinx.ext.ifconfig", + "sphinx.ext.viewcode", + # 'sphinx.ext.intersphinx', + # 'sphinx.ext.doctest' +] + +if enable_spell_check: + extensions.append("sphinxcontrib.spelling") + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = ".rst" + +# The master toctree document. +master_doc = "index" + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path . +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. + +# Fix for read the docs +on_rtd = os.environ.get("READTHEDOCS") == "True" +if on_rtd: + html_theme = "default" +else: + html_theme = "sphinx_rtd_theme" + +html_logo = "_static/img/logo.png" + + +def setup(app): + app.add_css_file("css/baselines_theme.css") + + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# The default sidebars (for documents that don't match any pattern) are +# defined by theme itself. Builtin themes are using these templates by +# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', +# 'searchbox.html']``. +# +# html_sidebars = {} + + +# -- Options for HTMLHelp output --------------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = "StableBaselines3doc" + + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, "StableBaselines3.tex", "Stable Baselines3 Documentation", "Stable Baselines3 Contributors", "manual"), +] + + +# -- Options for manual page output ------------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [(master_doc, "stablebaselines3", "Stable Baselines3 Documentation", [author], 1)] + + +# -- Options for Texinfo output ---------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ( + master_doc, + "StableBaselines3", + "Stable Baselines3 Documentation", + author, + "StableBaselines3", + "One line description of project.", + "Miscellaneous", + ), +] + + +# -- Extension configuration ------------------------------------------------- + +# Example configuration for intersphinx: refer to the Python standard library. +# intersphinx_mapping = { +# 'python': ('https://docs.python.org/3/', None), +# 'numpy': ('http://docs.scipy.org/doc/numpy/', None), +# 'torch': ('http://pytorch.org/docs/master/', None), +# } diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst new file mode 100644 index 0000000..1d863ea --- /dev/null +++ b/docs/guide/algos.rst @@ -0,0 +1,30 @@ +RL Algorithms +============= + +This table displays the rl algorithms that are implemented in the Stable Baselines3 contrib project, +along with some useful characteristics: support for discrete/continuous actions, multiprocessing. + +.. +.. ============ =========== ============ ================= =============== ================ +.. Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing +.. ============ =========== ============ ================= =============== ================ +.. A2C ✔️ ✔️ ✔️ ✔️ ✔️ +.. DDPG ✔️ ❌ ❌ ❌ ❌ +.. DQN ❌ ✔️ ❌ ❌ ❌ +.. PPO ✔️ ✔️ ✔️ ✔️ ✔️ +.. SAC ✔️ ❌ ❌ ❌ ❌ +.. TD3 ✔️ ❌ ❌ ❌ ❌ +.. ============ =========== ============ ================= =============== ================ + + +.. .. note:: +.. Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm. +.. +.. Actions ``gym.spaces``: +.. +.. - ``Box``: A N-dimensional box that contains every point in the action +.. space. +.. - ``Discrete``: A list of possible actions, where each timestep only +.. one of the actions can be used. +.. - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used. +.. - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination. diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst new file mode 100644 index 0000000..65b2d08 --- /dev/null +++ b/docs/guide/examples.rst @@ -0,0 +1,40 @@ +.. _examples: + +Examples +======== + +WIP + +.. PyBullet: Normalizing input features +.. ------------------------------------ +.. +.. Normalizing input features may be essential to successful training of an RL agent +.. (by default, images are scaled but not other types of input), +.. for instance when training on `PyBullet `__ environments. For that, a wrapper exists and +.. will compute a running average and standard deviation of input features (it can do the same for rewards). +.. + +.. .. note:: +.. +.. you need to install pybullet with ``pip install pybullet`` +.. +.. +.. .. image:: ../_static/img/colab-badge.svg +.. :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb +.. +.. +.. .. code-block:: python +.. +.. import gym +.. import pybullet_envs +.. +.. from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize +.. from stable_baselines3 import PPO +.. +.. env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) +.. # Automatically normalize the input features and reward +.. env = VecNormalize(env, norm_obs=True, norm_reward=True, +.. clip_obs=10.) +.. +.. model = PPO('MlpPolicy', env) +.. model.learn(total_timesteps=2000) diff --git a/docs/guide/install.rst b/docs/guide/install.rst new file mode 100644 index 0000000..dcffed8 --- /dev/null +++ b/docs/guide/install.rst @@ -0,0 +1,37 @@ +.. _install: + +Installation +============ + +Prerequisites +------------- + +Please read `Stable-Baselines3 installation guide `_ first. + + +Stable Release +~~~~~~~~~~~~~~ +To install Stable Baselines3 contrib with pip, execute: + +.. code-block:: bash + + pip install sb3-contrib + + +Bleeding-edge version +--------------------- + +.. code-block:: bash + + pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/ + + +Development version +------------------- + +To contribute to Stable-Baselines3, with support for running tests and building the documentation. + +.. code-block:: bash + + git clone https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/ && cd stable-baselines3-contrib + pip install -e . diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..b77a85f --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,78 @@ +.. Stable Baselines3 documentation master file, created by + sphinx-quickstart on Thu Sep 26 11:06:54 2019. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to Stable Baselines3 Contrib docs! +========================================== + +Contrib package for `Stable Baselines3 `_ - Experimental code. + + +Github repository: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib + +SB3 repository: https://github.com/DLR-RM/stable-baselines3 + +RL Baselines3 Zoo (collection of pre-trained agents): https://github.com/DLR-RM/rl-baselines3-zoo + +RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning. + + +.. toctree:: + :maxdepth: 2 + :caption: User Guide + + guide/install + guide/algos + guide/examples + + +.. .. toctree:: +.. :maxdepth: 1 +.. :caption: RL Algorithms +.. +.. modules/a2c + +.. .. toctree:: +.. :maxdepth: 1 +.. :caption: Common +.. +.. common/atari_wrappers + +.. toctree:: + :maxdepth: 1 + :caption: Misc + + misc/changelog + + +Citing Stable Baselines3 +------------------------ +To cite this project in publications: + +.. code-block:: bibtex + + @misc{stable-baselines3, + author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah}, + title = {Stable Baselines3}, + year = {2019}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}}, + } + +Contributing +------------ + +To any interested in making the rl baselines better, there are still some improvements +that need to be done. +You can check issues in the `repo `_. + +If you want to contribute, please read `CONTRIBUTING.md `_ first. + +Indices and tables +------------------- + +* :ref:`genindex` +* :ref:`search` +* :ref:`modindex` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..22b5fff --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build +set SPHINXPROJ=StableBaselines + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst new file mode 100644 index 0000000..6cb9a38 --- /dev/null +++ b/docs/misc/changelog.rst @@ -0,0 +1,42 @@ +.. _changelog: + +Changelog +========== + + +Pre-Release 0.10.0a0 (WIP) +------------------------------ + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + + +Maintainers +----------- + +Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a), +`Maximilian Ernestus`_ (aka @erniejunior), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_). + +.. _Ashley Hill: https://github.com/hill-a +.. _Antonin Raffin: https://araffin.github.io/ +.. _Maximilian Ernestus: https://github.com/erniejunior +.. _Adam Gleave: https://gleave.me/ +.. _@araffin: https://github.com/araffin +.. _@AdamGleave: https://github.com/adamgleave +.. _Anssi Kanervisto: https://github.com/Miffyli +.. _@Miffyli: https://github.com/Miffyli diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt new file mode 100644 index 0000000..42669bf --- /dev/null +++ b/docs/spelling_wordlist.txt @@ -0,0 +1,121 @@ +py +env +atari +argparse +Argparse +TensorFlow +feedforward +envs +VecEnv +pretrain +petrained +tf +th +nn +np +str +mujoco +cpu +ndarray +ndarrays +timestep +timesteps +stepsize +dataset +adam +fn +normalisation +Kullback +Leibler +boolean +deserialized +pretrained +minibatch +subprocesses +ArgumentParser +Tensorflow +Gaussian +approximator +minibatches +hyperparameters +hyperparameter +vectorized +rl +colab +dataloader +npz +datasets +vf +logits +num +Utils +backpropagate +prepend +NaN +preprocessing +Cloudpickle +async +multiprocess +tensorflow +mlp +cnn +neglogp +tanh +coef +repo +Huber +params +ppo +arxiv +Arxiv +func +DQN +Uhlenbeck +Ornstein +multithread +cancelled +Tensorboard +parallelize +customising +serializable +Multiprocessed +cartpole +toolset +lstm +rescale +ffmpeg +avconv +unnormalized +Github +pre +preprocess +backend +attr +preprocess +Antonin +Raffin +araffin +Homebrew +Numpy +Theano +rollout +kfac +Piecewise +csv +nvidia +visdom +tensorboard +preprocessed +namespace +sklearn +GoalEnv +Torchy +pytorch +dicts +optimizers +Deprecations +forkserver +cuda +Polyak +gSDE +rollouts diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index e58cb11..f728c3a 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -17,25 +17,25 @@ class Actor(BasePolicy): """ Actor network (policy) for TQC. - :param observation_space: (gym.spaces.Space) Obervation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: ([int]) Network architecture - :param features_extractor: (nn.Module) Network to extract features + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: (int) Number of features - :param activation_fn: (Type[nn.Module]) Activation function - :param use_sde: (bool) Whether to use State Dependent Exploration or not - :param log_std_init: (float) Initial value for the log standard deviation - :param full_std: (bool) Whether to use (n_features x n_actions) parameters + :param features_dim: Number of features + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE. - :param sde_net_arch: ([int]) Network architecture for extracting features + :param sde_net_arch: Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. - :param normalize_images: (bool) Whether to normalize images or not, + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) """ @@ -131,7 +131,7 @@ class Actor(BasePolicy): but is slightly different when using ``expln`` function (cf StateDependentNoiseDistribution doc). - :return: (th.Tensor) + :return: """ msg = "get_std() is only available when using gSDE" assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg @@ -141,7 +141,7 @@ class Actor(BasePolicy): """ Sample new weights for the exploration matrix, when using gSDE. - :param batch_size: (int) + :param batch_size: """ msg = "reset_noise() is only available when using gSDE" assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg @@ -151,8 +151,8 @@ class Actor(BasePolicy): """ Get the parameters for the action distribution. - :param obs: (th.Tensor) - :return: (Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]) + :param obs: + :return: Mean, standard deviation and optional keyword arguments. """ features = self.extract_features(obs) @@ -183,31 +183,19 @@ class Actor(BasePolicy): def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: return self.forward(observation, deterministic) - def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: - """ - Evaluate actions according to the current policy, - given the observations. Only useful when using SDE. - :param obs: (th.Tensor) - :param actions: (th.Tensor) - :return: (th.Tensor) log likelihood of taking those actions - """ - mean_actions, log_std, kwargs = self.get_action_dist_params(obs) - self.action_dist.proba_distribution(mean_actions, log_std, **kwargs) - return self.action_dist.log_prob(actions) - class Critic(BaseModel): """ Critic network (q-value function) for TQC. - :param observation_space: (gym.spaces.Space) Obervation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: ([int]) Network architecture - :param features_extractor: (nn.Module) Network to extract features + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: (int) Number of features - :param activation_fn: (Type[nn.Module]) Activation function - :param normalize_images: (bool) Whether to normalize images or not, + :param features_dim: Number of features + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) """ @@ -257,28 +245,28 @@ class TQCPolicy(BasePolicy): """ Policy class (with both actor and critic) for TQC. - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (callable) Learning rate schedule (could be constant) - :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param use_sde: (bool) Whether to use State Dependent Exploration or not - :param log_std_init: (float) Initial value for the log standard deviation - :param sde_net_arch: ([int]) Network architecture for extracting features + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param sde_net_arch: Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments to pass to the feature extractor. - :param normalize_images: (bool) Whether to normalize images or not, + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ @@ -388,7 +376,7 @@ class TQCPolicy(BasePolicy): """ Sample new weights for the exploration matrix, when using gSDE. - :param batch_size: (int) + :param batch_size: """ self.actor.reset_noise(batch_size=batch_size) @@ -412,26 +400,26 @@ class CnnPolicy(TQCPolicy): """ Policy class (with both actor and critic) for TQC. - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (callable) Learning rate schedule (could be constant) - :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param use_sde: (bool) Whether to use State Dependent Exploration or not - :param log_std_init: (float) Initial value for the log standard deviation - :param sde_net_arch: ([int]) Network architecture for extracting features + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param sde_net_arch: Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param normalize_images: (bool) Whether to normalize images or not, + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 18fc902..ea6818e 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -7,7 +7,6 @@ from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.utils import polyak_update -from tqdm import tqdm from sb3_contrib.tqc.policies import TQCPolicy @@ -15,48 +14,48 @@ from sb3_contrib.tqc.policies import TQCPolicy class TQC(OffPolicyAlgorithm): """ - Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics + Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics. Paper: https://arxiv.org/abs/2005.04269 - :param policy: (TQCPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) - :param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str) - :param learning_rate: (float or callable) learning rate for adam optimizer, + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0) - :param buffer_size: (int) size of the replay buffer - :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts - :param batch_size: (int) Minibatch size for each gradient update - :param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) - :param gamma: (float) the discount factor - :param train_freq: (int) Update the model every ``train_freq`` steps. - :param gradient_steps: (int) How many gradient update after each step - :param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. + :param gradient_steps: How many gradient update after each step + :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes. Note that this cannot be used at the same time as ``train_freq`` - :param action_noise: (ActionNoise) the action noise type (None by default), this can help + :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. - :param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 - :param ent_coef: (str or float) Entropy regularization coefficient. (Equivalent to + :param ent_coef: Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value) - :param target_update_interval: (int) update the target network every ``target_network_update_freq`` + :param target_update_interval: update the target network every ``target_network_update_freq`` gradient steps. - :param target_entropy: (str or float) target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) - :param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE) + :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) - :param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling + :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts) - :param create_eval_env: (bool) Whether to create a second environment that will be + :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug - :param seed: (int) Seed for the pseudo random generators - :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. - :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param _init_setup_model: Whether or not to build the network at the creation of the instance """ def __init__( @@ -274,171 +273,6 @@ class TQC(OffPolicyAlgorithm): if len(ent_coef_losses) > 0: logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) - def pretrain( - self, - gradient_steps: int, - batch_size: int = 64, - n_action_samples: int = -1, - target_update_interval: int = 1, - tau: float = 0.005, - strategy: str = "exp", - reduce: str = "mean", - exp_temperature: float = 1.0, - off_policy_update_freq: int = -1, - ) -> None: - """ - Pretrain with Critic Regularized Regression (CRR) - Paper: https://arxiv.org/abs/2006.15134 - """ - # Update optimizers learning rate - optimizers = [self.actor.optimizer, self.critic.optimizer] - if self.ent_coef_optimizer is not None: - optimizers += [self.ent_coef_optimizer] - - # Update learning rate according to lr schedule - self._update_learning_rate(optimizers) - - actor_losses, critic_losses = [], [] - - for gradient_step in tqdm(range(gradient_steps)): - - if off_policy_update_freq > 0 and gradient_step % off_policy_update_freq == 0: - self.train(gradient_steps=1, batch_size=batch_size) - continue - - # Sample replay buffer - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) - - # We need to sample because `log_std` may have changed between two gradient steps - if self.use_sde: - self.actor.reset_noise() - - # Action by the current actor for the sampled state - _, log_prob = self.actor.action_log_prob(replay_data.observations) - log_prob = log_prob.reshape(-1, 1) - - ent_coef_loss = None - if self.ent_coef_optimizer is not None: - # Important: detach the variable from the graph - # so we don't change it with other losses - # see https://github.com/rail-berkeley/softlearning/issues/60 - ent_coef = th.exp(self.log_ent_coef.detach()) - ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() - else: - ent_coef = self.ent_coef_tensor - - self.replay_buffer.ent_coef = ent_coef.item() - - # Optimize entropy coefficient, also called - # entropy temperature or alpha in the paper - if ent_coef_loss is not None: - self.ent_coef_optimizer.zero_grad() - ent_coef_loss.backward() - self.ent_coef_optimizer.step() - - with th.no_grad(): - top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics - # Select action according to policy - next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) - # Compute and cut quantiles at the next state - # batch x nets x quantiles - next_z = self.critic_target(replay_data.next_observations, next_actions) - sorted_z, _ = th.sort(next_z.reshape(batch_size, -1)) - sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop] - - target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1) - # td error + entropy term - q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q - - # Get current Q estimates - # using action from the replay buffer - current_z = self.critic(replay_data.observations, replay_data.actions) - # Compute critic loss - critic_loss = self.quantile_huber_loss(current_z, q_backup) - critic_losses.append(critic_loss.item()) - - # Optimize the critic - self.critic.optimizer.zero_grad() - critic_loss.backward() - self.critic.optimizer.step() - - if strategy == "bc": - # Behavior cloning - weight = 1 - else: - # Tensor version: TODO: check that the reshape works as expected - # cleaner but not faster on cpu for large batch size - # with th.no_grad(): - # # Q-value for the action in the buffer - # qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True) - # # Create tensor to avoid loop - # # Note: For SDE, we need to sample several matrices - # obs_ = replay_data.observations.repeat(n_action_samples, 1) - # if self.use_sde: - # self.actor.reset_noise(batch_size * n_action_samples) - # actions_pi, _ = self.actor.action_log_prob(obs_) - # qf_pi = self.critic(obs_, actions_pi.detach()).mean(2).mean(1, keepdim=True) - # # Agregate: reduce mean or reduce max - # if reduce == "max": - # _, qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).max(axis=0) - # else: - # qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).mean(axis=0) - with th.no_grad(): - qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True) - - # Use the mean (as done in AWAC, cf rlkit) - if n_action_samples == -1: - actions_pi = self.actor.forward(replay_data.observations, deterministic=True) - qf_agg = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True) - else: - qf_agg = None - for _ in range(n_action_samples): - if self.use_sde: - self.actor.reset_noise() - actions_pi, _ = self.actor.action_log_prob(replay_data.observations) - - qf_pi = self.critic(replay_data.observations, actions_pi.detach()).mean(2).mean(1, keepdim=True) - if qf_agg is None: - if reduce == "max": - qf_agg = qf_pi - else: - qf_agg = qf_pi / n_action_samples - else: - if reduce == "max": - qf_agg = th.max(qf_pi, qf_agg) - else: - qf_agg += qf_pi / n_action_samples - - advantage = qf_buffer - qf_agg - if strategy == "binary": - # binary advantage - weight = advantage > 0 - else: - # exp advantage - exp_clip = 20.0 - weight = th.clamp(th.exp(advantage / exp_temperature), 0.0, exp_clip) - - # Log prob by the current actor for the sampled state and action - log_prob = self.actor.evaluate_actions(replay_data.observations, replay_data.actions) - log_prob = log_prob.reshape(-1, 1) - - # weigthed regression loss (close to policy gradient loss) - actor_loss = (-log_prob * weight).mean() - # actor_loss = ((actions_pi - replay_data.actions * weight) ** 2).mean() - actor_losses.append(actor_loss.item()) - - # Optimize the actor - self.actor.optimizer.zero_grad() - actor_loss.backward() - self.actor.optimizer.step() - - # Update target networks - if gradient_step % target_update_interval == 0: - polyak_update(self.critic.parameters(), self.critic_target.parameters(), tau) - - if self.use_sde: - print(f"std={(self.actor.get_std()).mean().item()}") - def learn( self, total_timesteps: int, @@ -469,7 +303,7 @@ class TQC(OffPolicyAlgorithm): Returns the names of the parameters that should be excluded by default when saving the model. - :return: (List[str]) List of parameters that should be excluded from save + :return: List of parameters that should be excluded from save """ # Exclude aliases return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 63aaf3d..37f1777 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -0.9.0a2 +0.10.0a0 diff --git a/setup.py b/setup.py index 41055cd..f5adfad 100644 --- a/setup.py +++ b/setup.py @@ -23,11 +23,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3[tests,docs]>=0.9.0a0", - # For progress bar when using CRR - "tqdm" - # Enable CMA - # "cma", + "stable_baselines3[tests,docs]>=0.9.0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_run.py b/tests/test_run.py index 2cbcf41..c2bc941 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -14,41 +14,27 @@ def test_tqc(ent_coef): create_eval_env=True, ent_coef=ent_coef, ) - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=500, eval_freq=250) @pytest.mark.parametrize("n_critics", [1, 3]) def test_n_critics(n_critics): # Test TQC with different number of critics model = TQC( - "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1 + "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1 ) - model.learn(total_timesteps=1000) + model.learn(total_timesteps=500) -# "CartPole-v1" -# @pytest.mark.parametrize("env_id", ["MountainCarContinuous-v0"]) -# def test_cmaes(env_id): -# if CMAES is None: -# return -# model = CMAES("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[64]), verbose=1, create_eval_env=True) -# model.learn(total_timesteps=50000, eval_freq=10000) - - -@pytest.mark.parametrize("strategy", ["exp", "bc", "binary"]) -@pytest.mark.parametrize("reduce", ["mean", "max"]) -def test_crr(tmp_path, strategy, reduce): +def test_sde(): model = TQC( "MlpPolicy", "Pendulum-v0", - policy_kwargs=dict(net_arch=[64]), - learning_starts=1000, + policy_kwargs=dict(net_arch=[64], sde_net_arch=[8]), + use_sde=True, + learning_starts=100, verbose=1, - create_eval_env=True, - action_noise=None, - use_sde=False, ) - - model.learn(total_timesteps=1000, eval_freq=0) - for n_action_samples in [1, 2, -1]: - model.pretrain(gradient_steps=32, batch_size=32, n_action_samples=n_action_samples, strategy=strategy, reduce=reduce) + model.learn(total_timesteps=500) + model.policy.reset_noise() + model.policy.actor.get_std() diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 4415267..9a73c77 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,6 +1,5 @@ import os import pathlib -import warnings from collections import OrderedDict from copy import deepcopy @@ -11,7 +10,6 @@ import torch as th from stable_baselines3 import DQN from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox -from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv @@ -45,7 +43,7 @@ def test_save_load(tmp_path, model_class): # create model model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1) - model.learn(total_timesteps=500, eval_freq=250) + model.learn(total_timesteps=500) env.reset() observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) @@ -154,7 +152,7 @@ def test_save_load(tmp_path, model_class): assert np.allclose(selected_actions, new_selected_actions, 1e-4) # check if learn still works - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=500) del model @@ -177,17 +175,17 @@ def test_set_env(model_class): # create model model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16])) # learn - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=300) # change env model.set_env(env2) # learn again - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=300) # change env test wrapping model.set_env(env3) # learn again - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=300) @pytest.mark.parametrize("model_class", MODEL_LIST) @@ -247,45 +245,6 @@ def test_save_load_replay_buffer(tmp_path, model_class): ) -@pytest.mark.parametrize("model_class", [TQC]) -@pytest.mark.parametrize("optimize_memory_usage", [False, True]) -def test_warn_buffer(recwarn, model_class, optimize_memory_usage): - """ - When using memory efficient replay buffer, - a warning must be emitted when calling `.learn()` - multiple times. - See https://github.com/DLR-RM/stable-baselines3/issues/46 - """ - # remove gym warnings - warnings.filterwarnings(action="ignore", category=DeprecationWarning) - warnings.filterwarnings(action="ignore", category=UserWarning, module="gym") - - model = model_class( - "MlpPolicy", - select_env(model_class), - buffer_size=100, - optimize_memory_usage=optimize_memory_usage, - policy_kwargs=dict(net_arch=[64]), - learning_starts=10, - ) - - model.learn(150) - - model.learn(150, reset_num_timesteps=False) - - # Check that there is no warning - assert len(recwarn) == 0 - - model.learn(150) - - if optimize_memory_usage: - assert len(recwarn) == 1 - warning = recwarn.pop(UserWarning) - assert "The last trajectory in the replay buffer will be truncated" in str(warning.message) - else: - assert len(recwarn) == 0 - - @pytest.mark.parametrize("model_class", MODEL_LIST) @pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) def test_save_load_policy(tmp_path, model_class, policy_str): @@ -309,7 +268,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str): # create model model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs) - model.learn(total_timesteps=500, eval_freq=250) + model.learn(total_timesteps=500) env.reset() observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) @@ -375,52 +334,3 @@ def test_save_load_policy(tmp_path, model_class, policy_str): os.remove(tmp_path / "policy.pkl") if actor_class is not None: os.remove(tmp_path / "actor.pkl") - - -@pytest.mark.parametrize("pathtype", [str, pathlib.Path]) -def test_open_file_str_pathlib(tmp_path, pathtype): - # check that suffix isn't added because we used open_path first - with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1: - save_to_pkl(fp1, "foo") - assert fp1.closed - with pytest.warns(None) as record: - assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo" - assert not record - - # test custom suffix - with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1: - save_to_pkl(fp1, "foo") - assert fp1.closed - with pytest.warns(None) as record: - assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo" - assert not record - - # test without suffix - with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1: - save_to_pkl(fp1, "foo") - assert fp1.closed - with pytest.warns(None) as record: - assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo" - assert not record - - # test that a warning is raised when the path doesn't exist - with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1: - save_to_pkl(fp1, "foo") - assert fp1.closed - with pytest.warns(None) as record: - assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo" - assert len(record) == 0 - - with pytest.warns(None) as record: - assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo" - assert len(record) == 1 - - fp = pathlib.Path(f"{tmp_path}/t2").open("w") - fp.write("rubbish") - fp.close() - # test that a warning is only raised when verbose = 0 - with pytest.warns(None) as record: - open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close() - open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close() - open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close() - assert len(record) == 1