Add TQC and base scripts
This commit is contained in:
parent
8c282a0288
commit
0d9f2e229e
|
|
@ -0,0 +1,11 @@
|
|||
[run]
|
||||
branch = False
|
||||
omit =
|
||||
tests/*
|
||||
setup.py
|
||||
|
||||
[report]
|
||||
exclude_lines =
|
||||
pragma: no cover
|
||||
raise NotImplementedError()
|
||||
if typing.TYPE_CHECKING:
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
---
|
||||
name: Issue Template
|
||||
about: How to create an issue for this repository
|
||||
|
||||
---
|
||||
|
||||
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||
Please post your question on [reddit](https://www.reddit.com/r/reinforcementlearning/) or [stack overflow](https://stackoverflow.com/) in that case.
|
||||
|
||||
If you have any questions, feel free to create an issue with the tag [question].
|
||||
If you wish to suggest an enhancement or feature request, add the tag [feature request].
|
||||
If you are submitting a bug report, please fill in the following details.
|
||||
|
||||
If your issue is related to a custom gym environment, please check it first using:
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
|
||||
env = CustomEnv(arg1, ...)
|
||||
# It will check your custom environment and output additional warnings if needed
|
||||
check_env(env)
|
||||
```
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Code example**
|
||||
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
|
||||
|
||||
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks)
|
||||
for both code and stack traces.
|
||||
|
||||
```python
|
||||
from stable_baselines3 import ...
|
||||
|
||||
```
|
||||
|
||||
```bash
|
||||
Traceback (most recent call last): File ...
|
||||
|
||||
```
|
||||
|
||||
**System Info**
|
||||
Describe the characteristic of your environment:
|
||||
* Describe how the library was installed (pip, docker, source, ...)
|
||||
* GPU models and configuration
|
||||
* Python version
|
||||
* PyTorch version
|
||||
* Gym version
|
||||
* Versions of any other relevant libraries
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
<!--- Provide a general summary of your changes in the Title above -->
|
||||
|
||||
## Description
|
||||
<!--- Describe your changes in detail -->
|
||||
|
||||
## Motivation and Context
|
||||
<!--- Why is this change required? What problem does it solve? -->
|
||||
<!--- If it fixes an open issue, please link to the issue here. -->
|
||||
<!--- You can use the syntax `closes #100` if this solves the issue #100 -->
|
||||
- [ ] I have raised an issue to propose this change ([required](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) for new features and bug fixes)
|
||||
|
||||
## Types of changes
|
||||
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
|
||||
- [ ] Documentation (update in the documentation)
|
||||
|
||||
## Checklist:
|
||||
<!--- Go over all the following points, and put an `x` in all the boxes that apply. -->
|
||||
<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
|
||||
- [ ] I've read the [CONTRIBUTION](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) guide (**required**)
|
||||
- [ ] I have updated the changelog accordingly (**required**).
|
||||
- [ ] My change requires a change to the documentation.
|
||||
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
|
||||
- [ ] I have updated the documentation accordingly.
|
||||
- [ ] I have reformatted the code using `make format` (**required**)
|
||||
- [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**)
|
||||
- [ ] I have ensured `make pytest` and `make type` both pass. (**required**)
|
||||
|
||||
|
||||
Note: we are using a maximum length of 127 characters per line
|
||||
|
||||
<!--- This Template is an edited version of the one from https://github.com/evilsocket/pwnagotchi/ -->
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
# Skip CI if [ci skip] in the commit message
|
||||
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7] # 3.8 not supported yet by pytype
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# cpu version of pytorch
|
||||
pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install stable_baselines3[extra,tests,docs]
|
||||
# Install master version
|
||||
pip install git+https://github.com/DLR-RM/stable-baselines3
|
||||
# Use headless version
|
||||
pip install opencv-python-headless
|
||||
# - name: Build the doc
|
||||
# run: |
|
||||
# make doc
|
||||
- name: Type check
|
||||
run: |
|
||||
make type
|
||||
- name: Check codestyle
|
||||
run: |
|
||||
make check-codestyle
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
make lint
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make pytest
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
*.swp
|
||||
*.pyc
|
||||
*.pkl
|
||||
*.py~
|
||||
*.bak
|
||||
.pytest_cache
|
||||
.DS_Store
|
||||
.idea
|
||||
.coverage
|
||||
.coverage.*
|
||||
__pycache__/
|
||||
_build/
|
||||
*.npz
|
||||
*.pth
|
||||
.pytype/
|
||||
git_rewrite_commit_history.sh
|
||||
|
||||
# Setuptools distribution and build folders.
|
||||
/dist/
|
||||
/build
|
||||
keys/
|
||||
|
||||
# Virtualenv
|
||||
/env
|
||||
/venv
|
||||
|
||||
|
||||
*.sublime-project
|
||||
*.sublime-workspace
|
||||
|
||||
.idea
|
||||
|
||||
logs/
|
||||
|
||||
.ipynb_checkpoints
|
||||
ghostdriver.log
|
||||
|
||||
htmlcov
|
||||
|
||||
junk
|
||||
src
|
||||
|
||||
*.egg-info
|
||||
.cache
|
||||
*.lprof
|
||||
*.prof
|
||||
|
||||
MUJOCO_LOG.TXT
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
## Release 0.9.0a2 (WIP)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
### New Features
|
||||
- Add TQC
|
||||
|
||||
### Bug fixes
|
||||
|
||||
### Documentation
|
||||
|
||||
### Other
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
SHELL=/bin/bash
|
||||
LINT_PATHS=sb3_contrib/ tests/ setup.py
|
||||
|
||||
pytest:
|
||||
./scripts/run_tests.sh
|
||||
|
||||
type:
|
||||
pytype
|
||||
|
||||
lint:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
# see https://lintlyci.github.io/Flake8Rules/
|
||||
flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings.
|
||||
flake8 ${LINT_PATHS} --count --exit-zero --statistics
|
||||
|
||||
format:
|
||||
# Sort imports
|
||||
isort ${LINT_PATHS}
|
||||
# Reformat using black
|
||||
black -l 127 ${LINT_PATHS}
|
||||
|
||||
check-codestyle:
|
||||
# Sort imports
|
||||
isort --check ${LINT_PATHS}
|
||||
# Reformat using black
|
||||
black --check -l 127 ${LINT_PATHS}
|
||||
|
||||
commit-checks: format type lint
|
||||
|
||||
|
||||
# PyPi package release
|
||||
release:
|
||||
python setup.py sdist
|
||||
python setup.py bdist_wheel
|
||||
twine upload dist/*
|
||||
|
||||
# Test PyPi package release
|
||||
test-release:
|
||||
python setup.py sdist
|
||||
python setup.py bdist_wheel
|
||||
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
|
||||
|
||||
.PHONY: lint format check-codestyle commit-checks
|
||||
22
README.md
22
README.md
|
|
@ -1,2 +1,20 @@
|
|||
# stable-baselines3-contrib
|
||||
Contrib package for Stable-Baselines3 - Experimental code
|
||||
[](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/actions)[](https://github.com/psf/black)
|
||||
|
||||
# Stable-Baselines3 - Contrib
|
||||
|
||||
Contrib package for [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) - Experimental code
|
||||
|
||||
**You need the master version of Stable Baselines3**:
|
||||
```
|
||||
pip install git+https://github.com/DLR-RM/stable-baselines3
|
||||
```
|
||||
|
||||
Implemented:
|
||||
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
```
|
||||
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
||||
```
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
import os
|
||||
|
||||
# from sb3_contrib.cmaes import CMAES
|
||||
from sb3_contrib.tqc import TQC
|
||||
|
||||
# Read version from file
|
||||
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
|
||||
with open(version_file, "r") as file_handler:
|
||||
__version__ = file_handler.read().strip()
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy
|
||||
from sb3_contrib.tqc.tqc import TQC
|
||||
|
|
@ -0,0 +1,480 @@
|
|||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
|
||||
from torch import nn as nn
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
LOG_STD_MAX = 2
|
||||
LOG_STD_MIN = -20
|
||||
|
||||
|
||||
class Actor(BasePolicy):
|
||||
"""
|
||||
Actor network (policy) for TQC.
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Obervation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param net_arch: ([int]) Network architecture
|
||||
:param features_extractor: (nn.Module) Network to extract features
|
||||
(a CNN when using images, a nn.Flatten() layer otherwise)
|
||||
:param features_dim: (int) Number of features
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE.
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
normalize_images: bool = True,
|
||||
):
|
||||
super(Actor, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor=features_extractor,
|
||||
normalize_images=normalize_images,
|
||||
squash_output=True,
|
||||
)
|
||||
|
||||
# Save arguments to re-create object at loading
|
||||
self.use_sde = use_sde
|
||||
self.sde_features_extractor = None
|
||||
self.sde_net_arch = sde_net_arch
|
||||
self.net_arch = net_arch
|
||||
self.features_dim = features_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.log_std_init = log_std_init
|
||||
self.sde_net_arch = sde_net_arch
|
||||
self.use_expln = use_expln
|
||||
self.full_std = full_std
|
||||
self.clip_mean = clip_mean
|
||||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
|
||||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
|
||||
|
||||
if self.use_sde:
|
||||
latent_sde_dim = last_layer_dim
|
||||
# Separate feature extractor for gSDE
|
||||
if sde_net_arch is not None:
|
||||
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
|
||||
features_dim, sde_net_arch, activation_fn
|
||||
)
|
||||
|
||||
self.action_dist = StateDependentNoiseDistribution(
|
||||
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
|
||||
)
|
||||
self.mu, self.log_std = self.action_dist.proba_distribution_net(
|
||||
latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init
|
||||
)
|
||||
# Avoid numerical issues by limiting the mean of the Gaussian
|
||||
# to be in [-clip_mean, clip_mean]
|
||||
if clip_mean > 0.0:
|
||||
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
|
||||
else:
|
||||
self.action_dist = SquashedDiagGaussianDistribution(action_dim)
|
||||
self.mu = nn.Linear(last_layer_dim, action_dim)
|
||||
self.log_std = nn.Linear(last_layer_dim, action_dim)
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
data.update(
|
||||
dict(
|
||||
net_arch=self.net_arch,
|
||||
features_dim=self.features_dim,
|
||||
activation_fn=self.activation_fn,
|
||||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
full_std=self.full_std,
|
||||
sde_net_arch=self.sde_net_arch,
|
||||
use_expln=self.use_expln,
|
||||
features_extractor=self.features_extractor,
|
||||
clip_mean=self.clip_mean,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
def get_std(self) -> th.Tensor:
|
||||
"""
|
||||
Retrieve the standard deviation of the action distribution.
|
||||
Only useful when using gSDE.
|
||||
It corresponds to ``th.exp(log_std)`` in the normal case,
|
||||
but is slightly different when using ``expln`` function
|
||||
(cf StateDependentNoiseDistribution doc).
|
||||
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
msg = "get_std() is only available when using gSDE"
|
||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
||||
return self.action_dist.get_std(self.log_std)
|
||||
|
||||
def reset_noise(self, batch_size: int = 1) -> None:
|
||||
"""
|
||||
Sample new weights for the exploration matrix, when using gSDE.
|
||||
|
||||
:param batch_size: (int)
|
||||
"""
|
||||
msg = "reset_noise() is only available when using gSDE"
|
||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
||||
|
||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
"""
|
||||
Get the parameters for the action distribution.
|
||||
|
||||
:param obs: (th.Tensor)
|
||||
:return: (Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]])
|
||||
Mean, standard deviation and optional keyword arguments.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
latent_pi = self.latent_pi(features)
|
||||
mean_actions = self.mu(latent_pi)
|
||||
|
||||
if self.use_sde:
|
||||
latent_sde = latent_pi
|
||||
if self.sde_features_extractor is not None:
|
||||
latent_sde = self.sde_features_extractor(features)
|
||||
return mean_actions, self.log_std, dict(latent_sde=latent_sde)
|
||||
# Unstructured exploration (Original implementation)
|
||||
log_std = self.log_std(latent_pi)
|
||||
# Original Implementation to cap the standard deviation
|
||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||
return mean_actions, log_std, {}
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# Note: the action is squashed
|
||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||
|
||||
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# return action and associated log prob
|
||||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self.forward(observation, deterministic)
|
||||
|
||||
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations. Only useful when using SDE.
|
||||
:param obs: (th.Tensor)
|
||||
:param actions: (th.Tensor)
|
||||
:return: (th.Tensor) log likelihood of taking those actions
|
||||
"""
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
self.action_dist.proba_distribution(mean_actions, log_std, **kwargs)
|
||||
return self.action_dist.log_prob(actions)
|
||||
|
||||
|
||||
class Critic(BaseModel):
|
||||
"""
|
||||
Critic network (q-value function) for TQC.
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Obervation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param net_arch: ([int]) Network architecture
|
||||
:param features_extractor: (nn.Module) Network to extract features
|
||||
(a CNN when using images, a nn.Flatten() layer otherwise)
|
||||
:param features_dim: (int) Number of features
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
n_quantiles: int = 25,
|
||||
n_critics: int = 2,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor=features_extractor,
|
||||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
|
||||
self.q_networks = []
|
||||
self.n_quantiles = n_quantiles
|
||||
self.n_critics = n_critics
|
||||
self.quantiles_total = n_quantiles * n_critics
|
||||
|
||||
for i in range(n_critics):
|
||||
qf_net = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
|
||||
qf_net = nn.Sequential(*qf_net)
|
||||
self.add_module(f"qf{i}", qf_net)
|
||||
self.q_networks.append(qf_net)
|
||||
|
||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
|
||||
# Learn the features extractor using the policy loss only
|
||||
# this is much faster
|
||||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
qvalue_input = th.cat([features, action], dim=1)
|
||||
quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1)
|
||||
return quantiles
|
||||
|
||||
|
||||
class TQCPolicy(BasePolicy):
|
||||
"""
|
||||
Policy class (with both actor and critic) for TQC.
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Observation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param lr_schedule: (callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_quantiles: int = 25,
|
||||
n_critics: int = 2,
|
||||
):
|
||||
super(TQCPolicy, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
squash_output=True,
|
||||
)
|
||||
|
||||
if net_arch is None:
|
||||
if features_extractor_class == FlattenExtractor:
|
||||
net_arch = [256, 256]
|
||||
else:
|
||||
net_arch = []
|
||||
|
||||
# Create shared features extractor
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.net_args = {
|
||||
"observation_space": self.observation_space,
|
||||
"action_space": self.action_space,
|
||||
"features_extractor": self.features_extractor,
|
||||
"features_dim": self.features_dim,
|
||||
"net_arch": self.net_arch,
|
||||
"activation_fn": self.activation_fn,
|
||||
"normalize_images": normalize_images,
|
||||
}
|
||||
self.actor_kwargs = self.net_args.copy()
|
||||
self.critic_kwargs = self.net_args.copy()
|
||||
sde_kwargs = {
|
||||
"use_sde": use_sde,
|
||||
"log_std_init": log_std_init,
|
||||
"sde_net_arch": sde_net_arch,
|
||||
"use_expln": use_expln,
|
||||
"clip_mean": clip_mean,
|
||||
}
|
||||
self.actor_kwargs.update(sde_kwargs)
|
||||
tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics}
|
||||
self.critic_kwargs.update(tqc_kwargs)
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
self.actor = self.make_actor()
|
||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
self.critic = self.make_critic()
|
||||
self.critic_target = self.make_critic()
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
# Do not optimize the shared feature extractor with the critic loss
|
||||
# otherwise, there are gradient computation issues
|
||||
# Another solution: having duplicated features extractor but requires more memory and computation
|
||||
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
|
||||
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
data.update(
|
||||
dict(
|
||||
net_arch=self.net_args["net_arch"],
|
||||
activation_fn=self.net_args["activation_fn"],
|
||||
use_sde=self.actor_kwargs["use_sde"],
|
||||
log_std_init=self.actor_kwargs["log_std_init"],
|
||||
sde_net_arch=self.actor_kwargs["sde_net_arch"],
|
||||
use_expln=self.actor_kwargs["use_expln"],
|
||||
clip_mean=self.actor_kwargs["clip_mean"],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
def reset_noise(self, batch_size: int = 1) -> None:
|
||||
"""
|
||||
Sample new weights for the exploration matrix, when using gSDE.
|
||||
|
||||
:param batch_size: (int)
|
||||
"""
|
||||
self.actor.reset_noise(batch_size=batch_size)
|
||||
|
||||
def make_actor(self) -> Actor:
|
||||
return Actor(**self.actor_kwargs).to(self.device)
|
||||
|
||||
def make_critic(self) -> Critic:
|
||||
return Critic(**self.critic_kwargs).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self.actor(observation, deterministic)
|
||||
|
||||
|
||||
MlpPolicy = TQCPolicy
|
||||
|
||||
|
||||
class CnnPolicy(TQCPolicy):
|
||||
"""
|
||||
Policy class (with both actor and critic) for TQC.
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Observation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param lr_schedule: (callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_quantiles: int = 25,
|
||||
n_critics: int = 2,
|
||||
):
|
||||
super(CnnPolicy, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
activation_fn,
|
||||
use_sde,
|
||||
log_std_init,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
clip_mean,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
n_quantiles,
|
||||
n_critics,
|
||||
)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
|
|
@ -0,0 +1,492 @@
|
|||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
from stable_baselines3.common.utils import polyak_update
|
||||
from tqdm import tqdm
|
||||
|
||||
from sb3_contrib.tqc.policies import TQCPolicy
|
||||
|
||||
|
||||
class TQC(OffPolicyAlgorithm):
|
||||
"""
|
||||
|
||||
Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics
|
||||
Paper: https://arxiv.org/abs/2005.04269
|
||||
|
||||
:param policy: (TQCPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: (float or callable) learning rate for adam optimizer,
|
||||
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
|
||||
it can be a function of the current progress remaining (from 1 to 0)
|
||||
:param buffer_size: (int) size of the replay buffer
|
||||
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
|
||||
:param batch_size: (int) Minibatch size for each gradient update
|
||||
:param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1)
|
||||
:param gamma: (float) the discount factor
|
||||
:param train_freq: (int) Update the model every ``train_freq`` steps.
|
||||
:param gradient_steps: (int) How many gradient update after each step
|
||||
:param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes.
|
||||
Note that this cannot be used at the same time as ``train_freq``
|
||||
:param action_noise: (ActionNoise) the action noise type (None by default), this can help
|
||||
for hard exploration problem. Cf common.noise for the different action noise type.
|
||||
:param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer
|
||||
at a cost of more complexity.
|
||||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||
:param ent_coef: (str or float) Entropy regularization coefficient. (Equivalent to
|
||||
inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
|
||||
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
|
||||
:param target_update_interval: (int) update the target network every ``target_network_update_freq``
|
||||
gradient steps.
|
||||
:param target_entropy: (str or float) target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
|
||||
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
:param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param seed: (int) Seed for the pseudo random generators
|
||||
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[TQCPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Callable] = 3e-4,
|
||||
buffer_size: int = int(1e6),
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
train_freq: int = 1,
|
||||
gradient_steps: int = 1,
|
||||
n_episodes_rollout: int = -1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
ent_coef: Union[str, float] = "auto",
|
||||
target_update_interval: int = 1,
|
||||
target_entropy: Union[str, float] = "auto",
|
||||
top_quantiles_to_drop_per_net: int = 2,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
use_sde_at_warmup: bool = False,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super(TQC, self).__init__(
|
||||
policy,
|
||||
env,
|
||||
TQCPolicy,
|
||||
learning_rate,
|
||||
replay_buffer_class,
|
||||
replay_buffer_kwargs,
|
||||
buffer_size,
|
||||
learning_starts,
|
||||
batch_size,
|
||||
tau,
|
||||
gamma,
|
||||
train_freq,
|
||||
gradient_steps,
|
||||
n_episodes_rollout,
|
||||
action_noise,
|
||||
policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=tensorboard_log,
|
||||
verbose=verbose,
|
||||
device=device,
|
||||
create_eval_env=create_eval_env,
|
||||
seed=seed,
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
)
|
||||
|
||||
self.target_entropy = target_entropy
|
||||
self.log_ent_coef = None # type: Optional[th.Tensor]
|
||||
# Entropy coefficient / Entropy temperature
|
||||
# Inverse of the reward scale
|
||||
self.ent_coef = ent_coef
|
||||
self.target_update_interval = target_update_interval
|
||||
self.ent_coef_optimizer = None
|
||||
self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
super(TQC, self)._setup_model()
|
||||
self._create_aliases()
|
||||
self.replay_buffer.actor = self.actor
|
||||
self.replay_buffer.ent_coef = 0.0
|
||||
|
||||
# Target entropy is used when learning the entropy coefficient
|
||||
if self.target_entropy == "auto":
|
||||
# automatically set target entropy if needed
|
||||
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
|
||||
else:
|
||||
# Force conversion
|
||||
# this will also throw an error for unexpected string
|
||||
self.target_entropy = float(self.target_entropy)
|
||||
|
||||
# The entropy coefficient or entropy can be learned automatically
|
||||
# see Automating Entropy Adjustment for Maximum Entropy RL section
|
||||
# of https://arxiv.org/abs/1812.05905
|
||||
if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"):
|
||||
# Default initial value of ent_coef when learned
|
||||
init_value = 1.0
|
||||
if "_" in self.ent_coef:
|
||||
init_value = float(self.ent_coef.split("_")[1])
|
||||
assert init_value > 0.0, "The initial value of ent_coef must be greater than 0"
|
||||
|
||||
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
|
||||
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
|
||||
self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
|
||||
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
|
||||
else:
|
||||
# Force conversion to float
|
||||
# this will throw an error if a malformed string (different from 'auto')
|
||||
# is passed
|
||||
self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)
|
||||
|
||||
def _create_aliases(self) -> None:
|
||||
self.actor = self.policy.actor
|
||||
self.critic = self.policy.critic
|
||||
self.critic_target = self.policy.critic_target
|
||||
|
||||
@staticmethod
|
||||
def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor:
|
||||
# batch x nets x quantiles x samples
|
||||
pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None]
|
||||
abs_pairwise_delta = th.abs(pairwise_delta)
|
||||
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
|
||||
|
||||
n_quantiles = quantiles.shape[2]
|
||||
tau = th.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles
|
||||
loss = (th.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean()
|
||||
return loss
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
|
||||
# Update optimizers learning rate
|
||||
optimizers = [self.actor.optimizer, self.critic.optimizer]
|
||||
if self.ent_coef_optimizer is not None:
|
||||
optimizers += [self.ent_coef_optimizer]
|
||||
|
||||
# Update learning rate according to lr schedule
|
||||
self._update_learning_rate(optimizers)
|
||||
|
||||
ent_coef_losses, ent_coefs = [], []
|
||||
actor_losses, critic_losses = [], []
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
||||
# We need to sample because `log_std` may have changed between two gradient steps
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
|
||||
# Action by the current actor for the sampled state
|
||||
actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
|
||||
log_prob = log_prob.reshape(-1, 1)
|
||||
|
||||
ent_coef_loss = None
|
||||
if self.ent_coef_optimizer is not None:
|
||||
# Important: detach the variable from the graph
|
||||
# so we don't change it with other losses
|
||||
# see https://github.com/rail-berkeley/softlearning/issues/60
|
||||
ent_coef = th.exp(self.log_ent_coef.detach())
|
||||
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
|
||||
ent_coef_losses.append(ent_coef_loss.item())
|
||||
else:
|
||||
ent_coef = self.ent_coef_tensor
|
||||
|
||||
ent_coefs.append(ent_coef.item())
|
||||
self.replay_buffer.ent_coef = ent_coef.item()
|
||||
|
||||
# Optimize entropy coefficient, also called
|
||||
# entropy temperature or alpha in the paper
|
||||
if ent_coef_loss is not None:
|
||||
self.ent_coef_optimizer.zero_grad()
|
||||
ent_coef_loss.backward()
|
||||
self.ent_coef_optimizer.step()
|
||||
|
||||
with th.no_grad():
|
||||
top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics
|
||||
# Select action according to policy
|
||||
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||
# Compute and cut quantiles at the next state
|
||||
# batch x nets x quantiles
|
||||
next_z = self.critic_target(replay_data.next_observations, next_actions)
|
||||
sorted_z, _ = th.sort(next_z.reshape(batch_size, -1))
|
||||
sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop]
|
||||
|
||||
target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1)
|
||||
# td error + entropy term
|
||||
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
|
||||
|
||||
# Get current Q estimates
|
||||
# using action from the replay buffer
|
||||
current_z = self.critic(replay_data.observations, replay_data.actions)
|
||||
# Compute critic loss
|
||||
critic_loss = self.quantile_huber_loss(current_z, q_backup)
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# Optimize the critic
|
||||
self.critic.optimizer.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic.optimizer.step()
|
||||
|
||||
# Compute actor loss
|
||||
qf_pi = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True)
|
||||
actor_loss = (ent_coef * log_prob - qf_pi).mean()
|
||||
actor_losses.append(actor_loss.item())
|
||||
|
||||
# Optimize the actor
|
||||
self.actor.optimizer.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor.optimizer.step()
|
||||
|
||||
# Update target networks
|
||||
if gradient_step % self.target_update_interval == 0:
|
||||
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
|
||||
|
||||
self._n_updates += gradient_steps
|
||||
|
||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
logger.record("train/ent_coef", np.mean(ent_coefs))
|
||||
logger.record("train/actor_loss", np.mean(actor_losses))
|
||||
logger.record("train/critic_loss", np.mean(critic_losses))
|
||||
if len(ent_coef_losses) > 0:
|
||||
logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||
|
||||
def pretrain(
|
||||
self,
|
||||
gradient_steps: int,
|
||||
batch_size: int = 64,
|
||||
n_action_samples: int = -1,
|
||||
target_update_interval: int = 1,
|
||||
tau: float = 0.005,
|
||||
strategy: str = "exp",
|
||||
reduce: str = "mean",
|
||||
exp_temperature: float = 1.0,
|
||||
off_policy_update_freq: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
Pretrain with Critic Regularized Regression (CRR)
|
||||
Paper: https://arxiv.org/abs/2006.15134
|
||||
"""
|
||||
# Update optimizers learning rate
|
||||
optimizers = [self.actor.optimizer, self.critic.optimizer]
|
||||
if self.ent_coef_optimizer is not None:
|
||||
optimizers += [self.ent_coef_optimizer]
|
||||
|
||||
# Update learning rate according to lr schedule
|
||||
self._update_learning_rate(optimizers)
|
||||
|
||||
actor_losses, critic_losses = [], []
|
||||
|
||||
for gradient_step in tqdm(range(gradient_steps)):
|
||||
|
||||
if off_policy_update_freq > 0 and gradient_step % off_policy_update_freq == 0:
|
||||
self.train(gradient_steps=1, batch_size=batch_size)
|
||||
continue
|
||||
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
||||
# We need to sample because `log_std` may have changed between two gradient steps
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
|
||||
# Action by the current actor for the sampled state
|
||||
_, log_prob = self.actor.action_log_prob(replay_data.observations)
|
||||
log_prob = log_prob.reshape(-1, 1)
|
||||
|
||||
ent_coef_loss = None
|
||||
if self.ent_coef_optimizer is not None:
|
||||
# Important: detach the variable from the graph
|
||||
# so we don't change it with other losses
|
||||
# see https://github.com/rail-berkeley/softlearning/issues/60
|
||||
ent_coef = th.exp(self.log_ent_coef.detach())
|
||||
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
|
||||
else:
|
||||
ent_coef = self.ent_coef_tensor
|
||||
|
||||
self.replay_buffer.ent_coef = ent_coef.item()
|
||||
|
||||
# Optimize entropy coefficient, also called
|
||||
# entropy temperature or alpha in the paper
|
||||
if ent_coef_loss is not None:
|
||||
self.ent_coef_optimizer.zero_grad()
|
||||
ent_coef_loss.backward()
|
||||
self.ent_coef_optimizer.step()
|
||||
|
||||
with th.no_grad():
|
||||
top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics
|
||||
# Select action according to policy
|
||||
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||
# Compute and cut quantiles at the next state
|
||||
# batch x nets x quantiles
|
||||
next_z = self.critic_target(replay_data.next_observations, next_actions)
|
||||
sorted_z, _ = th.sort(next_z.reshape(batch_size, -1))
|
||||
sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop]
|
||||
|
||||
target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1)
|
||||
# td error + entropy term
|
||||
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
|
||||
|
||||
# Get current Q estimates
|
||||
# using action from the replay buffer
|
||||
current_z = self.critic(replay_data.observations, replay_data.actions)
|
||||
# Compute critic loss
|
||||
critic_loss = self.quantile_huber_loss(current_z, q_backup)
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# Optimize the critic
|
||||
self.critic.optimizer.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic.optimizer.step()
|
||||
|
||||
if strategy == "bc":
|
||||
# Behavior cloning
|
||||
weight = 1
|
||||
else:
|
||||
# Tensor version: TODO: check that the reshape works as expected
|
||||
# cleaner but not faster on cpu for large batch size
|
||||
# with th.no_grad():
|
||||
# # Q-value for the action in the buffer
|
||||
# qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True)
|
||||
# # Create tensor to avoid loop
|
||||
# # Note: For SDE, we need to sample several matrices
|
||||
# obs_ = replay_data.observations.repeat(n_action_samples, 1)
|
||||
# if self.use_sde:
|
||||
# self.actor.reset_noise(batch_size * n_action_samples)
|
||||
# actions_pi, _ = self.actor.action_log_prob(obs_)
|
||||
# qf_pi = self.critic(obs_, actions_pi.detach()).mean(2).mean(1, keepdim=True)
|
||||
# # Agregate: reduce mean or reduce max
|
||||
# if reduce == "max":
|
||||
# _, qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).max(axis=0)
|
||||
# else:
|
||||
# qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).mean(axis=0)
|
||||
with th.no_grad():
|
||||
qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True)
|
||||
|
||||
# Use the mean (as done in AWAC, cf rlkit)
|
||||
if n_action_samples == -1:
|
||||
actions_pi = self.actor.forward(replay_data.observations, deterministic=True)
|
||||
qf_agg = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True)
|
||||
else:
|
||||
qf_agg = None
|
||||
for _ in range(n_action_samples):
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
actions_pi, _ = self.actor.action_log_prob(replay_data.observations)
|
||||
|
||||
qf_pi = self.critic(replay_data.observations, actions_pi.detach()).mean(2).mean(1, keepdim=True)
|
||||
if qf_agg is None:
|
||||
if reduce == "max":
|
||||
qf_agg = qf_pi
|
||||
else:
|
||||
qf_agg = qf_pi / n_action_samples
|
||||
else:
|
||||
if reduce == "max":
|
||||
qf_agg = th.max(qf_pi, qf_agg)
|
||||
else:
|
||||
qf_agg += qf_pi / n_action_samples
|
||||
|
||||
advantage = qf_buffer - qf_agg
|
||||
if strategy == "binary":
|
||||
# binary advantage
|
||||
weight = advantage > 0
|
||||
else:
|
||||
# exp advantage
|
||||
exp_clip = 20.0
|
||||
weight = th.clamp(th.exp(advantage / exp_temperature), 0.0, exp_clip)
|
||||
|
||||
# Log prob by the current actor for the sampled state and action
|
||||
log_prob = self.actor.evaluate_actions(replay_data.observations, replay_data.actions)
|
||||
log_prob = log_prob.reshape(-1, 1)
|
||||
|
||||
# weigthed regression loss (close to policy gradient loss)
|
||||
actor_loss = (-log_prob * weight).mean()
|
||||
# actor_loss = ((actions_pi - replay_data.actions * weight) ** 2).mean()
|
||||
actor_losses.append(actor_loss.item())
|
||||
|
||||
# Optimize the actor
|
||||
self.actor.optimizer.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor.optimizer.step()
|
||||
|
||||
# Update target networks
|
||||
if gradient_step % target_update_interval == 0:
|
||||
polyak_update(self.critic.parameters(), self.critic_target.parameters(), tau)
|
||||
|
||||
if self.use_sde:
|
||||
print(f"std={(self.actor.get_std()).mean().item()}")
|
||||
|
||||
def learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "TQC",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
|
||||
return super(TQC, self).learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
eval_env=eval_env,
|
||||
eval_freq=eval_freq,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
"""
|
||||
Returns the names of the parameters that should be excluded by default
|
||||
when saving the model.
|
||||
|
||||
:return: (List[str]) List of parameters that should be excluded from save
|
||||
"""
|
||||
# Exclude aliases
|
||||
return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
cf base class
|
||||
"""
|
||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||
saved_pytorch_variables = ["log_ent_coef"]
|
||||
if self.ent_coef_optimizer is not None:
|
||||
state_dicts.append("ent_coef_optimizer")
|
||||
else:
|
||||
saved_pytorch_variables.append("ent_coef_tensor")
|
||||
return state_dicts, saved_pytorch_variables
|
||||
|
|
@ -0,0 +1 @@
|
|||
0.9.0a2
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
#!/bin/bash
|
||||
python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
[metadata]
|
||||
# This includes the license file in the wheel.
|
||||
license_file = LICENSE
|
||||
|
||||
[tool:pytest]
|
||||
# Deterministic ordering for tests; useful for pytest-xdist.
|
||||
env =
|
||||
PYTHONHASHSEED=0
|
||||
filterwarnings =
|
||||
# Tensorboard/Tensorflow warnings
|
||||
ignore:inspect.getargspec:DeprecationWarning:tensorflow
|
||||
ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning
|
||||
ignore:The binary mode of fromstring is deprecated:DeprecationWarning
|
||||
ignore::FutureWarning:tensorflow
|
||||
# Gym warnings
|
||||
ignore:Parameters to load are deprecated.:DeprecationWarning
|
||||
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
|
||||
ignore::UserWarning:gym
|
||||
|
||||
[pytype]
|
||||
inputs = sb3_contrib
|
||||
|
||||
[flake8]
|
||||
ignore = W503,W504,E203,E231 # line breaks before and after binary operators
|
||||
# Ignore import not used when aliases are defined
|
||||
per-file-ignores =
|
||||
./sb3_contrib/__init__.py:F401
|
||||
./sb3_contrib/tqc/__init__.py:F401
|
||||
exclude =
|
||||
# No need to traverse our git directory
|
||||
.git,
|
||||
# There's no value in checking cache directories
|
||||
__pycache__,
|
||||
# Don't check the doc
|
||||
docs/
|
||||
# This contains our built documentation
|
||||
build,
|
||||
# This contains builds of flake8 that we don't want to check
|
||||
dist
|
||||
*.egg-info
|
||||
max-complexity = 15
|
||||
# The GitHub editor is 127 chars wide
|
||||
max-line-length = 127
|
||||
|
||||
[isort]
|
||||
profile = black
|
||||
line_length = 127
|
||||
src_paths = sb3_contrib
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
import os
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
with open(os.path.join("sb3_contrib", "version.txt"), "r") as file_handler:
|
||||
__version__ = file_handler.read().strip()
|
||||
|
||||
|
||||
long_description = """
|
||||
|
||||
# Stable-Baselines3 - Contrib
|
||||
|
||||
Contrib package for [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) - Experimental code
|
||||
|
||||
Implemented:
|
||||
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
||||
|
||||
""" # noqa:E501
|
||||
|
||||
|
||||
setup(
|
||||
name="sb3_contrib",
|
||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3[tests,docs]",
|
||||
# Enable CMA
|
||||
# "cma",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
url="https://github.com/Stable-Baselines-Team/stable-baselines3-contrib",
|
||||
author_email="antonin.raffin@dlr.de",
|
||||
keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
|
||||
"gym openai stable baselines toolbox python data-science",
|
||||
license="MIT",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
version=__version__,
|
||||
)
|
||||
|
||||
# python setup.py sdist
|
||||
# python setup.py bdist_wheel
|
||||
# twine upload --repository-url https://test.pypi.org/legacy/ dist/*
|
||||
# twine upload dist/*
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
import pytest
|
||||
|
||||
from sb3_contrib import TQC
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
||||
def test_tqc(ent_coef):
|
||||
model = TQC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
ent_coef=ent_coef,
|
||||
)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_critics", [1, 3])
|
||||
def test_n_critics(n_critics):
|
||||
# Test TQC with different number of critics
|
||||
model = TQC(
|
||||
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1
|
||||
)
|
||||
model.learn(total_timesteps=1000)
|
||||
|
||||
|
||||
# "CartPole-v1"
|
||||
# @pytest.mark.parametrize("env_id", ["MountainCarContinuous-v0"])
|
||||
# def test_cmaes(env_id):
|
||||
# if CMAES is None:
|
||||
# return
|
||||
# model = CMAES("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[64]), verbose=1, create_eval_env=True)
|
||||
# model.learn(total_timesteps=50000, eval_freq=10000)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strategy", ["exp", "bc", "binary"])
|
||||
@pytest.mark.parametrize("reduce", ["mean", "max"])
|
||||
def test_crr(tmp_path, strategy, reduce):
|
||||
model = TQC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
policy_kwargs=dict(net_arch=[64]),
|
||||
learning_starts=1000,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
action_noise=None,
|
||||
use_sde=False,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=1000, eval_freq=0)
|
||||
for n_action_samples in [1, 2, -1]:
|
||||
model.pretrain(gradient_steps=32, batch_size=32, n_action_samples=n_action_samples, strategy=strategy, reduce=reduce)
|
||||
|
|
@ -0,0 +1,426 @@
|
|||
import os
|
||||
import pathlib
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
from stable_baselines3 import DQN
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
|
||||
from stable_baselines3.common.utils import get_device
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from sb3_contrib import TQC
|
||||
|
||||
MODEL_LIST = [TQC]
|
||||
|
||||
|
||||
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
||||
"""
|
||||
Selects an environment with the correct action space as DQN only supports discrete action space
|
||||
"""
|
||||
if model_class == DQN:
|
||||
return IdentityEnv(10)
|
||||
else:
|
||||
return IdentityEnvBox(10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_save_load(tmp_path, model_class):
|
||||
"""
|
||||
Test if 'save' and 'load' saves and loads model correctly
|
||||
and if 'get_parameters' and 'set_parameters' and work correctly.
|
||||
|
||||
''warning does not test function of optimizer parameter load
|
||||
|
||||
:param model_class: (BaseAlgorithm) A RL model
|
||||
"""
|
||||
|
||||
env = DummyVecEnv([lambda: select_env(model_class)])
|
||||
|
||||
# create model
|
||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
env.reset()
|
||||
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||
|
||||
# Get parameters of different objects
|
||||
# deepcopy to avoid referencing to tensors we are about to modify
|
||||
original_params = deepcopy(model.get_parameters())
|
||||
|
||||
# Test different error cases of set_parameters.
|
||||
# Test that invalid object names throw errors
|
||||
invalid_object_params = deepcopy(original_params)
|
||||
invalid_object_params["I_should_not_be_a_valid_object"] = "and_I_am_an_invalid_tensor"
|
||||
with pytest.raises(ValueError):
|
||||
model.set_parameters(invalid_object_params, exact_match=True)
|
||||
with pytest.raises(ValueError):
|
||||
model.set_parameters(invalid_object_params, exact_match=False)
|
||||
|
||||
# Test that exact_match catches when something was missed.
|
||||
missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1])
|
||||
with pytest.raises(ValueError):
|
||||
model.set_parameters(missing_object_params, exact_match=True)
|
||||
|
||||
# Test that exact_match catches when something inside state-dict
|
||||
# is missing but we have exact_match.
|
||||
missing_state_dict_tensor_params = {}
|
||||
for object_name in original_params:
|
||||
object_params = {}
|
||||
missing_state_dict_tensor_params[object_name] = object_params
|
||||
# Skip last item in state-dict
|
||||
for k, v in list(original_params[object_name].items())[:-1]:
|
||||
object_params[k] = v
|
||||
with pytest.raises(RuntimeError):
|
||||
# PyTorch load_state_dict throws RuntimeError if strict but
|
||||
# invalid state-dict.
|
||||
model.set_parameters(missing_state_dict_tensor_params, exact_match=True)
|
||||
|
||||
# Test that parameters do indeed change.
|
||||
random_params = {}
|
||||
for object_name, params in original_params.items():
|
||||
# Do not randomize optimizer parameters (custom layout)
|
||||
if "optim" in object_name:
|
||||
random_params[object_name] = params
|
||||
else:
|
||||
# Again, skip the last item in state-dict
|
||||
random_params[object_name] = OrderedDict(
|
||||
(param_name, th.rand_like(param)) for param_name, param in list(params.items())[:-1]
|
||||
)
|
||||
|
||||
# Update model parameters with the new random values
|
||||
model.set_parameters(random_params, exact_match=False)
|
||||
|
||||
new_params = model.get_parameters()
|
||||
# Check that all params except the final item in each state-dict are different.
|
||||
for object_name in original_params:
|
||||
# Skip optimizers (no valid comparison with just th.allclose)
|
||||
if "optim" in object_name:
|
||||
continue
|
||||
# state-dicts use ordered dictionaries, so key order
|
||||
# is guaranteed.
|
||||
last_key = list(original_params[object_name].keys())[-1]
|
||||
for k in original_params[object_name]:
|
||||
if k == last_key:
|
||||
# Should be same as before
|
||||
assert th.allclose(
|
||||
original_params[object_name][k], new_params[object_name][k]
|
||||
), "Parameter changed despite not included in the loaded parameters."
|
||||
else:
|
||||
# Should be different
|
||||
assert not th.allclose(
|
||||
original_params[object_name][k], new_params[object_name][k]
|
||||
), "Parameters did not change as expected."
|
||||
|
||||
params = new_params
|
||||
|
||||
# get selected actions
|
||||
selected_actions, _ = model.predict(observations, deterministic=True)
|
||||
|
||||
# Check
|
||||
model.save(tmp_path / "test_save.zip")
|
||||
del model
|
||||
|
||||
# Check if the model loads as expected for every possible choice of device:
|
||||
for device in ["auto", "cpu", "cuda"]:
|
||||
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device)
|
||||
|
||||
# check if the model was loaded to the correct device
|
||||
assert model.device.type == get_device(device).type
|
||||
assert model.policy.device.type == get_device(device).type
|
||||
|
||||
# check if params are still the same after load
|
||||
new_params = model.get_parameters()
|
||||
|
||||
# Check that all params are the same as before save load procedure now
|
||||
for object_name in new_params:
|
||||
# Skip optimizers (no valid comparison with just th.allclose)
|
||||
if "optim" in object_name:
|
||||
continue
|
||||
for key in params[object_name]:
|
||||
assert new_params[object_name][key].device.type == get_device(device).type
|
||||
assert th.allclose(
|
||||
params[object_name][key].to("cpu"), new_params[object_name][key].to("cpu")
|
||||
), "Model parameters not the same after save and load."
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions, _ = model.predict(observations, deterministic=True)
|
||||
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||
|
||||
# check if learn still works
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
del model
|
||||
|
||||
# clear file from os
|
||||
os.remove(tmp_path / "test_save.zip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_set_env(model_class):
|
||||
"""
|
||||
Test if set_env function does work correct
|
||||
:param model_class: (BaseAlgorithm) A RL model
|
||||
"""
|
||||
|
||||
# use discrete for DQN
|
||||
env = DummyVecEnv([lambda: select_env(model_class)])
|
||||
env2 = DummyVecEnv([lambda: select_env(model_class)])
|
||||
env3 = select_env(model_class)
|
||||
|
||||
# create model
|
||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
|
||||
# learn
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
# change env
|
||||
model.set_env(env2)
|
||||
# learn again
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
# change env test wrapping
|
||||
model.set_env(env3)
|
||||
# learn again
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_exclude_include_saved_params(tmp_path, model_class):
|
||||
"""
|
||||
Test if exclude and include parameters of save() work
|
||||
|
||||
:param model_class: (BaseAlgorithm) A RL model
|
||||
"""
|
||||
env = DummyVecEnv([lambda: select_env(model_class)])
|
||||
|
||||
# create model, set verbose as 2, which is not standard
|
||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=2)
|
||||
|
||||
# Check if exclude works
|
||||
model.save(tmp_path / "test_save", exclude=["verbose"])
|
||||
del model
|
||||
model = model_class.load(str(tmp_path / "test_save.zip"))
|
||||
# check if verbose was not saved
|
||||
assert model.verbose != 2
|
||||
|
||||
# set verbose as something different then standard settings
|
||||
model.verbose = 2
|
||||
# Check if include works
|
||||
model.save(tmp_path / "test_save", exclude=["verbose"], include=["verbose"])
|
||||
del model
|
||||
model = model_class.load(str(tmp_path / "test_save.zip"))
|
||||
assert model.verbose == 2
|
||||
|
||||
# clear file from os
|
||||
os.remove(tmp_path / "test_save.zip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TQC])
|
||||
def test_save_load_replay_buffer(tmp_path, model_class):
|
||||
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
|
||||
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
|
||||
model = model_class("MlpPolicy", select_env(model_class), buffer_size=1000)
|
||||
model.learn(500)
|
||||
old_replay_buffer = deepcopy(model.replay_buffer)
|
||||
model.save_replay_buffer(path)
|
||||
model.replay_buffer = None
|
||||
model.load_replay_buffer(path)
|
||||
|
||||
assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations)
|
||||
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
|
||||
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
|
||||
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
|
||||
|
||||
# test extending replay buffer
|
||||
model.replay_buffer.extend(
|
||||
old_replay_buffer.observations,
|
||||
old_replay_buffer.observations,
|
||||
old_replay_buffer.actions,
|
||||
old_replay_buffer.rewards,
|
||||
old_replay_buffer.dones,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TQC])
|
||||
@pytest.mark.parametrize("optimize_memory_usage", [False, True])
|
||||
def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
|
||||
"""
|
||||
When using memory efficient replay buffer,
|
||||
a warning must be emitted when calling `.learn()`
|
||||
multiple times.
|
||||
See https://github.com/DLR-RM/stable-baselines3/issues/46
|
||||
"""
|
||||
# remove gym warnings
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
|
||||
|
||||
model = model_class(
|
||||
"MlpPolicy",
|
||||
select_env(model_class),
|
||||
buffer_size=100,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
policy_kwargs=dict(net_arch=[64]),
|
||||
learning_starts=10,
|
||||
)
|
||||
|
||||
model.learn(150)
|
||||
|
||||
model.learn(150, reset_num_timesteps=False)
|
||||
|
||||
# Check that there is no warning
|
||||
assert len(recwarn) == 0
|
||||
|
||||
model.learn(150)
|
||||
|
||||
if optimize_memory_usage:
|
||||
assert len(recwarn) == 1
|
||||
warning = recwarn.pop(UserWarning)
|
||||
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
|
||||
else:
|
||||
assert len(recwarn) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
|
||||
def test_save_load_policy(tmp_path, model_class, policy_str):
|
||||
"""
|
||||
Test saving and loading policy only.
|
||||
|
||||
:param model_class: (BaseAlgorithm) A RL model
|
||||
:param policy_str: (str) Name of the policy.
|
||||
"""
|
||||
kwargs = {}
|
||||
if policy_str == "MlpPolicy":
|
||||
env = select_env(model_class)
|
||||
else:
|
||||
if model_class in [TQC]:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(buffer_size=250)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
|
||||
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
# create model
|
||||
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
env.reset()
|
||||
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||
|
||||
policy = model.policy
|
||||
policy_class = policy.__class__
|
||||
actor, actor_class = None, None
|
||||
if model_class in [TQC]:
|
||||
actor = policy.actor
|
||||
actor_class = actor.__class__
|
||||
|
||||
# Get dictionary of current parameters
|
||||
params = deepcopy(policy.state_dict())
|
||||
|
||||
# Modify all parameters to be random values
|
||||
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||
|
||||
# Update model parameters with the new random values
|
||||
policy.load_state_dict(random_params)
|
||||
|
||||
new_params = policy.state_dict()
|
||||
# Check that all params are different now
|
||||
for k in params:
|
||||
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
|
||||
|
||||
params = new_params
|
||||
|
||||
# get selected actions
|
||||
selected_actions, _ = policy.predict(observations, deterministic=True)
|
||||
# Should also work with the actor only
|
||||
if actor is not None:
|
||||
selected_actions_actor, _ = actor.predict(observations, deterministic=True)
|
||||
|
||||
# Save and load policy
|
||||
policy.save(tmp_path / "policy.pkl")
|
||||
# Save and load actor
|
||||
if actor is not None:
|
||||
actor.save(tmp_path / "actor.pkl")
|
||||
|
||||
del policy, actor
|
||||
|
||||
policy = policy_class.load(tmp_path / "policy.pkl")
|
||||
if actor_class is not None:
|
||||
actor = actor_class.load(tmp_path / "actor.pkl")
|
||||
|
||||
# check if params are still the same after load
|
||||
new_params = policy.state_dict()
|
||||
|
||||
# Check that all params are the same as before save load procedure now
|
||||
for key in params:
|
||||
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions, _ = policy.predict(observations, deterministic=True)
|
||||
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||
|
||||
if actor_class is not None:
|
||||
new_selected_actions_actor, _ = actor.predict(observations, deterministic=True)
|
||||
assert np.allclose(selected_actions_actor, new_selected_actions_actor, 1e-4)
|
||||
assert np.allclose(selected_actions_actor, new_selected_actions, 1e-4)
|
||||
|
||||
# clear file from os
|
||||
os.remove(tmp_path / "policy.pkl")
|
||||
if actor_class is not None:
|
||||
os.remove(tmp_path / "actor.pkl")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
|
||||
def test_open_file_str_pathlib(tmp_path, pathtype):
|
||||
# check that suffix isn't added because we used open_path first
|
||||
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
|
||||
assert not record
|
||||
|
||||
# test custom suffix
|
||||
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
|
||||
assert not record
|
||||
|
||||
# test without suffix
|
||||
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
|
||||
assert not record
|
||||
|
||||
# test that a warning is raised when the path doesn't exist
|
||||
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
|
||||
assert len(record) == 0
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
|
||||
assert len(record) == 1
|
||||
|
||||
fp = pathlib.Path(f"{tmp_path}/t2").open("w")
|
||||
fp.write("rubbish")
|
||||
fp.close()
|
||||
# test that a warning is only raised when verbose = 0
|
||||
with pytest.warns(None) as record:
|
||||
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
|
||||
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
|
||||
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
|
||||
assert len(record) == 1
|
||||
Loading…
Reference in New Issue