Merge branch 'master' into contrib

This commit is contained in:
Anssi "Miffyli" Kanervisto 2020-10-15 02:07:46 +03:00
commit 8d3570ae5f
25 changed files with 860 additions and 389 deletions

View File

@ -29,15 +29,16 @@ jobs:
python -m pip install --upgrade pip python -m pip install --upgrade pip
# cpu version of pytorch # cpu version of pytorch
pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# Install dependencies for docs and tests
pip install stable_baselines3[extra,tests,docs] pip install stable_baselines3[extra,tests,docs]
# Install master version # Install master version
pip install git+https://github.com/DLR-RM/stable-baselines3 pip install git+https://github.com/DLR-RM/stable-baselines3
pip install . pip install .
# Use headless version # Use headless version
pip install opencv-python-headless pip install opencv-python-headless
# - name: Build the doc - name: Build the doc
# run: | run: |
# make doc make doc
- name: Type check - name: Type check
run: | run: |
make type make type

View File

@ -28,6 +28,11 @@ check-codestyle:
commit-checks: format type lint commit-checks: format type lint
doc:
cd docs && make html
spelling:
cd docs && make spelling
# PyPi package release # PyPi package release
release: release:
@ -41,4 +46,4 @@ test-release:
python setup.py bdist_wheel python setup.py bdist_wheel
twine upload --repository-url https://test.pypi.org/legacy/ dist/* twine upload --repository-url https://test.pypi.org/legacy/ dist/*
.PHONY: lint format check-codestyle commit-checks .PHONY: lint format check-codestyle commit-checks doc spelling

View File

@ -18,3 +18,19 @@ Implemented:
``` ```
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
``` ```
## Citing the Project
To cite this repository in publications (please cite SB3 directly):
```
@misc{stable-baselines3,
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
title = {Stable Baselines3},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
}
```

20
docs/Makefile Normal file
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS = -W # make warnings fatal
SPHINXBUILD = sphinx-build
SPHINXPROJ = StableBaselines
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

25
docs/README.md Normal file
View File

@ -0,0 +1,25 @@
## Stable Baselines3 Documentation
This folder contains documentation for the RL baselines.
### Build the Documentation
#### Install Sphinx and Theme
```
pip install sphinx sphinx-autobuild sphinx-rtd-theme
```
#### Building the Docs
In the `docs/` folder:
```
make html
```
if you want to building each time a file is changed:
```
sphinx-autobuild . _build/html
```

61
docs/_static/css/baselines_theme.css vendored Normal file
View File

@ -0,0 +1,61 @@
/* Main colors adapted from pytorch doc */
:root{
--main-bg-color: #343A40;
--link-color: #FD7E14;
}
/* Header fonts y */
h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
}
/* Docs background */
.wy-side-nav-search{
background-color: var(--main-bg-color);
}
/* Mobile version */
.wy-nav-top{
background-color: var(--main-bg-color);
}
/* Change link colors (except for the menu) */
a {
color: var(--link-color);
}
a:hover {
color: #4F778F;
}
.wy-menu a {
color: #b3b3b3;
}
.wy-menu a:hover {
color: #b3b3b3;
}
a.icon.icon-home {
color: #b3b3b3;
}
.version{
color: var(--link-color) !important;
}
/* Make code blocks have a background */
.codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] {
background: #f8f8f8;;
}
/* Change style of types in the docstrings .rst-content .field-list */
.field-list .xref.py.docutils, .field-list code.docutils, .field-list .docutils.literal.notranslate
{
border: None;
padding-left: 0;
padding-right: 0;
color: #404040;
}

