Add TQC and base scripts

This commit is contained in:
Antonin RAFFIN 2020-09-25 12:47:45 +02:00
parent 8c282a0288
commit 0d9f2e229e
19 changed files with 1832 additions and 2 deletions

11
.coveragerc Normal file
View File

@ -0,0 +1,11 @@
[run]
branch = False
omit =
tests/*
setup.py
[report]
exclude_lines =
pragma: no cover
raise NotImplementedError()
if typing.TYPE_CHECKING:

View File

@ -0,0 +1,53 @@
---
name: Issue Template
about: How to create an issue for this repository
---
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on [reddit](https://www.reddit.com/r/reinforcementlearning/) or [stack overflow](https://stackoverflow.com/) in that case.
If you have any questions, feel free to create an issue with the tag [question].
If you wish to suggest an enhancement or feature request, add the tag [feature request].
If you are submitting a bug report, please fill in the following details.
If your issue is related to a custom gym environment, please check it first using:
```python
from stable_baselines3.common.env_checker import check_env
env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)
```
**Describe the bug**
A clear and concise description of what the bug is.
**Code example**
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks)
for both code and stack traces.
```python
from stable_baselines3 import ...
```
```bash
Traceback (most recent call last): File ...
```
**System Info**
Describe the characteristic of your environment:
* Describe how the library was installed (pip, docker, source, ...)
* GPU models and configuration
* Python version
* PyTorch version
* Gym version
* Versions of any other relevant libraries
**Additional context**
Add any other context about the problem here.

34
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,34 @@
<!--- Provide a general summary of your changes in the Title above -->
## Description
<!--- Describe your changes in detail -->
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here. -->
<!--- You can use the syntax `closes #100` if this solves the issue #100 -->
- [ ] I have raised an issue to propose this change ([required](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) for new features and bug fixes)
## Types of changes
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Documentation (update in the documentation)
## Checklist:
<!--- Go over all the following points, and put an `x` in all the boxes that apply. -->
<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
- [ ] I've read the [CONTRIBUTION](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) guide (**required**)
- [ ] I have updated the changelog accordingly (**required**).
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
- [ ] I have updated the documentation accordingly.
- [ ] I have reformatted the code using `make format` (**required**)
- [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**)
- [ ] I have ensured `make pytest` and `make type` both pass. (**required**)
Note: we are using a maximum length of 127 characters per line
<!--- This Template is an edited version of the one from https://github.com/evilsocket/pwnagotchi/ -->

51
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,51 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: CI
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7] # 3.8 not supported yet by pytype
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install stable_baselines3[extra,tests,docs]
# Install master version
pip install git+https://github.com/DLR-RM/stable-baselines3
# Use headless version
pip install opencv-python-headless
# - name: Build the doc
# run: |
# make doc
- name: Type check
run: |
make type
- name: Check codestyle
run: |
make check-codestyle
- name: Lint with flake8
run: |
make lint
- name: Test with pytest
run: |
make pytest

48
.gitignore vendored Normal file
View File

@ -0,0 +1,48 @@
*.swp
*.pyc
*.pkl
*.py~
*.bak
.pytest_cache
.DS_Store
.idea
.coverage
.coverage.*
__pycache__/
_build/
*.npz
*.pth
.pytype/
git_rewrite_commit_history.sh
# Setuptools distribution and build folders.
/dist/
/build
keys/
# Virtualenv
/env
/venv
*.sublime-project
*.sublime-workspace
.idea
logs/
.ipynb_checkpoints
ghostdriver.log
htmlcov
junk
src
*.egg-info
.cache
*.lprof
*.prof
MUJOCO_LOG.TXT

12
CHANGELOG.md Normal file
View File

@ -0,0 +1,12 @@
## Release 0.9.0a2 (WIP)
### Breaking Changes
### New Features
- Add TQC
### Bug fixes
### Documentation
### Other

44
Makefile Normal file
View File

@ -0,0 +1,44 @@
SHELL=/bin/bash
LINT_PATHS=sb3_contrib/ tests/ setup.py
pytest:
./scripts/run_tests.sh
type:
pytype
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://lintlyci.github.io/Flake8Rules/
flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings.
flake8 ${LINT_PATHS} --count --exit-zero --statistics
format:
# Sort imports
isort ${LINT_PATHS}
# Reformat using black
black -l 127 ${LINT_PATHS}
check-codestyle:
# Sort imports
isort --check ${LINT_PATHS}
# Reformat using black
black --check -l 127 ${LINT_PATHS}
commit-checks: format type lint
# PyPi package release
release:
python setup.py sdist
python setup.py bdist_wheel
twine upload dist/*
# Test PyPi package release
test-release:
python setup.py sdist
python setup.py bdist_wheel
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
.PHONY: lint format check-codestyle commit-checks

View File

@ -1,2 +1,20 @@
# stable-baselines3-contrib
Contrib package for Stable-Baselines3 - Experimental code
[![CI](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/workflows/ci/badge.svg)](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/actions)[![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
# Stable-Baselines3 - Contrib
Contrib package for [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) - Experimental code
**You need the master version of Stable Baselines3**:
```
pip install git+https://github.com/DLR-RM/stable-baselines3
```
Implemented:
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
## Installation
```
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
```

9
sb3_contrib/__init__.py Normal file
View File

@ -0,0 +1,9 @@
import os
# from sb3_contrib.cmaes import CMAES
from sb3_contrib.tqc import TQC
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_file, "r") as file_handler:
__version__ = file_handler.read().strip()

0
sb3_contrib/py.typed Normal file
View File

View File

@ -0,0 +1,2 @@
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy
from sb3_contrib.tqc.tqc import TQC

480
sb3_contrib/tqc/policies.py Normal file
View File

@ -0,0 +1,480 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import gym
import torch as th
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
from torch import nn as nn
# CAP the standard deviation of the actor
LOG_STD_MAX = 2
LOG_STD_MIN = -20
class Actor(BasePolicy):
"""
Actor network (policy) for TQC.
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: ([int]) Network architecture
:param features_extractor: (nn.Module) Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: (int) Number of features
:param activation_fn: (Type[nn.Module]) Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE.
:param sde_net_arch: ([int]) Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
normalize_images: bool = True,
):
super(Actor, self).__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
squash_output=True,
)
# Save arguments to re-create object at loading
self.use_sde = use_sde
self.sde_features_extractor = None
self.sde_net_arch = sde_net_arch
self.net_arch = net_arch
self.features_dim = features_dim
self.activation_fn = activation_fn
self.log_std_init = log_std_init
self.sde_net_arch = sde_net_arch
self.use_expln = use_expln
self.full_std = full_std
self.clip_mean = clip_mean
action_dim = get_action_dim(self.action_space)
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
self.latent_pi = nn.Sequential(*latent_pi_net)
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
if self.use_sde:
latent_sde_dim = last_layer_dim
# Separate feature extractor for gSDE
if sde_net_arch is not None:
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
features_dim, sde_net_arch, activation_fn
)
self.action_dist = StateDependentNoiseDistribution(
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
)
self.mu, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init
)
# Avoid numerical issues by limiting the mean of the Gaussian
# to be in [-clip_mean, clip_mean]
if clip_mean > 0.0:
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
else:
self.action_dist = SquashedDiagGaussianDistribution(action_dim)
self.mu = nn.Linear(last_layer_dim, action_dim)
self.log_std = nn.Linear(last_layer_dim, action_dim)
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
data.update(
dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
full_std=self.full_std,
sde_net_arch=self.sde_net_arch,
use_expln=self.use_expln,
features_extractor=self.features_extractor,
clip_mean=self.clip_mean,
)
)
return data
def get_std(self) -> th.Tensor:
"""
Retrieve the standard deviation of the action distribution.
Only useful when using gSDE.
It corresponds to ``th.exp(log_std)`` in the normal case,
but is slightly different when using ``expln`` function
(cf StateDependentNoiseDistribution doc).
:return: (th.Tensor)
"""
msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
return self.action_dist.get_std(self.log_std)
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size: (int)
"""
msg = "reset_noise() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
"""
Get the parameters for the action distribution.
:param obs: (th.Tensor)
:return: (Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]])
Mean, standard deviation and optional keyword arguments.
"""
features = self.extract_features(obs)
latent_pi = self.latent_pi(features)
mean_actions = self.mu(latent_pi)
if self.use_sde:
latent_sde = latent_pi
if self.sde_features_extractor is not None:
latent_sde = self.sde_features_extractor(features)
return mean_actions, self.log_std, dict(latent_sde=latent_sde)
# Unstructured exploration (Original implementation)
log_std = self.log_std(latent_pi)
# Original Implementation to cap the standard deviation
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, {}
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# return action and associated log prob
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.forward(observation, deterministic)
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
"""
Evaluate actions according to the current policy,
given the observations. Only useful when using SDE.
:param obs: (th.Tensor)
:param actions: (th.Tensor)
:return: (th.Tensor) log likelihood of taking those actions
"""
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
self.action_dist.proba_distribution(mean_actions, log_std, **kwargs)
return self.action_dist.log_prob(actions)
class Critic(BaseModel):
"""
Critic network (q-value function) for TQC.
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: ([int]) Network architecture
:param features_extractor: (nn.Module) Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: (int) Number of features
:param activation_fn: (Type[nn.Module]) Activation function
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
n_quantiles: int = 25,
n_critics: int = 2,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)
action_dim = get_action_dim(self.action_space)
self.q_networks = []
self.n_quantiles = n_quantiles
self.n_critics = n_critics
self.quantiles_total = n_quantiles * n_critics
for i in range(n_critics):
qf_net = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
qf_net = nn.Sequential(*qf_net)
self.add_module(f"qf{i}", qf_net)
self.q_networks.append(qf_net)
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
# Learn the features extractor using the policy loss only
# this is much faster
with th.no_grad():
features = self.extract_features(obs)
qvalue_input = th.cat([features, action], dim=1)
quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1)
return quantiles
class TQCPolicy(BasePolicy):
"""
Policy class (with both actor and critic) for TQC.
:param observation_space: (gym.spaces.Space) Observation space
:param action_space: (gym.spaces.Space) Action space
:param lr_schedule: (callable) Learning rate schedule (could be constant)
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
:param activation_fn: (Type[nn.Module]) Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param sde_net_arch: ([int]) Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
to pass to the feature extractor.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_quantiles: int = 25,
n_critics: int = 2,
):
super(TQCPolicy, self).__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True,
)
if net_arch is None:
if features_extractor_class == FlattenExtractor:
net_arch = [256, 256]
else:
net_arch = []
# Create shared features extractor
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"features_extractor": self.features_extractor,
"features_dim": self.features_dim,
"net_arch": self.net_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
self.actor_kwargs = self.net_args.copy()
self.critic_kwargs = self.net_args.copy()
sde_kwargs = {
"use_sde": use_sde,
"log_std_init": log_std_init,
"sde_net_arch": sde_net_arch,
"use_expln": use_expln,
"clip_mean": clip_mean,
}
self.actor_kwargs.update(sde_kwargs)
tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics}
self.critic_kwargs.update(tqc_kwargs)
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
self._build(lr_schedule)
def _build(self, lr_schedule: Callable) -> None:
self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.critic = self.make_critic()
self.critic_target = self.make_critic()
self.critic_target.load_state_dict(self.critic.state_dict())
# Do not optimize the shared feature extractor with the critic loss
# otherwise, there are gradient computation issues
# Another solution: having duplicated features extractor but requires more memory and computation
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
data.update(
dict(
net_arch=self.net_args["net_arch"],
activation_fn=self.net_args["activation_fn"],
use_sde=self.actor_kwargs["use_sde"],
log_std_init=self.actor_kwargs["log_std_init"],
sde_net_arch=self.actor_kwargs["sde_net_arch"],
use_expln=self.actor_kwargs["use_expln"],
clip_mean=self.actor_kwargs["clip_mean"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size: (int)
"""
self.actor.reset_noise(batch_size=batch_size)
def make_actor(self) -> Actor:
return Actor(**self.actor_kwargs).to(self.device)
def make_critic(self) -> Critic:
return Critic(**self.critic_kwargs).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.actor(observation, deterministic)
MlpPolicy = TQCPolicy
class CnnPolicy(TQCPolicy):
"""
Policy class (with both actor and critic) for TQC.
:param observation_space: (gym.spaces.Space) Observation space
:param action_space: (gym.spaces.Space) Action space
:param lr_schedule: (callable) Learning rate schedule (could be constant)
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
:param activation_fn: (Type[nn.Module]) Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param sde_net_arch: ([int]) Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_quantiles: int = 25,
n_critics: int = 2,
):
super(CnnPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
use_sde,
log_std_init,
sde_net_arch,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_quantiles,
n_critics,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)

492
sb3_contrib/tqc/tqc.py Normal file
View File

@ -0,0 +1,492 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch as th
from stable_baselines3.common import logger
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import polyak_update
from tqdm import tqdm
from sb3_contrib.tqc.policies import TQCPolicy
class TQC(OffPolicyAlgorithm):
"""
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics
Paper: https://arxiv.org/abs/2005.04269
:param policy: (TQCPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str)
:param learning_rate: (float or callable) learning rate for adam optimizer,
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: (int) size of the replay buffer
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
:param batch_size: (int) Minibatch size for each gradient update
:param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: (float) the discount factor
:param train_freq: (int) Update the model every ``train_freq`` steps.
:param gradient_steps: (int) How many gradient update after each step
:param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes.
Note that this cannot be used at the same time as ``train_freq``
:param action_noise: (ActionNoise) the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param ent_coef: (str or float) Entropy regularization coefficient. (Equivalent to
inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
:param target_update_interval: (int) update the target network every ``target_network_update_freq``
gradient steps.
:param target_entropy: (str or float) target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts)
:param create_eval_env: (bool) Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug
:param seed: (int) Seed for the pseudo random generators
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
def __init__(
self,
policy: Union[str, Type[TQCPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Callable] = 3e-4,
buffer_size: int = int(1e6),
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: int = 1,
gradient_steps: int = 1,
n_episodes_rollout: int = -1,
action_noise: Optional[ActionNoise] = None,
optimize_memory_usage: bool = False,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
ent_coef: Union[str, float] = "auto",
target_update_interval: int = 1,
target_entropy: Union[str, float] = "auto",
top_quantiles_to_drop_per_net: int = 2,
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Dict[str, Any] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super(TQC, self).__init__(
policy,
env,
TQCPolicy,
learning_rate,
replay_buffer_class,
replay_buffer_kwargs,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
n_episodes_rollout,
action_noise,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
create_eval_env=create_eval_env,
seed=seed,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
use_sde_at_warmup=use_sde_at_warmup,
optimize_memory_usage=optimize_memory_usage,
)
self.target_entropy = target_entropy
self.log_ent_coef = None # type: Optional[th.Tensor]
# Entropy coefficient / Entropy temperature
# Inverse of the reward scale
self.ent_coef = ent_coef
self.target_update_interval = target_update_interval
self.ent_coef_optimizer = None
self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
super(TQC, self)._setup_model()
self._create_aliases()
self.replay_buffer.actor = self.actor
self.replay_buffer.ent_coef = 0.0
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)
# The entropy coefficient or entropy can be learned automatically
# see Automating Entropy Adjustment for Maximum Entropy RL section
# of https://arxiv.org/abs/1812.05905
if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"):
# Default initial value of ent_coef when learned
init_value = 1.0
if "_" in self.ent_coef:
init_value = float(self.ent_coef.split("_")[1])
assert init_value > 0.0, "The initial value of ent_coef must be greater than 0"
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
else:
# Force conversion to float
# this will throw an error if a malformed string (different from 'auto')
# is passed
self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)
def _create_aliases(self) -> None:
self.actor = self.policy.actor
self.critic = self.policy.critic
self.critic_target = self.policy.critic_target
@staticmethod
def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor:
# batch x nets x quantiles x samples
pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None]
abs_pairwise_delta = th.abs(pairwise_delta)
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
n_quantiles = quantiles.shape[2]
tau = th.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles
loss = (th.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean()
return loss
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Update optimizers learning rate
optimizers = [self.actor.optimizer, self.critic.optimizer]
if self.ent_coef_optimizer is not None:
optimizers += [self.ent_coef_optimizer]
# Update learning rate according to lr schedule
self._update_learning_rate(optimizers)
ent_coef_losses, ent_coefs = [], []
actor_losses, critic_losses = [], []
for gradient_step in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
# We need to sample because `log_std` may have changed between two gradient steps
if self.use_sde:
self.actor.reset_noise()
# Action by the current actor for the sampled state
actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
log_prob = log_prob.reshape(-1, 1)
ent_coef_loss = None
if self.ent_coef_optimizer is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef = th.exp(self.log_ent_coef.detach())
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
else:
ent_coef = self.ent_coef_tensor
ent_coefs.append(ent_coef.item())
self.replay_buffer.ent_coef = ent_coef.item()
# Optimize entropy coefficient, also called
# entropy temperature or alpha in the paper
if ent_coef_loss is not None:
self.ent_coef_optimizer.zero_grad()
ent_coef_loss.backward()
self.ent_coef_optimizer.step()
with th.no_grad():
top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute and cut quantiles at the next state
# batch x nets x quantiles
next_z = self.critic_target(replay_data.next_observations, next_actions)
sorted_z, _ = th.sort(next_z.reshape(batch_size, -1))
sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop]
target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1)
# td error + entropy term
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
# Get current Q estimates
# using action from the replay buffer
current_z = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = self.quantile_huber_loss(current_z, q_backup)
critic_losses.append(critic_loss.item())
# Optimize the critic
self.critic.optimizer.zero_grad()
critic_loss.backward()
self.critic.optimizer.step()
# Compute actor loss
qf_pi = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True)
actor_loss = (ent_coef * log_prob - qf_pi).mean()
actor_losses.append(actor_loss.item())
# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
# Update target networks
if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
self._n_updates += gradient_steps
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
logger.record("train/ent_coef", np.mean(ent_coefs))
logger.record("train/actor_loss", np.mean(actor_losses))
logger.record("train/critic_loss", np.mean(critic_losses))
if len(ent_coef_losses) > 0:
logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def pretrain(
self,
gradient_steps: int,
batch_size: int = 64,
n_action_samples: int = -1,
target_update_interval: int = 1,
tau: float = 0.005,
strategy: str = "exp",
reduce: str = "mean",
exp_temperature: float = 1.0,
off_policy_update_freq: int = -1,
) -> None:
"""
Pretrain with Critic Regularized Regression (CRR)
Paper: https://arxiv.org/abs/2006.15134
"""
# Update optimizers learning rate
optimizers = [self.actor.optimizer, self.critic.optimizer]
if self.ent_coef_optimizer is not None:
optimizers += [self.ent_coef_optimizer]
# Update learning rate according to lr schedule
self._update_learning_rate(optimizers)
actor_losses, critic_losses = [], []
for gradient_step in tqdm(range(gradient_steps)):
if off_policy_update_freq > 0 and gradient_step % off_policy_update_freq == 0:
self.train(gradient_steps=1, batch_size=batch_size)
continue
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
# We need to sample because `log_std` may have changed between two gradient steps
if self.use_sde:
self.actor.reset_noise()
# Action by the current actor for the sampled state
_, log_prob = self.actor.action_log_prob(replay_data.observations)
log_prob = log_prob.reshape(-1, 1)
ent_coef_loss = None
if self.ent_coef_optimizer is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef = th.exp(self.log_ent_coef.detach())
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
else:
ent_coef = self.ent_coef_tensor
self.replay_buffer.ent_coef = ent_coef.item()
# Optimize entropy coefficient, also called
# entropy temperature or alpha in the paper
if ent_coef_loss is not None:
self.ent_coef_optimizer.zero_grad()
ent_coef_loss.backward()
self.ent_coef_optimizer.step()
with th.no_grad():
top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute and cut quantiles at the next state
# batch x nets x quantiles
next_z = self.critic_target(replay_data.next_observations, next_actions)
sorted_z, _ = th.sort(next_z.reshape(batch_size, -1))
sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop]
target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1)
# td error + entropy term
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
# Get current Q estimates
# using action from the replay buffer
current_z = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = self.quantile_huber_loss(current_z, q_backup)
critic_losses.append(critic_loss.item())
# Optimize the critic
self.critic.optimizer.zero_grad()
critic_loss.backward()
self.critic.optimizer.step()
if strategy == "bc":
# Behavior cloning
weight = 1
else:
# Tensor version: TODO: check that the reshape works as expected
# cleaner but not faster on cpu for large batch size
# with th.no_grad():
# # Q-value for the action in the buffer
# qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True)
# # Create tensor to avoid loop
# # Note: For SDE, we need to sample several matrices
# obs_ = replay_data.observations.repeat(n_action_samples, 1)
# if self.use_sde:
# self.actor.reset_noise(batch_size * n_action_samples)
# actions_pi, _ = self.actor.action_log_prob(obs_)
# qf_pi = self.critic(obs_, actions_pi.detach()).mean(2).mean(1, keepdim=True)
# # Agregate: reduce mean or reduce max
# if reduce == "max":
# _, qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).max(axis=0)
# else:
# qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).mean(axis=0)
with th.no_grad():
qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True)
# Use the mean (as done in AWAC, cf rlkit)
if n_action_samples == -1:
actions_pi = self.actor.forward(replay_data.observations, deterministic=True)
qf_agg = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True)
else:
qf_agg = None
for _ in range(n_action_samples):
if self.use_sde:
self.actor.reset_noise()
actions_pi, _ = self.actor.action_log_prob(replay_data.observations)
qf_pi = self.critic(replay_data.observations, actions_pi.detach()).mean(2).mean(1, keepdim=True)
if qf_agg is None:
if reduce == "max":
qf_agg = qf_pi
else:
qf_agg = qf_pi / n_action_samples
else:
if reduce == "max":
qf_agg = th.max(qf_pi, qf_agg)
else:
qf_agg += qf_pi / n_action_samples
advantage = qf_buffer - qf_agg
if strategy == "binary":
# binary advantage
weight = advantage > 0
else:
# exp advantage
exp_clip = 20.0
weight = th.clamp(th.exp(advantage / exp_temperature), 0.0, exp_clip)
# Log prob by the current actor for the sampled state and action
log_prob = self.actor.evaluate_actions(replay_data.observations, replay_data.actions)
log_prob = log_prob.reshape(-1, 1)
# weigthed regression loss (close to policy gradient loss)
actor_loss = (-log_prob * weight).mean()
# actor_loss = ((actions_pi - replay_data.actions * weight) ** 2).mean()
actor_losses.append(actor_loss.item())
# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
# Update target networks
if gradient_step % target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), tau)
if self.use_sde:
print(f"std={(self.actor.get_std()).mean().item()}")
def learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "TQC",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
return super(TQC, self).learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
eval_env=eval_env,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
)
def _excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.
:return: (List[str]) List of parameters that should be excluded from save
"""
# Exclude aliases
return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
saved_pytorch_variables = ["log_ent_coef"]
if self.ent_coef_optimizer is not None:
state_dicts.append("ent_coef_optimizer")
else:
saved_pytorch_variables.append("ent_coef_tensor")
return state_dicts, saved_pytorch_variables

1
sb3_contrib/version.txt Normal file
View File

@ -0,0 +1 @@
0.9.0a2

2
scripts/run_tests.sh Normal file
View File

@ -0,0 +1,2 @@
#!/bin/bash
python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v

48
setup.cfg Normal file
View File

@ -0,0 +1,48 @@
[metadata]
# This includes the license file in the wheel.
license_file = LICENSE
[tool:pytest]
# Deterministic ordering for tests; useful for pytest-xdist.
env =
PYTHONHASHSEED=0
filterwarnings =
# Tensorboard/Tensorflow warnings
ignore:inspect.getargspec:DeprecationWarning:tensorflow
ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning
ignore:The binary mode of fromstring is deprecated:DeprecationWarning
ignore::FutureWarning:tensorflow
# Gym warnings
ignore:Parameters to load are deprecated.:DeprecationWarning
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
ignore::UserWarning:gym
[pytype]
inputs = sb3_contrib
[flake8]
ignore = W503,W504,E203,E231 # line breaks before and after binary operators
# Ignore import not used when aliases are defined
per-file-ignores =
./sb3_contrib/__init__.py:F401
./sb3_contrib/tqc/__init__.py:F401
exclude =
# No need to traverse our git directory
.git,
# There's no value in checking cache directories
__pycache__,
# Don't check the doc
docs/
# This contains our built documentation
build,
# This contains builds of flake8 that we don't want to check
dist
*.egg-info
max-complexity = 15
# The GitHub editor is 127 chars wide
max-line-length = 127
[isort]
profile = black
line_length = 127
src_paths = sb3_contrib

45
setup.py Normal file
View File

@ -0,0 +1,45 @@
import os
from setuptools import find_packages, setup
with open(os.path.join("sb3_contrib", "version.txt"), "r") as file_handler:
__version__ = file_handler.read().strip()
long_description = """
# Stable-Baselines3 - Contrib
Contrib package for [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) - Experimental code
Implemented:
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
""" # noqa:E501
setup(
name="sb3_contrib",
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3[tests,docs]",
# Enable CMA
# "cma",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
url="https://github.com/Stable-Baselines-Team/stable-baselines3-contrib",
author_email="antonin.raffin@dlr.de",
keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
"gym openai stable baselines toolbox python data-science",
license="MIT",
long_description=long_description,
long_description_content_type="text/markdown",
version=__version__,
)
# python setup.py sdist
# python setup.py bdist_wheel
# twine upload --repository-url https://test.pypi.org/legacy/ dist/*
# twine upload dist/*

54
tests/test_run.py Normal file
View File

@ -0,0 +1,54 @@
import pytest
from sb3_contrib import TQC
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
def test_tqc(ent_coef):
model = TQC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
create_eval_env=True,
ent_coef=ent_coef,
)
model.learn(total_timesteps=1000, eval_freq=500)
@pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics):
# Test TQC with different number of critics
model = TQC(
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1
)
model.learn(total_timesteps=1000)
# "CartPole-v1"
# @pytest.mark.parametrize("env_id", ["MountainCarContinuous-v0"])
# def test_cmaes(env_id):
# if CMAES is None:
# return
# model = CMAES("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[64]), verbose=1, create_eval_env=True)
# model.learn(total_timesteps=50000, eval_freq=10000)
@pytest.mark.parametrize("strategy", ["exp", "bc", "binary"])
@pytest.mark.parametrize("reduce", ["mean", "max"])
def test_crr(tmp_path, strategy, reduce):
model = TQC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64]),
learning_starts=1000,
verbose=1,
create_eval_env=True,
action_noise=None,
use_sde=False,
)
model.learn(total_timesteps=1000, eval_freq=0)
for n_action_samples in [1, 2, -1]:
model.pretrain(gradient_steps=32, batch_size=32, n_action_samples=n_action_samples, strategy=strategy, reduce=reduce)

426
tests/test_save_load.py Normal file
View File

@ -0,0 +1,426 @@
import os
import pathlib
import warnings
from collections import OrderedDict
from copy import deepcopy
import gym
import numpy as np
import pytest
import torch as th
from stable_baselines3 import DQN
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib import TQC
MODEL_LIST = [TQC]
def select_env(model_class: BaseAlgorithm) -> gym.Env:
"""
Selects an environment with the correct action space as DQN only supports discrete action space
"""
if model_class == DQN:
return IdentityEnv(10)
else:
return IdentityEnvBox(10)
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_save_load(tmp_path, model_class):
"""
Test if 'save' and 'load' saves and loads model correctly
and if 'get_parameters' and 'set_parameters' and work correctly.
''warning does not test function of optimizer parameter load
:param model_class: (BaseAlgorithm) A RL model
"""
env = DummyVecEnv([lambda: select_env(model_class)])
# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500, eval_freq=250)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
# Get parameters of different objects
# deepcopy to avoid referencing to tensors we are about to modify
original_params = deepcopy(model.get_parameters())
# Test different error cases of set_parameters.
# Test that invalid object names throw errors
invalid_object_params = deepcopy(original_params)
invalid_object_params["I_should_not_be_a_valid_object"] = "and_I_am_an_invalid_tensor"
with pytest.raises(ValueError):
model.set_parameters(invalid_object_params, exact_match=True)
with pytest.raises(ValueError):
model.set_parameters(invalid_object_params, exact_match=False)
# Test that exact_match catches when something was missed.
missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1])
with pytest.raises(ValueError):
model.set_parameters(missing_object_params, exact_match=True)
# Test that exact_match catches when something inside state-dict
# is missing but we have exact_match.
missing_state_dict_tensor_params = {}
for object_name in original_params:
object_params = {}
missing_state_dict_tensor_params[object_name] = object_params
# Skip last item in state-dict
for k, v in list(original_params[object_name].items())[:-1]:
object_params[k] = v
with pytest.raises(RuntimeError):
# PyTorch load_state_dict throws RuntimeError if strict but
# invalid state-dict.
model.set_parameters(missing_state_dict_tensor_params, exact_match=True)
# Test that parameters do indeed change.
random_params = {}
for object_name, params in original_params.items():
# Do not randomize optimizer parameters (custom layout)
if "optim" in object_name:
random_params[object_name] = params
else:
# Again, skip the last item in state-dict
random_params[object_name] = OrderedDict(
(param_name, th.rand_like(param)) for param_name, param in list(params.items())[:-1]
)
# Update model parameters with the new random values
model.set_parameters(random_params, exact_match=False)
new_params = model.get_parameters()
# Check that all params except the final item in each state-dict are different.
for object_name in original_params:
# Skip optimizers (no valid comparison with just th.allclose)
if "optim" in object_name:
continue
# state-dicts use ordered dictionaries, so key order
# is guaranteed.
last_key = list(original_params[object_name].keys())[-1]
for k in original_params[object_name]:
if k == last_key:
# Should be same as before
assert th.allclose(
original_params[object_name][k], new_params[object_name][k]
), "Parameter changed despite not included in the loaded parameters."
else:
# Should be different
assert not th.allclose(
original_params[object_name][k], new_params[object_name][k]
), "Parameters did not change as expected."
params = new_params
# get selected actions
selected_actions, _ = model.predict(observations, deterministic=True)
# Check
model.save(tmp_path / "test_save.zip")
del model
# Check if the model loads as expected for every possible choice of device:
for device in ["auto", "cpu", "cuda"]:
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device)
# check if the model was loaded to the correct device
assert model.device.type == get_device(device).type
assert model.policy.device.type == get_device(device).type
# check if params are still the same after load
new_params = model.get_parameters()
# Check that all params are the same as before save load procedure now
for object_name in new_params:
# Skip optimizers (no valid comparison with just th.allclose)
if "optim" in object_name:
continue
for key in params[object_name]:
assert new_params[object_name][key].device.type == get_device(device).type
assert th.allclose(
params[object_name][key].to("cpu"), new_params[object_name][key].to("cpu")
), "Model parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# check if learn still works
model.learn(total_timesteps=1000, eval_freq=500)
del model
# clear file from os
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_set_env(model_class):
"""
Test if set_env function does work correct
:param model_class: (BaseAlgorithm) A RL model
"""
# use discrete for DQN
env = DummyVecEnv([lambda: select_env(model_class)])
env2 = DummyVecEnv([lambda: select_env(model_class)])
env3 = select_env(model_class)
# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
# learn
model.learn(total_timesteps=1000, eval_freq=500)
# change env
model.set_env(env2)
# learn again
model.learn(total_timesteps=1000, eval_freq=500)
# change env test wrapping
model.set_env(env3)
# learn again
model.learn(total_timesteps=1000, eval_freq=500)
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_exclude_include_saved_params(tmp_path, model_class):
"""
Test if exclude and include parameters of save() work
:param model_class: (BaseAlgorithm) A RL model
"""
env = DummyVecEnv([lambda: select_env(model_class)])
# create model, set verbose as 2, which is not standard
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=2)
# Check if exclude works
model.save(tmp_path / "test_save", exclude=["verbose"])
del model
model = model_class.load(str(tmp_path / "test_save.zip"))
# check if verbose was not saved
assert model.verbose != 2
# set verbose as something different then standard settings
model.verbose = 2
# Check if include works
model.save(tmp_path / "test_save", exclude=["verbose"], include=["verbose"])
del model
model = model_class.load(str(tmp_path / "test_save.zip"))
assert model.verbose == 2
# clear file from os
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("model_class", [TQC])
def test_save_load_replay_buffer(tmp_path, model_class):
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
model = model_class("MlpPolicy", select_env(model_class), buffer_size=1000)
model.learn(500)
old_replay_buffer = deepcopy(model.replay_buffer)
model.save_replay_buffer(path)
model.replay_buffer = None
model.load_replay_buffer(path)
assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations)
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
# test extending replay buffer
model.replay_buffer.extend(
old_replay_buffer.observations,
old_replay_buffer.observations,
old_replay_buffer.actions,
old_replay_buffer.rewards,
old_replay_buffer.dones,
)
@pytest.mark.parametrize("model_class", [TQC])
@pytest.mark.parametrize("optimize_memory_usage", [False, True])
def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
"""
When using memory efficient replay buffer,
a warning must be emitted when calling `.learn()`
multiple times.
See https://github.com/DLR-RM/stable-baselines3/issues/46
"""
# remove gym warnings
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
model = model_class(
"MlpPolicy",
select_env(model_class),
buffer_size=100,
optimize_memory_usage=optimize_memory_usage,
policy_kwargs=dict(net_arch=[64]),
learning_starts=10,
)
model.learn(150)
model.learn(150, reset_num_timesteps=False)
# Check that there is no warning
assert len(recwarn) == 0
model.learn(150)
if optimize_memory_usage:
assert len(recwarn) == 1
warning = recwarn.pop(UserWarning)
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
else:
assert len(recwarn) == 0
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
def test_save_load_policy(tmp_path, model_class, policy_str):
"""
Test saving and loading policy only.
:param model_class: (BaseAlgorithm) A RL model
:param policy_str: (str) Name of the policy.
"""
kwargs = {}
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
if model_class in [TQC]:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
env = DummyVecEnv([lambda: env])
# create model
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
model.learn(total_timesteps=500, eval_freq=250)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
policy = model.policy
policy_class = policy.__class__
actor, actor_class = None, None
if model_class in [TQC]:
actor = policy.actor
actor_class = actor.__class__
# Get dictionary of current parameters
params = deepcopy(policy.state_dict())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
# Update model parameters with the new random values
policy.load_state_dict(random_params)
new_params = policy.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
params = new_params
# get selected actions
selected_actions, _ = policy.predict(observations, deterministic=True)
# Should also work with the actor only
if actor is not None:
selected_actions_actor, _ = actor.predict(observations, deterministic=True)
# Save and load policy
policy.save(tmp_path / "policy.pkl")
# Save and load actor
if actor is not None:
actor.save(tmp_path / "actor.pkl")
del policy, actor
policy = policy_class.load(tmp_path / "policy.pkl")
if actor_class is not None:
actor = actor_class.load(tmp_path / "actor.pkl")
# check if params are still the same after load
new_params = policy.state_dict()
# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = policy.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
if actor_class is not None:
new_selected_actions_actor, _ = actor.predict(observations, deterministic=True)
assert np.allclose(selected_actions_actor, new_selected_actions_actor, 1e-4)
assert np.allclose(selected_actions_actor, new_selected_actions, 1e-4)
# clear file from os
os.remove(tmp_path / "policy.pkl")
if actor_class is not None:
os.remove(tmp_path / "actor.pkl")
@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
def test_open_file_str_pathlib(tmp_path, pathtype):
# check that suffix isn't added because we used open_path first
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
assert not record
# test custom suffix
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
assert not record
# test without suffix
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
assert not record
# test that a warning is raised when the path doesn't exist
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
assert len(record) == 0
with pytest.warns(None) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
assert len(record) == 1
fp = pathlib.Path(f"{tmp_path}/t2").open("w")
fp.write("rubbish")
fp.close()
# test that a warning is only raised when verbose = 0
with pytest.warns(None) as record:
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
assert len(record) == 1