diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 62770eb..228ee93 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -29,15 +29,16 @@ jobs:
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ # Install dependencies for docs and tests
pip install stable_baselines3[extra,tests,docs]
# Install master version
pip install git+https://github.com/DLR-RM/stable-baselines3
pip install .
# Use headless version
pip install opencv-python-headless
- # - name: Build the doc
- # run: |
- # make doc
+ - name: Build the doc
+ run: |
+ make doc
- name: Type check
run: |
make type
diff --git a/Makefile b/Makefile
index 639d266..d740b60 100644
--- a/Makefile
+++ b/Makefile
@@ -28,6 +28,11 @@ check-codestyle:
commit-checks: format type lint
+doc:
+ cd docs && make html
+
+spelling:
+ cd docs && make spelling
# PyPi package release
release:
@@ -41,4 +46,4 @@ test-release:
python setup.py bdist_wheel
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
-.PHONY: lint format check-codestyle commit-checks
+.PHONY: lint format check-codestyle commit-checks doc spelling
diff --git a/README.md b/README.md
index 36cdb24..7280244 100644
--- a/README.md
+++ b/README.md
@@ -18,3 +18,19 @@ Implemented:
```
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
```
+
+
+## Citing the Project
+
+To cite this repository in publications (please cite SB3 directly):
+
+```
+@misc{stable-baselines3,
+ author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
+ title = {Stable Baselines3},
+ year = {2019},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
+}
+```
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000..47f98cd
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line.
+SPHINXOPTS = -W # make warnings fatal
+SPHINXBUILD = sphinx-build
+SPHINXPROJ = StableBaselines
+SOURCEDIR = .
+BUILDDIR = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
\ No newline at end of file
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 0000000..1427a79
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,25 @@
+## Stable Baselines3 Documentation
+
+This folder contains documentation for the RL baselines.
+
+
+### Build the Documentation
+
+#### Install Sphinx and Theme
+
+```
+pip install sphinx sphinx-autobuild sphinx-rtd-theme
+```
+
+#### Building the Docs
+
+In the `docs/` folder:
+```
+make html
+```
+
+if you want to building each time a file is changed:
+
+```
+sphinx-autobuild . _build/html
+```
diff --git a/docs/_static/css/baselines_theme.css b/docs/_static/css/baselines_theme.css
new file mode 100644
index 0000000..450864e
--- /dev/null
+++ b/docs/_static/css/baselines_theme.css
@@ -0,0 +1,61 @@
+/* Main colors adapted from pytorch doc */
+:root{
+ --main-bg-color: #343A40;
+ --link-color: #FD7E14;
+}
+
+/* Header fonts y */
+h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
+ font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
+}
+
+
+/* Docs background */
+.wy-side-nav-search{
+ background-color: var(--main-bg-color);
+}
+
+/* Mobile version */
+.wy-nav-top{
+ background-color: var(--main-bg-color);
+}
+
+/* Change link colors (except for the menu) */
+a {
+ color: var(--link-color);
+}
+
+a:hover {
+ color: #4F778F;
+}
+
+.wy-menu a {
+ color: #b3b3b3;
+}
+
+.wy-menu a:hover {
+ color: #b3b3b3;
+}
+
+a.icon.icon-home {
+ color: #b3b3b3;
+}
+
+.version{
+ color: var(--link-color) !important;
+}
+
+
+/* Make code blocks have a background */
+.codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] {
+ background: #f8f8f8;;
+}
+
+/* Change style of types in the docstrings .rst-content .field-list */
+.field-list .xref.py.docutils, .field-list code.docutils, .field-list .docutils.literal.notranslate
+{
+ border: None;
+ padding-left: 0;
+ padding-right: 0;
+ color: #404040;
+}
diff --git a/docs/_static/img/colab-badge.svg b/docs/_static/img/colab-badge.svg
new file mode 100644
index 0000000..c08066e
--- /dev/null
+++ b/docs/_static/img/colab-badge.svg
@@ -0,0 +1 @@
+
diff --git a/docs/_static/img/colab.svg b/docs/_static/img/colab.svg
new file mode 100644
index 0000000..c2d30e9
--- /dev/null
+++ b/docs/_static/img/colab.svg
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/_static/img/logo.png b/docs/_static/img/logo.png
new file mode 100644
index 0000000..7cec1c3
Binary files /dev/null and b/docs/_static/img/logo.png differ
diff --git a/docs/_static/img/logo.png~ b/docs/_static/img/logo.png~
new file mode 100644
index 0000000..ef592f5
Binary files /dev/null and b/docs/_static/img/logo.png~ differ
diff --git a/docs/conda_env.yml b/docs/conda_env.yml
new file mode 100644
index 0000000..35e604d
--- /dev/null
+++ b/docs/conda_env.yml
@@ -0,0 +1,18 @@
+name: root
+channels:
+ - pytorch
+ - defaults
+dependencies:
+ - cpuonly=1.0=0
+ - pip=20.2
+ - python=3.6
+ - pytorch=1.5.0=py3.6_cpu_0
+ - pip:
+ - gym>=0.17.2
+ - cloudpickle
+ - opencv-python-headless
+ - pandas
+ - numpy
+ - matplotlib
+ - sphinx_autodoc_typehints
+ - stable-baselines3>=0.9.0
diff --git a/docs/conf.py b/docs/conf.py
new file mode 100644
index 0000000..ed733df
--- /dev/null
+++ b/docs/conf.py
@@ -0,0 +1,219 @@
+# -*- coding: utf-8 -*-
+#
+# Configuration file for the Sphinx documentation builder.
+#
+# This file does only contain a selection of the most common options. For a
+# full list see the documentation:
+# http://www.sphinx-doc.org/en/master/config
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import sys
+from unittest.mock import MagicMock
+
+# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
+# PyEnchant.
+try:
+ import sphinxcontrib.spelling # noqa: F401
+
+ enable_spell_check = True
+except ImportError:
+ enable_spell_check = False
+
+# source code directory, relative to this file, for sphinx-autobuild
+sys.path.insert(0, os.path.abspath(".."))
+
+
+class Mock(MagicMock):
+ __subclasses__ = []
+
+ @classmethod
+ def __getattr__(cls, name):
+ return MagicMock()
+
+
+# Mock modules that requires C modules
+# Note: because of that we cannot test examples using CI
+# 'torch', 'torch.nn', 'torch.nn.functional',
+# DO not mock modules for now, we will need to do that for read the docs later
+MOCK_MODULES = []
+sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
+
+# Read version from file
+version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt")
+with open(version_file, "r") as file_handler:
+ __version__ = file_handler.read().strip()
+
+# -- Project information -----------------------------------------------------
+
+project = "Stable Baselines3 - Contrib"
+copyright = "2020, Stable Baselines3"
+author = "Stable Baselines3 Contributors"
+
+# The short X.Y version
+version = "master (" + __version__ + " )"
+# The full version, including alpha/beta/rc tags
+release = __version__
+
+
+# -- General configuration ---------------------------------------------------
+
+# If your documentation needs a minimal Sphinx version, state it here.
+#
+# needs_sphinx = '1.0'
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ "sphinx.ext.autodoc",
+ "sphinx_autodoc_typehints",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.mathjax",
+ "sphinx.ext.ifconfig",
+ "sphinx.ext.viewcode",
+ # 'sphinx.ext.intersphinx',
+ # 'sphinx.ext.doctest'
+]
+
+if enable_spell_check:
+ extensions.append("sphinxcontrib.spelling")
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# The suffix(es) of source filenames.
+# You can specify multiple suffix as a list of string:
+#
+# source_suffix = ['.rst', '.md']
+source_suffix = ".rst"
+
+# The master toctree document.
+master_doc = "index"
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = None
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path .
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
+
+# The name of the Pygments (syntax highlighting) style to use.
+pygments_style = "sphinx"
+
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+
+# Fix for read the docs
+on_rtd = os.environ.get("READTHEDOCS") == "True"
+if on_rtd:
+ html_theme = "default"
+else:
+ html_theme = "sphinx_rtd_theme"
+
+html_logo = "_static/img/logo.png"
+
+
+def setup(app):
+ app.add_css_file("css/baselines_theme.css")
+
+
+# Theme options are theme-specific and customize the look and feel of a theme
+# further. For a list of options available for each theme, see the
+# documentation.
+#
+# html_theme_options = {}
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ["_static"]
+
+# Custom sidebar templates, must be a dictionary that maps document names
+# to template names.
+#
+# The default sidebars (for documents that don't match any pattern) are
+# defined by theme itself. Builtin themes are using these templates by
+# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
+# 'searchbox.html']``.
+#
+# html_sidebars = {}
+
+
+# -- Options for HTMLHelp output ---------------------------------------------
+
+# Output file base name for HTML help builder.
+htmlhelp_basename = "StableBaselines3doc"
+
+
+# -- Options for LaTeX output ------------------------------------------------
+
+latex_elements = {
+ # The paper size ('letterpaper' or 'a4paper').
+ #
+ # 'papersize': 'letterpaper',
+ # The font size ('10pt', '11pt' or '12pt').
+ #
+ # 'pointsize': '10pt',
+ # Additional stuff for the LaTeX preamble.
+ #
+ # 'preamble': '',
+ # Latex figure (float) alignment
+ #
+ # 'figure_align': 'htbp',
+}
+
+# Grouping the document tree into LaTeX files. List of tuples
+# (source start file, target name, title,
+# author, documentclass [howto, manual, or own class]).
+latex_documents = [
+ (master_doc, "StableBaselines3.tex", "Stable Baselines3 Documentation", "Stable Baselines3 Contributors", "manual"),
+]
+
+
+# -- Options for manual page output ------------------------------------------
+
+# One entry per manual page. List of tuples
+# (source start file, name, description, authors, manual section).
+man_pages = [(master_doc, "stablebaselines3", "Stable Baselines3 Documentation", [author], 1)]
+
+
+# -- Options for Texinfo output ----------------------------------------------
+
+# Grouping the document tree into Texinfo files. List of tuples
+# (source start file, target name, title, author,
+# dir menu entry, description, category)
+texinfo_documents = [
+ (
+ master_doc,
+ "StableBaselines3",
+ "Stable Baselines3 Documentation",
+ author,
+ "StableBaselines3",
+ "One line description of project.",
+ "Miscellaneous",
+ ),
+]
+
+
+# -- Extension configuration -------------------------------------------------
+
+# Example configuration for intersphinx: refer to the Python standard library.
+# intersphinx_mapping = {
+# 'python': ('https://docs.python.org/3/', None),
+# 'numpy': ('http://docs.scipy.org/doc/numpy/', None),
+# 'torch': ('http://pytorch.org/docs/master/', None),
+# }
diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst
new file mode 100644
index 0000000..1d863ea
--- /dev/null
+++ b/docs/guide/algos.rst
@@ -0,0 +1,30 @@
+RL Algorithms
+=============
+
+This table displays the rl algorithms that are implemented in the Stable Baselines3 contrib project,
+along with some useful characteristics: support for discrete/continuous actions, multiprocessing.
+
+..
+.. ============ =========== ============ ================= =============== ================
+.. Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
+.. ============ =========== ============ ================= =============== ================
+.. A2C ✔️ ✔️ ✔️ ✔️ ✔️
+.. DDPG ✔️ ❌ ❌ ❌ ❌
+.. DQN ❌ ✔️ ❌ ❌ ❌
+.. PPO ✔️ ✔️ ✔️ ✔️ ✔️
+.. SAC ✔️ ❌ ❌ ❌ ❌
+.. TD3 ✔️ ❌ ❌ ❌ ❌
+.. ============ =========== ============ ================= =============== ================
+
+
+.. .. note::
+.. Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
+..
+.. Actions ``gym.spaces``:
+..
+.. - ``Box``: A N-dimensional box that contains every point in the action
+.. space.
+.. - ``Discrete``: A list of possible actions, where each timestep only
+.. one of the actions can be used.
+.. - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used.
+.. - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.
diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst
new file mode 100644
index 0000000..65b2d08
--- /dev/null
+++ b/docs/guide/examples.rst
@@ -0,0 +1,40 @@
+.. _examples:
+
+Examples
+========
+
+WIP
+
+.. PyBullet: Normalizing input features
+.. ------------------------------------
+..
+.. Normalizing input features may be essential to successful training of an RL agent
+.. (by default, images are scaled but not other types of input),
+.. for instance when training on `PyBullet `__ environments. For that, a wrapper exists and
+.. will compute a running average and standard deviation of input features (it can do the same for rewards).
+..
+
+.. .. note::
+..
+.. you need to install pybullet with ``pip install pybullet``
+..
+..
+.. .. image:: ../_static/img/colab-badge.svg
+.. :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
+..
+..
+.. .. code-block:: python
+..
+.. import gym
+.. import pybullet_envs
+..
+.. from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
+.. from stable_baselines3 import PPO
+..
+.. env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
+.. # Automatically normalize the input features and reward
+.. env = VecNormalize(env, norm_obs=True, norm_reward=True,
+.. clip_obs=10.)
+..
+.. model = PPO('MlpPolicy', env)
+.. model.learn(total_timesteps=2000)
diff --git a/docs/guide/install.rst b/docs/guide/install.rst
new file mode 100644
index 0000000..dcffed8
--- /dev/null
+++ b/docs/guide/install.rst
@@ -0,0 +1,37 @@
+.. _install:
+
+Installation
+============
+
+Prerequisites
+-------------
+
+Please read `Stable-Baselines3 installation guide `_ first.
+
+
+Stable Release
+~~~~~~~~~~~~~~
+To install Stable Baselines3 contrib with pip, execute:
+
+.. code-block:: bash
+
+ pip install sb3-contrib
+
+
+Bleeding-edge version
+---------------------
+
+.. code-block:: bash
+
+ pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/
+
+
+Development version
+-------------------
+
+To contribute to Stable-Baselines3, with support for running tests and building the documentation.
+
+.. code-block:: bash
+
+ git clone https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/ && cd stable-baselines3-contrib
+ pip install -e .
diff --git a/docs/index.rst b/docs/index.rst
new file mode 100644
index 0000000..b77a85f
--- /dev/null
+++ b/docs/index.rst
@@ -0,0 +1,78 @@
+.. Stable Baselines3 documentation master file, created by
+ sphinx-quickstart on Thu Sep 26 11:06:54 2019.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+Welcome to Stable Baselines3 Contrib docs!
+==========================================
+
+Contrib package for `Stable Baselines3 `_ - Experimental code.
+
+
+Github repository: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
+
+SB3 repository: https://github.com/DLR-RM/stable-baselines3
+
+RL Baselines3 Zoo (collection of pre-trained agents): https://github.com/DLR-RM/rl-baselines3-zoo
+
+RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning.
+
+
+.. toctree::
+ :maxdepth: 2
+ :caption: User Guide
+
+ guide/install
+ guide/algos
+ guide/examples
+
+
+.. .. toctree::
+.. :maxdepth: 1
+.. :caption: RL Algorithms
+..
+.. modules/a2c
+
+.. .. toctree::
+.. :maxdepth: 1
+.. :caption: Common
+..
+.. common/atari_wrappers
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Misc
+
+ misc/changelog
+
+
+Citing Stable Baselines3
+------------------------
+To cite this project in publications:
+
+.. code-block:: bibtex
+
+ @misc{stable-baselines3,
+ author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
+ title = {Stable Baselines3},
+ year = {2019},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
+ }
+
+Contributing
+------------
+
+To any interested in making the rl baselines better, there are still some improvements
+that need to be done.
+You can check issues in the `repo `_.
+
+If you want to contribute, please read `CONTRIBUTING.md `_ first.
+
+Indices and tables
+-------------------
+
+* :ref:`genindex`
+* :ref:`search`
+* :ref:`modindex`
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 0000000..22b5fff
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,36 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=.
+set BUILDDIR=_build
+set SPHINXPROJ=StableBaselines
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
+
+:end
+popd
diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst
new file mode 100644
index 0000000..6cb9a38
--- /dev/null
+++ b/docs/misc/changelog.rst
@@ -0,0 +1,42 @@
+.. _changelog:
+
+Changelog
+==========
+
+
+Pre-Release 0.10.0a0 (WIP)
+------------------------------
+
+Breaking Changes:
+^^^^^^^^^^^^^^^^^
+
+New Features:
+^^^^^^^^^^^^^
+
+Bug Fixes:
+^^^^^^^^^^
+
+Deprecations:
+^^^^^^^^^^^^^
+
+Others:
+^^^^^^^
+
+Documentation:
+^^^^^^^^^^^^^^
+
+
+Maintainers
+-----------
+
+Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a),
+`Maximilian Ernestus`_ (aka @erniejunior), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
+
+.. _Ashley Hill: https://github.com/hill-a
+.. _Antonin Raffin: https://araffin.github.io/
+.. _Maximilian Ernestus: https://github.com/erniejunior
+.. _Adam Gleave: https://gleave.me/
+.. _@araffin: https://github.com/araffin
+.. _@AdamGleave: https://github.com/adamgleave
+.. _Anssi Kanervisto: https://github.com/Miffyli
+.. _@Miffyli: https://github.com/Miffyli
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
new file mode 100644
index 0000000..42669bf
--- /dev/null
+++ b/docs/spelling_wordlist.txt
@@ -0,0 +1,121 @@
+py
+env
+atari
+argparse
+Argparse
+TensorFlow
+feedforward
+envs
+VecEnv
+pretrain
+petrained
+tf
+th
+nn
+np
+str
+mujoco
+cpu
+ndarray
+ndarrays
+timestep
+timesteps
+stepsize
+dataset
+adam
+fn
+normalisation
+Kullback
+Leibler
+boolean
+deserialized
+pretrained
+minibatch
+subprocesses
+ArgumentParser
+Tensorflow
+Gaussian
+approximator
+minibatches
+hyperparameters
+hyperparameter
+vectorized
+rl
+colab
+dataloader
+npz
+datasets
+vf
+logits
+num
+Utils
+backpropagate
+prepend
+NaN
+preprocessing
+Cloudpickle
+async
+multiprocess
+tensorflow
+mlp
+cnn
+neglogp
+tanh
+coef
+repo
+Huber
+params
+ppo
+arxiv
+Arxiv
+func
+DQN
+Uhlenbeck
+Ornstein
+multithread
+cancelled
+Tensorboard
+parallelize
+customising
+serializable
+Multiprocessed
+cartpole
+toolset
+lstm
+rescale
+ffmpeg
+avconv
+unnormalized
+Github
+pre
+preprocess
+backend
+attr
+preprocess
+Antonin
+Raffin
+araffin
+Homebrew
+Numpy
+Theano
+rollout
+kfac
+Piecewise
+csv
+nvidia
+visdom
+tensorboard
+preprocessed
+namespace
+sklearn
+GoalEnv
+Torchy
+pytorch
+dicts
+optimizers
+Deprecations
+forkserver
+cuda
+Polyak
+gSDE
+rollouts
diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py
index e58cb11..f728c3a 100644
--- a/sb3_contrib/tqc/policies.py
+++ b/sb3_contrib/tqc/policies.py
@@ -17,25 +17,25 @@ class Actor(BasePolicy):
"""
Actor network (policy) for TQC.
- :param observation_space: (gym.spaces.Space) Obervation space
- :param action_space: (gym.spaces.Space) Action space
- :param net_arch: ([int]) Network architecture
- :param features_extractor: (nn.Module) Network to extract features
+ :param observation_space: Obervation space
+ :param action_space: Action space
+ :param net_arch: Network architecture
+ :param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
- :param features_dim: (int) Number of features
- :param activation_fn: (Type[nn.Module]) Activation function
- :param use_sde: (bool) Whether to use State Dependent Exploration or not
- :param log_std_init: (float) Initial value for the log standard deviation
- :param full_std: (bool) Whether to use (n_features x n_actions) parameters
+ :param features_dim: Number of features
+ :param activation_fn: Activation function
+ :param use_sde: Whether to use State Dependent Exploration or not
+ :param log_std_init: Initial value for the log standard deviation
+ :param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE.
- :param sde_net_arch: ([int]) Network architecture for extracting features
+ :param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
- :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
+ :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
- :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
- :param normalize_images: (bool) Whether to normalize images or not,
+ :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
+ :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
@@ -131,7 +131,7 @@ class Actor(BasePolicy):
but is slightly different when using ``expln`` function
(cf StateDependentNoiseDistribution doc).
- :return: (th.Tensor)
+ :return:
"""
msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
@@ -141,7 +141,7 @@ class Actor(BasePolicy):
"""
Sample new weights for the exploration matrix, when using gSDE.
- :param batch_size: (int)
+ :param batch_size:
"""
msg = "reset_noise() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
@@ -151,8 +151,8 @@ class Actor(BasePolicy):
"""
Get the parameters for the action distribution.
- :param obs: (th.Tensor)
- :return: (Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]])
+ :param obs:
+ :return:
Mean, standard deviation and optional keyword arguments.
"""
features = self.extract_features(obs)
@@ -183,31 +183,19 @@ class Actor(BasePolicy):
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.forward(observation, deterministic)
- def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
- """
- Evaluate actions according to the current policy,
- given the observations. Only useful when using SDE.
- :param obs: (th.Tensor)
- :param actions: (th.Tensor)
- :return: (th.Tensor) log likelihood of taking those actions
- """
- mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
- self.action_dist.proba_distribution(mean_actions, log_std, **kwargs)
- return self.action_dist.log_prob(actions)
-
class Critic(BaseModel):
"""
Critic network (q-value function) for TQC.
- :param observation_space: (gym.spaces.Space) Obervation space
- :param action_space: (gym.spaces.Space) Action space
- :param net_arch: ([int]) Network architecture
- :param features_extractor: (nn.Module) Network to extract features
+ :param observation_space: Obervation space
+ :param action_space: Action space
+ :param net_arch: Network architecture
+ :param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
- :param features_dim: (int) Number of features
- :param activation_fn: (Type[nn.Module]) Activation function
- :param normalize_images: (bool) Whether to normalize images or not,
+ :param features_dim: Number of features
+ :param activation_fn: Activation function
+ :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
@@ -257,28 +245,28 @@ class TQCPolicy(BasePolicy):
"""
Policy class (with both actor and critic) for TQC.
- :param observation_space: (gym.spaces.Space) Observation space
- :param action_space: (gym.spaces.Space) Action space
- :param lr_schedule: (callable) Learning rate schedule (could be constant)
- :param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
- :param activation_fn: (Type[nn.Module]) Activation function
- :param use_sde: (bool) Whether to use State Dependent Exploration or not
- :param log_std_init: (float) Initial value for the log standard deviation
- :param sde_net_arch: ([int]) Network architecture for extracting features
+ :param observation_space: Observation space
+ :param action_space: Action space
+ :param lr_schedule: Learning rate schedule (could be constant)
+ :param net_arch: The specification of the policy and value networks.
+ :param activation_fn: Activation function
+ :param use_sde: Whether to use State Dependent Exploration or not
+ :param log_std_init: Initial value for the log standard deviation
+ :param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
- :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
+ :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
- :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
- :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
- :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
+ :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
+ :param features_extractor_class: Features extractor to use.
+ :param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
- :param normalize_images: (bool) Whether to normalize images or not,
+ :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
- :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
+ :param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
- :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
+ :param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
@@ -388,7 +376,7 @@ class TQCPolicy(BasePolicy):
"""
Sample new weights for the exploration matrix, when using gSDE.
- :param batch_size: (int)
+ :param batch_size:
"""
self.actor.reset_noise(batch_size=batch_size)
@@ -412,26 +400,26 @@ class CnnPolicy(TQCPolicy):
"""
Policy class (with both actor and critic) for TQC.
- :param observation_space: (gym.spaces.Space) Observation space
- :param action_space: (gym.spaces.Space) Action space
- :param lr_schedule: (callable) Learning rate schedule (could be constant)
- :param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
- :param activation_fn: (Type[nn.Module]) Activation function
- :param use_sde: (bool) Whether to use State Dependent Exploration or not
- :param log_std_init: (float) Initial value for the log standard deviation
- :param sde_net_arch: ([int]) Network architecture for extracting features
+ :param observation_space: Observation space
+ :param action_space: Action space
+ :param lr_schedule: Learning rate schedule (could be constant)
+ :param net_arch: The specification of the policy and value networks.
+ :param activation_fn: Activation function
+ :param use_sde: Whether to use State Dependent Exploration or not
+ :param log_std_init: Initial value for the log standard deviation
+ :param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
- :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
+ :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
- :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
- :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
- :param normalize_images: (bool) Whether to normalize images or not,
+ :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
+ :param features_extractor_class: Features extractor to use.
+ :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
- :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
+ :param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
- :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
+ :param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py
index 18fc902..ea6818e 100644
--- a/sb3_contrib/tqc/tqc.py
+++ b/sb3_contrib/tqc/tqc.py
@@ -7,7 +7,6 @@ from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import polyak_update
-from tqdm import tqdm
from sb3_contrib.tqc.policies import TQCPolicy
@@ -15,48 +14,48 @@ from sb3_contrib.tqc.policies import TQCPolicy
class TQC(OffPolicyAlgorithm):
"""
- Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics
+ Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics.
Paper: https://arxiv.org/abs/2005.04269
- :param policy: (TQCPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
- :param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str)
- :param learning_rate: (float or callable) learning rate for adam optimizer,
+ :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
+ :param env: The environment to learn from (if registered in Gym, can be str)
+ :param learning_rate: learning rate for adam optimizer,
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
it can be a function of the current progress remaining (from 1 to 0)
- :param buffer_size: (int) size of the replay buffer
- :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
- :param batch_size: (int) Minibatch size for each gradient update
- :param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1)
- :param gamma: (float) the discount factor
- :param train_freq: (int) Update the model every ``train_freq`` steps.
- :param gradient_steps: (int) How many gradient update after each step
- :param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes.
+ :param buffer_size: size of the replay buffer
+ :param learning_starts: how many steps of the model to collect transitions for before learning starts
+ :param batch_size: Minibatch size for each gradient update
+ :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
+ :param gamma: the discount factor
+ :param train_freq: Update the model every ``train_freq`` steps.
+ :param gradient_steps: How many gradient update after each step
+ :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
Note that this cannot be used at the same time as ``train_freq``
- :param action_noise: (ActionNoise) the action noise type (None by default), this can help
+ :param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
- :param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer
+ :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
- :param ent_coef: (str or float) Entropy regularization coefficient. (Equivalent to
+ :param ent_coef: Entropy regularization coefficient. (Equivalent to
inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
- :param target_update_interval: (int) update the target network every ``target_network_update_freq``
+ :param target_update_interval: update the target network every ``target_network_update_freq``
gradient steps.
- :param target_entropy: (str or float) target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
- :param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
+ :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
+ :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
- :param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
+ :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
- :param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling
+ :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts)
- :param create_eval_env: (bool) Whether to create a second environment that will be
+ :param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
- :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
- :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug
- :param seed: (int) Seed for the pseudo random generators
- :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
+ :param policy_kwargs: additional arguments to be passed to the policy on creation
+ :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param seed: Seed for the pseudo random generators
+ :param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
- :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
+ :param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
def __init__(
@@ -274,171 +273,6 @@ class TQC(OffPolicyAlgorithm):
if len(ent_coef_losses) > 0:
logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
- def pretrain(
- self,
- gradient_steps: int,
- batch_size: int = 64,
- n_action_samples: int = -1,
- target_update_interval: int = 1,
- tau: float = 0.005,
- strategy: str = "exp",
- reduce: str = "mean",
- exp_temperature: float = 1.0,
- off_policy_update_freq: int = -1,
- ) -> None:
- """
- Pretrain with Critic Regularized Regression (CRR)
- Paper: https://arxiv.org/abs/2006.15134
- """
- # Update optimizers learning rate
- optimizers = [self.actor.optimizer, self.critic.optimizer]
- if self.ent_coef_optimizer is not None:
- optimizers += [self.ent_coef_optimizer]
-
- # Update learning rate according to lr schedule
- self._update_learning_rate(optimizers)
-
- actor_losses, critic_losses = [], []
-
- for gradient_step in tqdm(range(gradient_steps)):
-
- if off_policy_update_freq > 0 and gradient_step % off_policy_update_freq == 0:
- self.train(gradient_steps=1, batch_size=batch_size)
- continue
-
- # Sample replay buffer
- replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
-
- # We need to sample because `log_std` may have changed between two gradient steps
- if self.use_sde:
- self.actor.reset_noise()
-
- # Action by the current actor for the sampled state
- _, log_prob = self.actor.action_log_prob(replay_data.observations)
- log_prob = log_prob.reshape(-1, 1)
-
- ent_coef_loss = None
- if self.ent_coef_optimizer is not None:
- # Important: detach the variable from the graph
- # so we don't change it with other losses
- # see https://github.com/rail-berkeley/softlearning/issues/60
- ent_coef = th.exp(self.log_ent_coef.detach())
- ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
- else:
- ent_coef = self.ent_coef_tensor
-
- self.replay_buffer.ent_coef = ent_coef.item()
-
- # Optimize entropy coefficient, also called
- # entropy temperature or alpha in the paper
- if ent_coef_loss is not None:
- self.ent_coef_optimizer.zero_grad()
- ent_coef_loss.backward()
- self.ent_coef_optimizer.step()
-
- with th.no_grad():
- top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics
- # Select action according to policy
- next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
- # Compute and cut quantiles at the next state
- # batch x nets x quantiles
- next_z = self.critic_target(replay_data.next_observations, next_actions)
- sorted_z, _ = th.sort(next_z.reshape(batch_size, -1))
- sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop]
-
- target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1)
- # td error + entropy term
- q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
-
- # Get current Q estimates
- # using action from the replay buffer
- current_z = self.critic(replay_data.observations, replay_data.actions)
- # Compute critic loss
- critic_loss = self.quantile_huber_loss(current_z, q_backup)
- critic_losses.append(critic_loss.item())
-
- # Optimize the critic
- self.critic.optimizer.zero_grad()
- critic_loss.backward()
- self.critic.optimizer.step()
-
- if strategy == "bc":
- # Behavior cloning
- weight = 1
- else:
- # Tensor version: TODO: check that the reshape works as expected
- # cleaner but not faster on cpu for large batch size
- # with th.no_grad():
- # # Q-value for the action in the buffer
- # qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True)
- # # Create tensor to avoid loop
- # # Note: For SDE, we need to sample several matrices
- # obs_ = replay_data.observations.repeat(n_action_samples, 1)
- # if self.use_sde:
- # self.actor.reset_noise(batch_size * n_action_samples)
- # actions_pi, _ = self.actor.action_log_prob(obs_)
- # qf_pi = self.critic(obs_, actions_pi.detach()).mean(2).mean(1, keepdim=True)
- # # Agregate: reduce mean or reduce max
- # if reduce == "max":
- # _, qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).max(axis=0)
- # else:
- # qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).mean(axis=0)
- with th.no_grad():
- qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True)
-
- # Use the mean (as done in AWAC, cf rlkit)
- if n_action_samples == -1:
- actions_pi = self.actor.forward(replay_data.observations, deterministic=True)
- qf_agg = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True)
- else:
- qf_agg = None
- for _ in range(n_action_samples):
- if self.use_sde:
- self.actor.reset_noise()
- actions_pi, _ = self.actor.action_log_prob(replay_data.observations)
-
- qf_pi = self.critic(replay_data.observations, actions_pi.detach()).mean(2).mean(1, keepdim=True)
- if qf_agg is None:
- if reduce == "max":
- qf_agg = qf_pi
- else:
- qf_agg = qf_pi / n_action_samples
- else:
- if reduce == "max":
- qf_agg = th.max(qf_pi, qf_agg)
- else:
- qf_agg += qf_pi / n_action_samples
-
- advantage = qf_buffer - qf_agg
- if strategy == "binary":
- # binary advantage
- weight = advantage > 0
- else:
- # exp advantage
- exp_clip = 20.0
- weight = th.clamp(th.exp(advantage / exp_temperature), 0.0, exp_clip)
-
- # Log prob by the current actor for the sampled state and action
- log_prob = self.actor.evaluate_actions(replay_data.observations, replay_data.actions)
- log_prob = log_prob.reshape(-1, 1)
-
- # weigthed regression loss (close to policy gradient loss)
- actor_loss = (-log_prob * weight).mean()
- # actor_loss = ((actions_pi - replay_data.actions * weight) ** 2).mean()
- actor_losses.append(actor_loss.item())
-
- # Optimize the actor
- self.actor.optimizer.zero_grad()
- actor_loss.backward()
- self.actor.optimizer.step()
-
- # Update target networks
- if gradient_step % target_update_interval == 0:
- polyak_update(self.critic.parameters(), self.critic_target.parameters(), tau)
-
- if self.use_sde:
- print(f"std={(self.actor.get_std()).mean().item()}")
-
def learn(
self,
total_timesteps: int,
@@ -469,7 +303,7 @@ class TQC(OffPolicyAlgorithm):
Returns the names of the parameters that should be excluded by default
when saving the model.
- :return: (List[str]) List of parameters that should be excluded from save
+ :return: List of parameters that should be excluded from save
"""
# Exclude aliases
return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt
index 63aaf3d..37f1777 100644
--- a/sb3_contrib/version.txt
+++ b/sb3_contrib/version.txt
@@ -1 +1 @@
-0.9.0a2
+0.10.0a0
diff --git a/setup.py b/setup.py
index 41055cd..f5adfad 100644
--- a/setup.py
+++ b/setup.py
@@ -23,11 +23,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
- "stable_baselines3[tests,docs]>=0.9.0a0",
- # For progress bar when using CRR
- "tqdm"
- # Enable CMA
- # "cma",
+ "stable_baselines3[tests,docs]>=0.9.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
diff --git a/tests/test_run.py b/tests/test_run.py
index 2cbcf41..c2bc941 100644
--- a/tests/test_run.py
+++ b/tests/test_run.py
@@ -14,41 +14,27 @@ def test_tqc(ent_coef):
create_eval_env=True,
ent_coef=ent_coef,
)
- model.learn(total_timesteps=1000, eval_freq=500)
+ model.learn(total_timesteps=500, eval_freq=250)
@pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics):
# Test TQC with different number of critics
model = TQC(
- "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1
+ "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1
)
- model.learn(total_timesteps=1000)
+ model.learn(total_timesteps=500)
-# "CartPole-v1"
-# @pytest.mark.parametrize("env_id", ["MountainCarContinuous-v0"])
-# def test_cmaes(env_id):
-# if CMAES is None:
-# return
-# model = CMAES("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[64]), verbose=1, create_eval_env=True)
-# model.learn(total_timesteps=50000, eval_freq=10000)
-
-
-@pytest.mark.parametrize("strategy", ["exp", "bc", "binary"])
-@pytest.mark.parametrize("reduce", ["mean", "max"])
-def test_crr(tmp_path, strategy, reduce):
+def test_sde():
model = TQC(
"MlpPolicy",
"Pendulum-v0",
- policy_kwargs=dict(net_arch=[64]),
- learning_starts=1000,
+ policy_kwargs=dict(net_arch=[64], sde_net_arch=[8]),
+ use_sde=True,
+ learning_starts=100,
verbose=1,
- create_eval_env=True,
- action_noise=None,
- use_sde=False,
)
-
- model.learn(total_timesteps=1000, eval_freq=0)
- for n_action_samples in [1, 2, -1]:
- model.pretrain(gradient_steps=32, batch_size=32, n_action_samples=n_action_samples, strategy=strategy, reduce=reduce)
+ model.learn(total_timesteps=500)
+ model.policy.reset_noise()
+ model.policy.actor.get_std()
diff --git a/tests/test_save_load.py b/tests/test_save_load.py
index 4415267..9a73c77 100644
--- a/tests/test_save_load.py
+++ b/tests/test_save_load.py
@@ -1,6 +1,5 @@
import os
import pathlib
-import warnings
from collections import OrderedDict
from copy import deepcopy
@@ -11,7 +10,6 @@ import torch as th
from stable_baselines3 import DQN
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
-from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
@@ -45,7 +43,7 @@ def test_save_load(tmp_path, model_class):
# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
- model.learn(total_timesteps=500, eval_freq=250)
+ model.learn(total_timesteps=500)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
@@ -154,7 +152,7 @@ def test_save_load(tmp_path, model_class):
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# check if learn still works
- model.learn(total_timesteps=1000, eval_freq=500)
+ model.learn(total_timesteps=500)
del model
@@ -177,17 +175,17 @@ def test_set_env(model_class):
# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
# learn
- model.learn(total_timesteps=1000, eval_freq=500)
+ model.learn(total_timesteps=300)
# change env
model.set_env(env2)
# learn again
- model.learn(total_timesteps=1000, eval_freq=500)
+ model.learn(total_timesteps=300)
# change env test wrapping
model.set_env(env3)
# learn again
- model.learn(total_timesteps=1000, eval_freq=500)
+ model.learn(total_timesteps=300)
@pytest.mark.parametrize("model_class", MODEL_LIST)
@@ -247,45 +245,6 @@ def test_save_load_replay_buffer(tmp_path, model_class):
)
-@pytest.mark.parametrize("model_class", [TQC])
-@pytest.mark.parametrize("optimize_memory_usage", [False, True])
-def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
- """
- When using memory efficient replay buffer,
- a warning must be emitted when calling `.learn()`
- multiple times.
- See https://github.com/DLR-RM/stable-baselines3/issues/46
- """
- # remove gym warnings
- warnings.filterwarnings(action="ignore", category=DeprecationWarning)
- warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
-
- model = model_class(
- "MlpPolicy",
- select_env(model_class),
- buffer_size=100,
- optimize_memory_usage=optimize_memory_usage,
- policy_kwargs=dict(net_arch=[64]),
- learning_starts=10,
- )
-
- model.learn(150)
-
- model.learn(150, reset_num_timesteps=False)
-
- # Check that there is no warning
- assert len(recwarn) == 0
-
- model.learn(150)
-
- if optimize_memory_usage:
- assert len(recwarn) == 1
- warning = recwarn.pop(UserWarning)
- assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
- else:
- assert len(recwarn) == 0
-
-
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
def test_save_load_policy(tmp_path, model_class, policy_str):
@@ -309,7 +268,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
# create model
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
- model.learn(total_timesteps=500, eval_freq=250)
+ model.learn(total_timesteps=500)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
@@ -375,52 +334,3 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
os.remove(tmp_path / "policy.pkl")
if actor_class is not None:
os.remove(tmp_path / "actor.pkl")
-
-
-@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
-def test_open_file_str_pathlib(tmp_path, pathtype):
- # check that suffix isn't added because we used open_path first
- with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
- save_to_pkl(fp1, "foo")
- assert fp1.closed
- with pytest.warns(None) as record:
- assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
- assert not record
-
- # test custom suffix
- with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
- save_to_pkl(fp1, "foo")
- assert fp1.closed
- with pytest.warns(None) as record:
- assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
- assert not record
-
- # test without suffix
- with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
- save_to_pkl(fp1, "foo")
- assert fp1.closed
- with pytest.warns(None) as record:
- assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
- assert not record
-
- # test that a warning is raised when the path doesn't exist
- with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
- save_to_pkl(fp1, "foo")
- assert fp1.closed
- with pytest.warns(None) as record:
- assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
- assert len(record) == 0
-
- with pytest.warns(None) as record:
- assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
- assert len(record) == 1
-
- fp = pathlib.Path(f"{tmp_path}/t2").open("w")
- fp.write("rubbish")
- fp.close()
- # test that a warning is only raised when verbose = 0
- with pytest.warns(None) as record:
- open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
- open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
- open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
- assert len(record) == 1