1
docs/_static/img/colab-badge.svg vendored Normal file
View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="117" height="20"><linearGradient id="b" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="a"><rect width="117" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#a)"><path fill="#555" d="M0 0h30v20H0z"/><path fill="#007ec6" d="M30 0h87v20H30z"/><path fill="url(#b)" d="M0 0h117v20H0z"/></g><g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="110"><svg x="4px" y="0px" width="22px" height="20px" viewBox="-2 0 28 24" style="background-color: #fff;border-radius: 1px;"><path style="fill:#ef9008;" d="M1.977,16.77c-2.667-2.277-2.605-7.079,0-9.357C2.919,8.057,3.522,9.075,4.49,9.691c-1.152,1.6-1.146,3.201-0.004,4.803C3.522,15.111,2.918,16.126,1.977,16.77z"/><path style="fill:#fdba18;" d="M12.257,17.114c-1.767-1.633-2.485-3.658-2.118-6.02c0.451-2.91,2.139-4.893,4.946-5.678c2.565-0.718,4.964-0.217,6.878,1.819c-0.884,0.743-1.707,1.547-2.434,2.446C18.488,8.827,17.319,8.435,16,8.856c-2.404,0.767-3.046,3.241-1.494,5.644c-0.241,0.275-0.493,0.541-0.721,0.826C13.295,15.939,12.511,16.3,12.257,17.114z"/><path style="fill:#ef9008;" d="M19.529,9.682c0.727-0.899,1.55-1.703,2.434-2.446c2.703,2.783,2.701,7.031-0.005,9.764c-2.648,2.674-6.936,2.725-9.701,0.115c0.254-0.814,1.038-1.175,1.528-1.788c0.228-0.285,0.48-0.552,0.721-0.826c1.053,0.916,2.254,1.268,3.6,0.83C20.502,14.551,21.151,11.927,19.529,9.682z"/><path style="fill:#fdba18;" d="M4.49,9.691C3.522,9.075,2.919,8.057,1.977,7.413c2.209-2.398,5.721-2.942,8.476-1.355c0.555,0.32,0.719,0.606,0.285,1.128c-0.157,0.188-0.258,0.422-0.391,0.631c-0.299,0.47-0.509,1.067-0.929,1.371C8.933,9.539,8.523,8.847,8.021,8.746C6.673,8.475,5.509,8.787,4.49,9.691z"/><path style="fill:#fdba18;" d="M1.977,16.77c0.941-0.644,1.545-1.659,2.509-2.277c1.373,1.152,2.85,1.433,4.45,0.499c0.332-0.194,0.503-0.088,0.673,0.19c0.386,0.635,0.753,1.285,1.181,1.89c0.34,0.48,0.222,0.715-0.253,1.006C7.84,19.73,4.205,19.188,1.977,16.77z"/></svg><text x="245" y="140" transform="scale(.1)" textLength="30"> </text><text x="725" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="770">Open in Colab</text><text x="725" y="140" transform="scale(.1)" textLength="770">Open in Colab</text></g> </svg>

After

Width:  |  Height:  |  Size: 2.3 KiB

7
docs/_static/img/colab.svg vendored Normal file
View File

@ -0,0 +1,7 @@
<svg width="24px" height="15px" viewBox="0 0 24 15" version="1.1" xmlns="http://www.w3.org/2000/svg">
<path d="M1.977,11.77 C-0.69,9.493 -0.628,4.691 1.977,2.413 C2.919,3.057 3.522,4.075 4.49,4.691 C3.338,6.291 3.344,7.892 4.486,9.494 C3.522,10.111 2.918,11.126 1.977,11.77 Z" fill="#FDBA18"/>
<path d="M12.257,12.114 C10.49,10.481 9.772,8.456 10.139,6.094 C10.59,3.184 12.278,1.201 15.085,0.416 C17.65,-0.302 20.049,0.199 21.963,2.235 C21.079,2.978 20.256,3.782 19.529,4.681 C18.488,3.827 17.319,3.435 16,3.856 C13.596,4.623 12.954,7.097 14.506,9.5 C14.265,9.775 14.013,10.041 13.785,10.326 C13.295,10.939 12.511,11.3 12.257,12.114 Z" fill="#FCD93D"/>
<path d="M19.529,4.682 C20.256,3.783 21.079,2.979 21.963,2.236 C24.666,5.019 24.664,9.267 21.958,12 C19.31,14.674 15.022,14.725 12.257,12.115 C12.511,11.301 13.295,10.94 13.785,10.327 C14.013,10.042 14.265,9.775 14.506,9.501 C15.559,10.417 16.76,10.769 18.106,10.331 C20.502,9.551 21.151,6.927 19.529,4.682 Z" fill="#FDBA18"/>
<path d="M4.49,4.691 C3.522,4.075 2.919,3.057 1.977,2.413 C4.186,0.015 7.698,-0.529 10.453,1.058 C11.008,1.378 11.172,1.664 10.738,2.186 C10.581,2.374 10.48,2.608 10.347,2.817 C10.048,3.287 9.838,3.884 9.418,4.188 C8.933,4.539 8.523,3.847 8.021,3.746 C6.673,3.475 5.509,3.787 4.49,4.691 Z" fill="#FCD93D"/>
<path d="M1.977,11.77 C2.918,11.126 3.522,10.111 4.486,9.493 C5.859,10.645 7.336,10.926 8.936,9.992 C9.268,9.798 9.439,9.904 9.609,10.182 C9.995,10.817 10.362,11.467 10.79,12.072 C11.13,12.552 11.012,12.787 10.537,13.078 C7.84,14.73 4.205,14.188 1.977,11.77 Z" fill="#FCD93D"/>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

BIN
docs/_static/img/logo.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

BIN
docs/_static/img/logo.png~ vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

18
docs/conda_env.yml Normal file
View File

@ -0,0 +1,18 @@
name: root
channels:
- pytorch
- defaults
dependencies:
- cpuonly=1.0=0
- pip=20.2
- python=3.6
- pytorch=1.5.0=py3.6_cpu_0
- pip:
- gym>=0.17.2
- cloudpickle
- opencv-python-headless
- pandas
- numpy
- matplotlib
- sphinx_autodoc_typehints
- stable-baselines3>=0.9.0

219
docs/conf.py Normal file
View File

@ -0,0 +1,219 @@
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
from unittest.mock import MagicMock
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
# PyEnchant.
try:
import sphinxcontrib.spelling # noqa: F401
enable_spell_check = True
except ImportError:
enable_spell_check = False
# source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath(".."))
class Mock(MagicMock):
__subclasses__ = []
@classmethod
def __getattr__(cls, name):
return MagicMock()
# Mock modules that requires C modules
# Note: because of that we cannot test examples using CI
# 'torch', 'torch.nn', 'torch.nn.functional',
# DO not mock modules for now, we will need to do that for read the docs later
MOCK_MODULES = []
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt")
with open(version_file, "r") as file_handler:
__version__ = file_handler.read().strip()
# -- Project information -----------------------------------------------------
project = "Stable Baselines3 - Contrib"
copyright = "2020, Stable Baselines3"
author = "Stable Baselines3 Contributors"
# The short X.Y version
version = "master (" + __version__ + " )"
# The full version, including alpha/beta/rc tags
release = __version__
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.autodoc",
"sphinx_autodoc_typehints",
"sphinx.ext.autosummary",
"sphinx.ext.mathjax",
"sphinx.ext.ifconfig",
"sphinx.ext.viewcode",
# 'sphinx.ext.intersphinx',
# 'sphinx.ext.doctest'
]
if enable_spell_check:
extensions.append("sphinxcontrib.spelling")
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = ".rst"
# The master toctree document.
master_doc = "index"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path .
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
# Fix for read the docs
on_rtd = os.environ.get("READTHEDOCS") == "True"
if on_rtd:
html_theme = "default"
else:
html_theme = "sphinx_rtd_theme"
html_logo = "_static/img/logo.png"
def setup(app):
app.add_css_file("css/baselines_theme.css")
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = "StableBaselines3doc"
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, "StableBaselines3.tex", "Stable Baselines3 Documentation", "Stable Baselines3 Contributors", "manual"),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "stablebaselines3", "Stable Baselines3 Documentation", [author], 1)]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(
master_doc,
"StableBaselines3",
"Stable Baselines3 Documentation",
author,
"StableBaselines3",
"One line description of project.",
"Miscellaneous",
),
]
# -- Extension configuration -------------------------------------------------
# Example configuration for intersphinx: refer to the Python standard library.
# intersphinx_mapping = {
# 'python': ('https://docs.python.org/3/', None),
# 'numpy': ('http://docs.scipy.org/doc/numpy/', None),
# 'torch': ('http://pytorch.org/docs/master/', None),
# }

30
docs/guide/algos.rst Normal file
View File

@ -0,0 +1,30 @@
RL Algorithms
=============
This table displays the rl algorithms that are implemented in the Stable Baselines3 contrib project,
along with some useful characteristics: support for discrete/continuous actions, multiprocessing.
..
.. ============ =========== ============ ================= =============== ================
.. Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
.. ============ =========== ============ ================= =============== ================
.. A2C ✔️ ✔️ ✔️ ✔️ ✔️
.. DDPG ✔️ ❌ ❌ ❌ ❌
.. DQN ❌ ✔️ ❌ ❌ ❌
.. PPO ✔️ ✔️ ✔️ ✔️ ✔️
.. SAC ✔️ ❌ ❌ ❌ ❌
.. TD3 ✔️ ❌ ❌ ❌ ❌
.. ============ =========== ============ ================= =============== ================
.. .. note::
.. Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
..
.. Actions ``gym.spaces``:
..
.. - ``Box``: A N-dimensional box that contains every point in the action
.. space.
.. - ``Discrete``: A list of possible actions, where each timestep only
.. one of the actions can be used.
.. - ``MultiDiscrete``: A list of possible actions, where each timestep only one action of each discrete set can be used.
.. - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.

40
docs/guide/examples.rst Normal file
View File

@ -0,0 +1,40 @@
.. _examples:
Examples
========
WIP
.. PyBullet: Normalizing input features
.. ------------------------------------
..
.. Normalizing input features may be essential to successful training of an RL agent
.. (by default, images are scaled but not other types of input),
.. for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and
.. will compute a running average and standard deviation of input features (it can do the same for rewards).
..
.. .. note::
..
.. you need to install pybullet with ``pip install pybullet``
..
..
.. .. image:: ../_static/img/colab-badge.svg
.. :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
..
..
.. .. code-block:: python
..
.. import gym
.. import pybullet_envs
..
.. from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
.. from stable_baselines3 import PPO
..
.. env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
.. # Automatically normalize the input features and reward
.. env = VecNormalize(env, norm_obs=True, norm_reward=True,
.. clip_obs=10.)
..
.. model = PPO('MlpPolicy', env)
.. model.learn(total_timesteps=2000)

37
docs/guide/install.rst Normal file
View File

@ -0,0 +1,37 @@
.. _install:
Installation
============
Prerequisites
-------------
Please read `Stable-Baselines3 installation guide <https://stable-baselines3.readthedocs.io/en/master/guide/install.html>`_ first.
Stable Release
~~~~~~~~~~~~~~
To install Stable Baselines3 contrib with pip, execute:
.. code-block:: bash
pip install sb3-contrib
Bleeding-edge version
---------------------
.. code-block:: bash
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/
Development version
-------------------
To contribute to Stable-Baselines3, with support for running tests and building the documentation.
.. code-block:: bash
git clone https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/ && cd stable-baselines3-contrib
pip install -e .

78
docs/index.rst Normal file
View File

@ -0,0 +1,78 @@
.. Stable Baselines3 documentation master file, created by
sphinx-quickstart on Thu Sep 26 11:06:54 2019.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to Stable Baselines3 Contrib docs!
==========================================
Contrib package for `Stable Baselines3 <https://github.com/DLR-RM/stable-baselines3>`_ - Experimental code.
Github repository: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
SB3 repository: https://github.com/DLR-RM/stable-baselines3
RL Baselines3 Zoo (collection of pre-trained agents): https://github.com/DLR-RM/rl-baselines3-zoo
RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning.
.. toctree::
:maxdepth: 2
:caption: User Guide
guide/install
guide/algos
guide/examples
.. .. toctree::
.. :maxdepth: 1
.. :caption: RL Algorithms
..
.. modules/a2c
.. .. toctree::
.. :maxdepth: 1
.. :caption: Common
..
.. common/atari_wrappers
.. toctree::
:maxdepth: 1
:caption: Misc
misc/changelog
Citing Stable Baselines3
------------------------
To cite this project in publications:
.. code-block:: bibtex
@misc{stable-baselines3,
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
title = {Stable Baselines3},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
}
Contributing
------------
To any interested in making the rl baselines better, there are still some improvements
that need to be done.
You can check issues in the `repo <https://github.com/DLR-RM/stable-baselines3/issues>`_.
If you want to contribute, please read `CONTRIBUTING.md <https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md>`_ first.
Indices and tables
-------------------
* :ref:`genindex`
* :ref:`search`
* :ref:`modindex`

36
docs/make.bat Normal file
View File

@ -0,0 +1,36 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
set SPHINXPROJ=StableBaselines
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
:end
popd

42
docs/misc/changelog.rst Normal file
View File

@ -0,0 +1,42 @@
.. _changelog:
Changelog
==========
Pre-Release 0.10.0a0 (WIP)
------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Documentation:
^^^^^^^^^^^^^^
Maintainers
-----------
Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a),
`Maximilian Ernestus`_ (aka @erniejunior), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
.. _Ashley Hill: https://github.com/hill-a
.. _Antonin Raffin: https://araffin.github.io/
.. _Maximilian Ernestus: https://github.com/erniejunior
.. _Adam Gleave: https://gleave.me/
.. _@araffin: https://github.com/araffin
.. _@AdamGleave: https://github.com/adamgleave
.. _Anssi Kanervisto: https://github.com/Miffyli
.. _@Miffyli: https://github.com/Miffyli

121
docs/spelling_wordlist.txt Normal file
View File

@ -0,0 +1,121 @@
py
env
atari
argparse
Argparse
TensorFlow
feedforward
envs
VecEnv
pretrain
petrained
tf
th
nn
np
str
mujoco
cpu
ndarray
ndarrays
timestep
timesteps
stepsize
dataset
adam
fn
normalisation
Kullback
Leibler
boolean
deserialized
pretrained
minibatch
subprocesses
ArgumentParser
Tensorflow
Gaussian
approximator
minibatches
hyperparameters
hyperparameter
vectorized
rl
colab
dataloader
npz
datasets
vf
logits
num
Utils
backpropagate
prepend
NaN
preprocessing
Cloudpickle
async
multiprocess
tensorflow
mlp
cnn
neglogp
tanh
coef
repo
Huber
params
ppo
arxiv
Arxiv
func
DQN
Uhlenbeck
Ornstein
multithread
cancelled
Tensorboard
parallelize
customising
serializable
Multiprocessed
cartpole
toolset
lstm
rescale
ffmpeg
avconv
unnormalized
Github
pre
preprocess
backend
attr
preprocess
Antonin
Raffin
araffin
Homebrew
Numpy
Theano
rollout
kfac
Piecewise
csv
nvidia
visdom
tensorboard
preprocessed
namespace
sklearn
GoalEnv
Torchy
pytorch
dicts
optimizers
Deprecations
forkserver
cuda
Polyak
gSDE
rollouts

View File

@ -17,25 +17,25 @@ class Actor(BasePolicy):
""" """
Actor network (policy) for TQC. Actor network (policy) for TQC.
:param observation_space: (gym.spaces.Space) Obervation space :param observation_space: Obervation space
:param action_space: (gym.spaces.Space) Action space :param action_space: Action space
:param net_arch: ([int]) Network architecture :param net_arch: Network architecture
:param features_extractor: (nn.Module) Network to extract features :param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise) (a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: (int) Number of features :param features_dim: Number of features
:param activation_fn: (Type[nn.Module]) Activation function :param activation_fn: Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not :param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param full_std: (bool) Whether to use (n_features x n_actions) parameters :param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE. for the std instead of only (n_features,) when using gSDE.
:param sde_net_arch: ([int]) Network architecture for extracting features :param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used. when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features. Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param normalize_images: (bool) Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
""" """
@ -131,7 +131,7 @@ class Actor(BasePolicy):
but is slightly different when using ``expln`` function but is slightly different when using ``expln`` function
(cf StateDependentNoiseDistribution doc). (cf StateDependentNoiseDistribution doc).
:return: (th.Tensor) :return:
""" """
msg = "get_std() is only available when using gSDE" msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
@ -141,7 +141,7 @@ class Actor(BasePolicy):
""" """
Sample new weights for the exploration matrix, when using gSDE. Sample new weights for the exploration matrix, when using gSDE.
:param batch_size: (int) :param batch_size:
""" """
msg = "reset_noise() is only available when using gSDE" msg = "reset_noise() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
@ -151,8 +151,8 @@ class Actor(BasePolicy):
""" """
Get the parameters for the action distribution. Get the parameters for the action distribution.
:param obs: (th.Tensor) :param obs:
:return: (Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]) :return:
Mean, standard deviation and optional keyword arguments. Mean, standard deviation and optional keyword arguments.
""" """
features = self.extract_features(obs) features = self.extract_features(obs)
@ -183,31 +183,19 @@ class Actor(BasePolicy):
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.forward(observation, deterministic) return self.forward(observation, deterministic)
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
"""
Evaluate actions according to the current policy,
given the observations. Only useful when using SDE.
:param obs: (th.Tensor)
:param actions: (th.Tensor)
:return: (th.Tensor) log likelihood of taking those actions
"""
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
self.action_dist.proba_distribution(mean_actions, log_std, **kwargs)
return self.action_dist.log_prob(actions)
class Critic(BaseModel): class Critic(BaseModel):
""" """
Critic network (q-value function) for TQC. Critic network (q-value function) for TQC.
:param observation_space: (gym.spaces.Space) Obervation space :param observation_space: Obervation space
:param action_space: (gym.spaces.Space) Action space :param action_space: Action space
:param net_arch: ([int]) Network architecture :param net_arch: Network architecture
:param features_extractor: (nn.Module) Network to extract features :param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise) (a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: (int) Number of features :param features_dim: Number of features
:param activation_fn: (Type[nn.Module]) Activation function :param activation_fn: Activation function
:param normalize_images: (bool) Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
""" """
@ -257,28 +245,28 @@ class TQCPolicy(BasePolicy):
""" """
Policy class (with both actor and critic) for TQC. Policy class (with both actor and critic) for TQC.
:param observation_space: (gym.spaces.Space) Observation space :param observation_space: Observation space
:param action_space: (gym.spaces.Space) Action space :param action_space: Action space
:param lr_schedule: (callable) Learning rate schedule (could be constant) :param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks. :param net_arch: The specification of the policy and value networks.
:param activation_fn: (Type[nn.Module]) Activation function :param activation_fn: Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not :param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: ([int]) Network architecture for extracting features :param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used. when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features. Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. :param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments :param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor. to pass to the feature extractor.
:param normalize_images: (bool) Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, :param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default ``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, :param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer excluding the learning rate, to pass to the optimizer
""" """
@ -388,7 +376,7 @@ class TQCPolicy(BasePolicy):
""" """
Sample new weights for the exploration matrix, when using gSDE. Sample new weights for the exploration matrix, when using gSDE.
:param batch_size: (int) :param batch_size:
""" """
self.actor.reset_noise(batch_size=batch_size) self.actor.reset_noise(batch_size=batch_size)
@ -412,26 +400,26 @@ class CnnPolicy(TQCPolicy):
""" """
Policy class (with both actor and critic) for TQC. Policy class (with both actor and critic) for TQC.
:param observation_space: (gym.spaces.Space) Observation space :param observation_space: Observation space
:param action_space: (gym.spaces.Space) Action space :param action_space: Action space
:param lr_schedule: (callable) Learning rate schedule (could be constant) :param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks. :param net_arch: The specification of the policy and value networks.
:param activation_fn: (Type[nn.Module]) Activation function :param activation_fn: Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not :param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation :param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: ([int]) Network architecture for extracting features :param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used. when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features. Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. :param features_extractor_class: Features extractor to use.
:param normalize_images: (bool) Whether to normalize images or not, :param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, :param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default ``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, :param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer excluding the learning rate, to pass to the optimizer
""" """

View File

@ -7,7 +7,6 @@ from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import polyak_update from stable_baselines3.common.utils import polyak_update
from tqdm import tqdm
from sb3_contrib.tqc.policies import TQCPolicy from sb3_contrib.tqc.policies import TQCPolicy
@ -15,48 +14,48 @@ from sb3_contrib.tqc.policies import TQCPolicy
class TQC(OffPolicyAlgorithm): class TQC(OffPolicyAlgorithm):
""" """
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics.
Paper: https://arxiv.org/abs/2005.04269 Paper: https://arxiv.org/abs/2005.04269
:param policy: (TQCPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str) :param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: (float or callable) learning rate for adam optimizer, :param learning_rate: learning rate for adam optimizer,
the same learning rate will be used for all networks (Q-Values, Actor and Value function) the same learning rate will be used for all networks (Q-Values, Actor and Value function)
it can be a function of the current progress remaining (from 1 to 0) it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: (int) size of the replay buffer :param buffer_size: size of the replay buffer
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts :param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: (int) Minibatch size for each gradient update :param batch_size: Minibatch size for each gradient update
:param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: (float) the discount factor :param gamma: the discount factor
:param train_freq: (int) Update the model every ``train_freq`` steps. :param train_freq: Update the model every ``train_freq`` steps.
:param gradient_steps: (int) How many gradient update after each step :param gradient_steps: How many gradient update after each step
:param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
Note that this cannot be used at the same time as ``train_freq`` Note that this cannot be used at the same time as ``train_freq``
:param action_noise: (ActionNoise) the action noise type (None by default), this can help :param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type. for hard exploration problem. Cf common.noise for the different action noise type.
:param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity. at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param ent_coef: (str or float) Entropy regularization coefficient. (Equivalent to :param ent_coef: Entropy regularization coefficient. (Equivalent to
inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value) Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
:param target_update_interval: (int) update the target network every ``target_network_update_freq`` :param target_update_interval: update the target network every ``target_network_update_freq``
gradient steps. gradient steps.
:param target_entropy: (str or float) target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE) :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False) instead of action noise exploration (default: False)
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout) Default: -1 (only sample at the beginning of the rollout)
:param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts) during the warm up phase (before learning starts)
:param create_eval_env: (bool) Whether to create a second environment that will be :param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment) used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation :param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: (int) Seed for the pseudo random generators :param seed: Seed for the pseudo random generators
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. :param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible. Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance :param _init_setup_model: Whether or not to build the network at the creation of the instance
""" """
def __init__( def __init__(
@ -274,171 +273,6 @@ class TQC(OffPolicyAlgorithm):
if len(ent_coef_losses) > 0: if len(ent_coef_losses) > 0:
logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def pretrain(
self,
gradient_steps: int,
batch_size: int = 64,
n_action_samples: int = -1,
target_update_interval: int = 1,
tau: float = 0.005,
strategy: str = "exp",
reduce: str = "mean",
exp_temperature: float = 1.0,
off_policy_update_freq: int = -1,
) -> None:
"""
Pretrain with Critic Regularized Regression (CRR)
Paper: https://arxiv.org/abs/2006.15134
"""
# Update optimizers learning rate
optimizers = [self.actor.optimizer, self.critic.optimizer]
if self.ent_coef_optimizer is not None:
optimizers += [self.ent_coef_optimizer]
# Update learning rate according to lr schedule
self._update_learning_rate(optimizers)
actor_losses, critic_losses = [], []
for gradient_step in tqdm(range(gradient_steps)):
if off_policy_update_freq > 0 and gradient_step % off_policy_update_freq == 0:
self.train(gradient_steps=1, batch_size=batch_size)
continue
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
# We need to sample because `log_std` may have changed between two gradient steps
if self.use_sde:
self.actor.reset_noise()
# Action by the current actor for the sampled state
_, log_prob = self.actor.action_log_prob(replay_data.observations)
log_prob = log_prob.reshape(-1, 1)
ent_coef_loss = None
if self.ent_coef_optimizer is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef = th.exp(self.log_ent_coef.detach())
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
else:
ent_coef = self.ent_coef_tensor
self.replay_buffer.ent_coef = ent_coef.item()
# Optimize entropy coefficient, also called
# entropy temperature or alpha in the paper
if ent_coef_loss is not None:
self.ent_coef_optimizer.zero_grad()
ent_coef_loss.backward()
self.ent_coef_optimizer.step()
with th.no_grad():
top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute and cut quantiles at the next state
# batch x nets x quantiles
next_z = self.critic_target(replay_data.next_observations, next_actions)
sorted_z, _ = th.sort(next_z.reshape(batch_size, -1))
sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop]
target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1)
# td error + entropy term
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
# Get current Q estimates
# using action from the replay buffer
current_z = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = self.quantile_huber_loss(current_z, q_backup)
critic_losses.append(critic_loss.item())
# Optimize the critic
self.critic.optimizer.zero_grad()
critic_loss.backward()
self.critic.optimizer.step()
if strategy == "bc":
# Behavior cloning
weight = 1
else:
# Tensor version: TODO: check that the reshape works as expected
# cleaner but not faster on cpu for large batch size
# with th.no_grad():
# # Q-value for the action in the buffer
# qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True)
# # Create tensor to avoid loop
# # Note: For SDE, we need to sample several matrices
# obs_ = replay_data.observations.repeat(n_action_samples, 1)
# if self.use_sde:
# self.actor.reset_noise(batch_size * n_action_samples)
# actions_pi, _ = self.actor.action_log_prob(obs_)
# qf_pi = self.critic(obs_, actions_pi.detach()).mean(2).mean(1, keepdim=True)
# # Agregate: reduce mean or reduce max
# if reduce == "max":
# _, qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).max(axis=0)
# else:
# qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).mean(axis=0)
with th.no_grad():
qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True)
# Use the mean (as done in AWAC, cf rlkit)
if n_action_samples == -1:
actions_pi = self.actor.forward(replay_data.observations, deterministic=True)
qf_agg = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True)
else:
qf_agg = None
for _ in range(n_action_samples):
if self.use_sde:
self.actor.reset_noise()
actions_pi, _ = self.actor.action_log_prob(replay_data.observations)
qf_pi = self.critic(replay_data.observations, actions_pi.detach()).mean(2).mean(1, keepdim=True)
if qf_agg is None:
if reduce == "max":
qf_agg = qf_pi
else:
qf_agg = qf_pi / n_action_samples
else:
if reduce == "max":
qf_agg = th.max(qf_pi, qf_agg)
else:
qf_agg += qf_pi / n_action_samples
advantage = qf_buffer - qf_agg
if strategy == "binary":
# binary advantage
weight = advantage > 0
else:
# exp advantage
exp_clip = 20.0
weight = th.clamp(th.exp(advantage / exp_temperature), 0.0, exp_clip)
# Log prob by the current actor for the sampled state and action
log_prob = self.actor.evaluate_actions(replay_data.observations, replay_data.actions)
log_prob = log_prob.reshape(-1, 1)
# weigthed regression loss (close to policy gradient loss)
actor_loss = (-log_prob * weight).mean()
# actor_loss = ((actions_pi - replay_data.actions * weight) ** 2).mean()
actor_losses.append(actor_loss.item())
# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
# Update target networks
if gradient_step % target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), tau)
if self.use_sde:
print(f"std={(self.actor.get_std()).mean().item()}")
def learn( def learn(
self, self,
total_timesteps: int, total_timesteps: int,
@ -469,7 +303,7 @@ class TQC(OffPolicyAlgorithm):
Returns the names of the parameters that should be excluded by default Returns the names of the parameters that should be excluded by default
when saving the model. when saving the model.
:return: (List[str]) List of parameters that should be excluded from save :return: List of parameters that should be excluded from save
""" """
# Exclude aliases # Exclude aliases
return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]

View File

@ -1 +1 @@
0.9.0a2 0.10.0a0

View File

@ -23,11 +23,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3[tests,docs]>=0.9.0a0", "stable_baselines3[tests,docs]>=0.9.0",
# For progress bar when using CRR
"tqdm"
# Enable CMA
# "cma",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

View File

@ -14,41 +14,27 @@ def test_tqc(ent_coef):
create_eval_env=True, create_eval_env=True,
ent_coef=ent_coef, ent_coef=ent_coef,
) )
model.learn(total_timesteps=1000, eval_freq=500) model.learn(total_timesteps=500, eval_freq=250)
@pytest.mark.parametrize("n_critics", [1, 3]) @pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics): def test_n_critics(n_critics):
# Test TQC with different number of critics # Test TQC with different number of critics
model = TQC( model = TQC(
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1 "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1
) )
model.learn(total_timesteps=1000) model.learn(total_timesteps=500)
# "CartPole-v1" def test_sde():
# @pytest.mark.parametrize("env_id", ["MountainCarContinuous-v0"])
# def test_cmaes(env_id):
# if CMAES is None:
# return
# model = CMAES("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[64]), verbose=1, create_eval_env=True)
# model.learn(total_timesteps=50000, eval_freq=10000)
@pytest.mark.parametrize("strategy", ["exp", "bc", "binary"])
@pytest.mark.parametrize("reduce", ["mean", "max"])
def test_crr(tmp_path, strategy, reduce):
model = TQC( model = TQC(
"MlpPolicy", "MlpPolicy",
"Pendulum-v0", "Pendulum-v0",
policy_kwargs=dict(net_arch=[64]), policy_kwargs=dict(net_arch=[64], sde_net_arch=[8]),
learning_starts=1000, use_sde=True,
learning_starts=100,
verbose=1, verbose=1,
create_eval_env=True,
action_noise=None,
use_sde=False,
) )
model.learn(total_timesteps=500)
model.learn(total_timesteps=1000, eval_freq=0) model.policy.reset_noise()
for n_action_samples in [1, 2, -1]: model.policy.actor.get_std()
model.pretrain(gradient_steps=32, batch_size=32, n_action_samples=n_action_samples, strategy=strategy, reduce=reduce)

View File

@ -1,6 +1,5 @@
import os import os
import pathlib import pathlib
import warnings
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
@ -11,7 +10,6 @@ import torch as th
from stable_baselines3 import DQN from stable_baselines3 import DQN
from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
from stable_baselines3.common.utils import get_device from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3.common.vec_env import DummyVecEnv
@ -45,7 +43,7 @@ def test_save_load(tmp_path, model_class):
# create model # create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1) model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500, eval_freq=250) model.learn(total_timesteps=500)
env.reset() env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
@ -154,7 +152,7 @@ def test_save_load(tmp_path, model_class):
assert np.allclose(selected_actions, new_selected_actions, 1e-4) assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# check if learn still works # check if learn still works
model.learn(total_timesteps=1000, eval_freq=500) model.learn(total_timesteps=500)
del model del model
@ -177,17 +175,17 @@ def test_set_env(model_class):
# create model # create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16])) model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
# learn # learn
model.learn(total_timesteps=1000, eval_freq=500) model.learn(total_timesteps=300)
# change env # change env
model.set_env(env2) model.set_env(env2)
# learn again # learn again
model.learn(total_timesteps=1000, eval_freq=500) model.learn(total_timesteps=300)
# change env test wrapping # change env test wrapping
model.set_env(env3) model.set_env(env3)
# learn again # learn again
model.learn(total_timesteps=1000, eval_freq=500) model.learn(total_timesteps=300)
@pytest.mark.parametrize("model_class", MODEL_LIST) @pytest.mark.parametrize("model_class", MODEL_LIST)
@ -247,45 +245,6 @@ def test_save_load_replay_buffer(tmp_path, model_class):
) )
@pytest.mark.parametrize("model_class", [TQC])
@pytest.mark.parametrize("optimize_memory_usage", [False, True])
def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
"""
When using memory efficient replay buffer,
a warning must be emitted when calling `.learn()`
multiple times.
See https://github.com/DLR-RM/stable-baselines3/issues/46
"""
# remove gym warnings
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
model = model_class(
"MlpPolicy",
select_env(model_class),
buffer_size=100,
optimize_memory_usage=optimize_memory_usage,
policy_kwargs=dict(net_arch=[64]),
learning_starts=10,
)
model.learn(150)
model.learn(150, reset_num_timesteps=False)
# Check that there is no warning
assert len(recwarn) == 0
model.learn(150)
if optimize_memory_usage:
assert len(recwarn) == 1
warning = recwarn.pop(UserWarning)
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
else:
assert len(recwarn) == 0
@pytest.mark.parametrize("model_class", MODEL_LIST) @pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) @pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
def test_save_load_policy(tmp_path, model_class, policy_str): def test_save_load_policy(tmp_path, model_class, policy_str):
@ -309,7 +268,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
# create model # create model
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs) model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
model.learn(total_timesteps=500, eval_freq=250) model.learn(total_timesteps=500)
env.reset() env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
@ -375,52 +334,3 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
os.remove(tmp_path / "policy.pkl") os.remove(tmp_path / "policy.pkl")
if actor_class is not None: if actor_class is not None:
os.remove(tmp_path / "actor.pkl") os.remove(tmp_path / "actor.pkl")
@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
def test_open_file_str_pathlib(tmp_path, pathtype):
# check that suffix isn't added because we used open_path first
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
assert not record
# test custom suffix
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
assert not record
# test without suffix
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
assert not record
# test that a warning is raised when the path doesn't exist
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
assert len(record) == 0
with pytest.warns(None) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
assert len(record) == 1
fp = pathlib.Path(f"{tmp_path}/t2").open("w")
fp.write("rubbish")
fp.close()
# test that a warning is only raised when verbose = 0
with pytest.warns(None) as record:
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
assert len(record) == 1