diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..a8fc2af --- /dev/null +++ b/.coveragerc @@ -0,0 +1,11 @@ +[run] +branch = False +omit = + tests/* + setup.py + +[report] +exclude_lines = + pragma: no cover + raise NotImplementedError() + if typing.TYPE_CHECKING: diff --git a/.github/ISSUE_TEMPLATE/issue-template.md b/.github/ISSUE_TEMPLATE/issue-template.md new file mode 100644 index 0000000..1268278 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/issue-template.md @@ -0,0 +1,53 @@ +--- +name: Issue Template +about: How to create an issue for this repository + +--- + +**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. +Please post your question on [reddit](https://www.reddit.com/r/reinforcementlearning/) or [stack overflow](https://stackoverflow.com/) in that case. + +If you have any questions, feel free to create an issue with the tag [question]. +If you wish to suggest an enhancement or feature request, add the tag [feature request]. +If you are submitting a bug report, please fill in the following details. + +If your issue is related to a custom gym environment, please check it first using: + +```python +from stable_baselines3.common.env_checker import check_env + +env = CustomEnv(arg1, ...) +# It will check your custom environment and output additional warnings if needed +check_env(env) +``` + +**Describe the bug** +A clear and concise description of what the bug is. + +**Code example** +Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. + +Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) +for both code and stack traces. + +```python +from stable_baselines3 import ... + +``` + +```bash +Traceback (most recent call last): File ... + +``` + +**System Info** +Describe the characteristic of your environment: + * Describe how the library was installed (pip, docker, source, ...) + * GPU models and configuration + * Python version + * PyTorch version + * Gym version + * Versions of any other relevant libraries + +**Additional context** +Add any other context about the problem here. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..2b35d61 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,34 @@ + + +## Description + + +## Motivation and Context + + + +- [ ] I have raised an issue to propose this change ([required](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) for new features and bug fixes) + +## Types of changes + +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to change) +- [ ] Documentation (update in the documentation) + +## Checklist: + + +- [ ] I've read the [CONTRIBUTION](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) guide (**required**) +- [ ] I have updated the changelog accordingly (**required**). +- [ ] My change requires a change to the documentation. +- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*). +- [ ] I have updated the documentation accordingly. +- [ ] I have reformatted the code using `make format` (**required**) +- [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**) +- [ ] I have ensured `make pytest` and `make type` both pass. (**required**) + + +Note: we are using a maximum length of 127 characters per line + + diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..1342908 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,51 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: CI + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + # Skip CI if [ci skip] in the commit message + if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')" + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.6, 3.7] # 3.8 not supported yet by pytype + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install stable_baselines3[extra,tests,docs] + # Install master version + pip install git+https://github.com/DLR-RM/stable-baselines3 + # Use headless version + pip install opencv-python-headless + # - name: Build the doc + # run: | + # make doc + - name: Type check + run: | + make type + - name: Check codestyle + run: | + make check-codestyle + - name: Lint with flake8 + run: | + make lint + - name: Test with pytest + run: | + make pytest diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9f54889 --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +*.swp +*.pyc +*.pkl +*.py~ +*.bak +.pytest_cache +.DS_Store +.idea +.coverage +.coverage.* +__pycache__/ +_build/ +*.npz +*.pth +.pytype/ +git_rewrite_commit_history.sh + +# Setuptools distribution and build folders. +/dist/ +/build +keys/ + +# Virtualenv +/env +/venv + + +*.sublime-project +*.sublime-workspace + +.idea + +logs/ + +.ipynb_checkpoints +ghostdriver.log + +htmlcov + +junk +src + +*.egg-info +.cache +*.lprof +*.prof + +MUJOCO_LOG.TXT diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..c14db75 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,12 @@ +## Release 0.9.0a2 (WIP) + +### Breaking Changes + +### New Features +- Add TQC + +### Bug fixes + +### Documentation + +### Other diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..639d266 --- /dev/null +++ b/Makefile @@ -0,0 +1,44 @@ +SHELL=/bin/bash +LINT_PATHS=sb3_contrib/ tests/ setup.py + +pytest: + ./scripts/run_tests.sh + +type: + pytype + +lint: + # stop the build if there are Python syntax errors or undefined names + # see https://lintlyci.github.io/Flake8Rules/ + flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. + flake8 ${LINT_PATHS} --count --exit-zero --statistics + +format: + # Sort imports + isort ${LINT_PATHS} + # Reformat using black + black -l 127 ${LINT_PATHS} + +check-codestyle: + # Sort imports + isort --check ${LINT_PATHS} + # Reformat using black + black --check -l 127 ${LINT_PATHS} + +commit-checks: format type lint + + +# PyPi package release +release: + python setup.py sdist + python setup.py bdist_wheel + twine upload dist/* + +# Test PyPi package release +test-release: + python setup.py sdist + python setup.py bdist_wheel + twine upload --repository-url https://test.pypi.org/legacy/ dist/* + +.PHONY: lint format check-codestyle commit-checks diff --git a/README.md b/README.md index 06ea91b..3bcc589 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,20 @@ -# stable-baselines3-contrib -Contrib package for Stable-Baselines3 - Experimental code +[![CI](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/workflows/ci/badge.svg)](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/actions)[![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + +# Stable-Baselines3 - Contrib + +Contrib package for [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) - Experimental code + +**You need the master version of Stable Baselines3**: +``` +pip install git+https://github.com/DLR-RM/stable-baselines3 +``` + +Implemented: +- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) + + +## Installation + +``` +pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib +``` diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py new file mode 100644 index 0000000..c66cfec --- /dev/null +++ b/sb3_contrib/__init__.py @@ -0,0 +1,9 @@ +import os + +# from sb3_contrib.cmaes import CMAES +from sb3_contrib.tqc import TQC + +# Read version from file +version_file = os.path.join(os.path.dirname(__file__), "version.txt") +with open(version_file, "r") as file_handler: + __version__ = file_handler.read().strip() diff --git a/sb3_contrib/py.typed b/sb3_contrib/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/sb3_contrib/tqc/__init__.py b/sb3_contrib/tqc/__init__.py new file mode 100644 index 0000000..ef4b6bd --- /dev/null +++ b/sb3_contrib/tqc/__init__.py @@ -0,0 +1,2 @@ +from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy +from sb3_contrib.tqc.tqc import TQC diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py new file mode 100644 index 0000000..e58cb11 --- /dev/null +++ b/sb3_contrib/tqc/policies.py @@ -0,0 +1,480 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Type + +import gym +import torch as th +from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution +from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy +from stable_baselines3.common.preprocessing import get_action_dim +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp +from torch import nn as nn + +# CAP the standard deviation of the actor +LOG_STD_MAX = 2 +LOG_STD_MIN = -20 + + +class Actor(BasePolicy): + """ + Actor network (policy) for TQC. + + :param observation_space: (gym.spaces.Space) Obervation space + :param action_space: (gym.spaces.Space) Action space + :param net_arch: ([int]) Network architecture + :param features_extractor: (nn.Module) Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param features_dim: (int) Number of features + :param activation_fn: (Type[nn.Module]) Activation function + :param use_sde: (bool) Whether to use State Dependent Exploration or not + :param log_std_init: (float) Initial value for the log standard deviation + :param full_std: (bool) Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE. + :param sde_net_arch: ([int]) Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. + :param normalize_images: (bool) Whether to normalize images or not, + dividing by 255.0 (True by default) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + normalize_images: bool = True, + ): + super(Actor, self).__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + squash_output=True, + ) + + # Save arguments to re-create object at loading + self.use_sde = use_sde + self.sde_features_extractor = None + self.sde_net_arch = sde_net_arch + self.net_arch = net_arch + self.features_dim = features_dim + self.activation_fn = activation_fn + self.log_std_init = log_std_init + self.sde_net_arch = sde_net_arch + self.use_expln = use_expln + self.full_std = full_std + self.clip_mean = clip_mean + + action_dim = get_action_dim(self.action_space) + latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn) + self.latent_pi = nn.Sequential(*latent_pi_net) + last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim + + if self.use_sde: + latent_sde_dim = last_layer_dim + # Separate feature extractor for gSDE + if sde_net_arch is not None: + self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor( + features_dim, sde_net_arch, activation_fn + ) + + self.action_dist = StateDependentNoiseDistribution( + action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True + ) + self.mu, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init + ) + # Avoid numerical issues by limiting the mean of the Gaussian + # to be in [-clip_mean, clip_mean] + if clip_mean > 0.0: + self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean)) + else: + self.action_dist = SquashedDiagGaussianDistribution(action_dim) + self.mu = nn.Linear(last_layer_dim, action_dim) + self.log_std = nn.Linear(last_layer_dim, action_dim) + + def _get_data(self) -> Dict[str, Any]: + data = super()._get_data() + + data.update( + dict( + net_arch=self.net_arch, + features_dim=self.features_dim, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + full_std=self.full_std, + sde_net_arch=self.sde_net_arch, + use_expln=self.use_expln, + features_extractor=self.features_extractor, + clip_mean=self.clip_mean, + ) + ) + return data + + def get_std(self) -> th.Tensor: + """ + Retrieve the standard deviation of the action distribution. + Only useful when using gSDE. + It corresponds to ``th.exp(log_std)`` in the normal case, + but is slightly different when using ``expln`` function + (cf StateDependentNoiseDistribution doc). + + :return: (th.Tensor) + """ + msg = "get_std() is only available when using gSDE" + assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg + return self.action_dist.get_std(self.log_std) + + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + + :param batch_size: (int) + """ + msg = "reset_noise() is only available when using gSDE" + assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg + self.action_dist.sample_weights(self.log_std, batch_size=batch_size) + + def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: + """ + Get the parameters for the action distribution. + + :param obs: (th.Tensor) + :return: (Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]) + Mean, standard deviation and optional keyword arguments. + """ + features = self.extract_features(obs) + latent_pi = self.latent_pi(features) + mean_actions = self.mu(latent_pi) + + if self.use_sde: + latent_sde = latent_pi + if self.sde_features_extractor is not None: + latent_sde = self.sde_features_extractor(features) + return mean_actions, self.log_std, dict(latent_sde=latent_sde) + # Unstructured exploration (Original implementation) + log_std = self.log_std(latent_pi) + # Original Implementation to cap the standard deviation + log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) + return mean_actions, log_std, {} + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + mean_actions, log_std, kwargs = self.get_action_dist_params(obs) + # Note: the action is squashed + return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) + + def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + mean_actions, log_std, kwargs = self.get_action_dist_params(obs) + # return action and associated log prob + return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self.forward(observation, deterministic) + + def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: + """ + Evaluate actions according to the current policy, + given the observations. Only useful when using SDE. + :param obs: (th.Tensor) + :param actions: (th.Tensor) + :return: (th.Tensor) log likelihood of taking those actions + """ + mean_actions, log_std, kwargs = self.get_action_dist_params(obs) + self.action_dist.proba_distribution(mean_actions, log_std, **kwargs) + return self.action_dist.log_prob(actions) + + +class Critic(BaseModel): + """ + Critic network (q-value function) for TQC. + + :param observation_space: (gym.spaces.Space) Obervation space + :param action_space: (gym.spaces.Space) Action space + :param net_arch: ([int]) Network architecture + :param features_extractor: (nn.Module) Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param features_dim: (int) Number of features + :param activation_fn: (Type[nn.Module]) Activation function + :param normalize_images: (bool) Whether to normalize images or not, + dividing by 255.0 (True by default) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + n_quantiles: int = 25, + n_critics: int = 2, + ): + super().__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + ) + + action_dim = get_action_dim(self.action_space) + + self.q_networks = [] + self.n_quantiles = n_quantiles + self.n_critics = n_critics + self.quantiles_total = n_quantiles * n_critics + + for i in range(n_critics): + qf_net = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn) + qf_net = nn.Sequential(*qf_net) + self.add_module(f"qf{i}", qf_net) + self.q_networks.append(qf_net) + + def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]: + # Learn the features extractor using the policy loss only + # this is much faster + with th.no_grad(): + features = self.extract_features(obs) + qvalue_input = th.cat([features, action], dim=1) + quantiles = th.stack(tuple(qf(qvalue_input) for qf in self.q_networks), dim=1) + return quantiles + + +class TQCPolicy(BasePolicy): + """ + Policy class (with both actor and critic) for TQC. + + :param observation_space: (gym.spaces.Space) Observation space + :param action_space: (gym.spaces.Space) Action space + :param lr_schedule: (callable) Learning rate schedule (could be constant) + :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. + :param activation_fn: (Type[nn.Module]) Activation function + :param use_sde: (bool) Whether to use State Dependent Exploration or not + :param log_std_init: (float) Initial value for the log standard deviation + :param sde_net_arch: ([int]) Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. + :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + to pass to the feature extractor. + :param normalize_images: (bool) Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_quantiles: int = 25, + n_critics: int = 2, + ): + super(TQCPolicy, self).__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True, + ) + + if net_arch is None: + if features_extractor_class == FlattenExtractor: + net_arch = [256, 256] + else: + net_arch = [] + + # Create shared features extractor + self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + self.features_dim = self.features_extractor.features_dim + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.net_args = { + "observation_space": self.observation_space, + "action_space": self.action_space, + "features_extractor": self.features_extractor, + "features_dim": self.features_dim, + "net_arch": self.net_arch, + "activation_fn": self.activation_fn, + "normalize_images": normalize_images, + } + self.actor_kwargs = self.net_args.copy() + self.critic_kwargs = self.net_args.copy() + sde_kwargs = { + "use_sde": use_sde, + "log_std_init": log_std_init, + "sde_net_arch": sde_net_arch, + "use_expln": use_expln, + "clip_mean": clip_mean, + } + self.actor_kwargs.update(sde_kwargs) + tqc_kwargs = {"n_quantiles": n_quantiles, "n_critics": n_critics} + self.critic_kwargs.update(tqc_kwargs) + self.actor, self.actor_target = None, None + self.critic, self.critic_target = None, None + + self._build(lr_schedule) + + def _build(self, lr_schedule: Callable) -> None: + self.actor = self.make_actor() + self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + self.critic = self.make_critic() + self.critic_target = self.make_critic() + self.critic_target.load_state_dict(self.critic.state_dict()) + # Do not optimize the shared feature extractor with the critic loss + # otherwise, there are gradient computation issues + # Another solution: having duplicated features extractor but requires more memory and computation + critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name] + self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) + + def _get_data(self) -> Dict[str, Any]: + data = super()._get_data() + + data.update( + dict( + net_arch=self.net_args["net_arch"], + activation_fn=self.net_args["activation_fn"], + use_sde=self.actor_kwargs["use_sde"], + log_std_init=self.actor_kwargs["log_std_init"], + sde_net_arch=self.actor_kwargs["sde_net_arch"], + use_expln=self.actor_kwargs["use_expln"], + clip_mean=self.actor_kwargs["clip_mean"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + + :param batch_size: (int) + """ + self.actor.reset_noise(batch_size=batch_size) + + def make_actor(self) -> Actor: + return Actor(**self.actor_kwargs).to(self.device) + + def make_critic(self) -> Critic: + return Critic(**self.critic_kwargs).to(self.device) + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self._predict(obs, deterministic=deterministic) + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self.actor(observation, deterministic) + + +MlpPolicy = TQCPolicy + + +class CnnPolicy(TQCPolicy): + """ + Policy class (with both actor and critic) for TQC. + + :param observation_space: (gym.spaces.Space) Observation space + :param action_space: (gym.spaces.Space) Action space + :param lr_schedule: (callable) Learning rate schedule (could be constant) + :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. + :param activation_fn: (Type[nn.Module]) Activation function + :param use_sde: (bool) Whether to use State Dependent Exploration or not + :param log_std_init: (float) Initial value for the log standard deviation + :param sde_net_arch: ([int]) Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. + :param normalize_images: (bool) Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_quantiles: int = 25, + n_critics: int = 2, + ): + super(CnnPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + use_sde, + log_std_init, + sde_net_arch, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_quantiles, + n_critics, + ) + + +register_policy("MlpPolicy", MlpPolicy) +register_policy("CnnPolicy", CnnPolicy) diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py new file mode 100644 index 0000000..f44f67c --- /dev/null +++ b/sb3_contrib/tqc/tqc.py @@ -0,0 +1,492 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch as th +from stable_baselines3.common import logger +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback +from stable_baselines3.common.utils import polyak_update +from tqdm import tqdm + +from sb3_contrib.tqc.policies import TQCPolicy + + +class TQC(OffPolicyAlgorithm): + """ + + Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics + Paper: https://arxiv.org/abs/2005.04269 + + :param policy: (TQCPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str) + :param learning_rate: (float or callable) learning rate for adam optimizer, + the same learning rate will be used for all networks (Q-Values, Actor and Value function) + it can be a function of the current progress remaining (from 1 to 0) + :param buffer_size: (int) size of the replay buffer + :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts + :param batch_size: (int) Minibatch size for each gradient update + :param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) + :param gamma: (float) the discount factor + :param train_freq: (int) Update the model every ``train_freq`` steps. + :param gradient_steps: (int) How many gradient update after each step + :param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. + Note that this cannot be used at the same time as ``train_freq`` + :param action_noise: (ActionNoise) the action noise type (None by default), this can help + for hard exploration problem. Cf common.noise for the different action noise type. + :param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + :param ent_coef: (str or float) Entropy regularization coefficient. (Equivalent to + inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. + Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value) + :param target_update_interval: (int) update the target network every ``target_network_update_freq`` + gradient steps. + :param target_entropy: (str or float) target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) + :param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling + during the warm up phase (before learning starts) + :param create_eval_env: (bool) Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation + :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug + :param seed: (int) Seed for the pseudo random generators + :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + """ + + def __init__( + self, + policy: Union[str, Type[TQCPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Callable] = 3e-4, + buffer_size: int = int(1e6), + learning_starts: int = 100, + batch_size: int = 256, + tau: float = 0.005, + gamma: float = 0.99, + train_freq: int = 1, + gradient_steps: int = 1, + n_episodes_rollout: int = -1, + action_noise: Optional[ActionNoise] = None, + optimize_memory_usage: bool = False, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + ent_coef: Union[str, float] = "auto", + target_update_interval: int = 1, + target_entropy: Union[str, float] = "auto", + top_quantiles_to_drop_per_net: int = 2, + use_sde: bool = False, + sde_sample_freq: int = -1, + use_sde_at_warmup: bool = False, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Dict[str, Any] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super(TQC, self).__init__( + policy, + env, + TQCPolicy, + learning_rate, + replay_buffer_class, + replay_buffer_kwargs, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + n_episodes_rollout, + action_noise, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + use_sde_at_warmup=use_sde_at_warmup, + optimize_memory_usage=optimize_memory_usage, + ) + + self.target_entropy = target_entropy + self.log_ent_coef = None # type: Optional[th.Tensor] + # Entropy coefficient / Entropy temperature + # Inverse of the reward scale + self.ent_coef = ent_coef + self.target_update_interval = target_update_interval + self.ent_coef_optimizer = None + self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + super(TQC, self)._setup_model() + self._create_aliases() + self.replay_buffer.actor = self.actor + self.replay_buffer.ent_coef = 0.0 + + # Target entropy is used when learning the entropy coefficient + if self.target_entropy == "auto": + # automatically set target entropy if needed + self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) + else: + # Force conversion + # this will also throw an error for unexpected string + self.target_entropy = float(self.target_entropy) + + # The entropy coefficient or entropy can be learned automatically + # see Automating Entropy Adjustment for Maximum Entropy RL section + # of https://arxiv.org/abs/1812.05905 + if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"): + # Default initial value of ent_coef when learned + init_value = 1.0 + if "_" in self.ent_coef: + init_value = float(self.ent_coef.split("_")[1]) + assert init_value > 0.0, "The initial value of ent_coef must be greater than 0" + + # Note: we optimize the log of the entropy coeff which is slightly different from the paper + # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 + self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True) + self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1)) + else: + # Force conversion to float + # this will throw an error if a malformed string (different from 'auto') + # is passed + self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device) + + def _create_aliases(self) -> None: + self.actor = self.policy.actor + self.critic = self.policy.critic + self.critic_target = self.policy.critic_target + + @staticmethod + def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor: + # batch x nets x quantiles x samples + pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None] + abs_pairwise_delta = th.abs(pairwise_delta) + huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5) + + n_quantiles = quantiles.shape[2] + tau = th.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles + loss = (th.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean() + return loss + + def train(self, gradient_steps: int, batch_size: int = 64) -> None: + # Update optimizers learning rate + optimizers = [self.actor.optimizer, self.critic.optimizer] + if self.ent_coef_optimizer is not None: + optimizers += [self.ent_coef_optimizer] + + # Update learning rate according to lr schedule + self._update_learning_rate(optimizers) + + ent_coef_losses, ent_coefs = [], [] + actor_losses, critic_losses = [], [] + + for gradient_step in range(gradient_steps): + # Sample replay buffer + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + + # We need to sample because `log_std` may have changed between two gradient steps + if self.use_sde: + self.actor.reset_noise() + + # Action by the current actor for the sampled state + actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations) + log_prob = log_prob.reshape(-1, 1) + + ent_coef_loss = None + if self.ent_coef_optimizer is not None: + # Important: detach the variable from the graph + # so we don't change it with other losses + # see https://github.com/rail-berkeley/softlearning/issues/60 + ent_coef = th.exp(self.log_ent_coef.detach()) + ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() + ent_coef_losses.append(ent_coef_loss.item()) + else: + ent_coef = self.ent_coef_tensor + + ent_coefs.append(ent_coef.item()) + self.replay_buffer.ent_coef = ent_coef.item() + + # Optimize entropy coefficient, also called + # entropy temperature or alpha in the paper + if ent_coef_loss is not None: + self.ent_coef_optimizer.zero_grad() + ent_coef_loss.backward() + self.ent_coef_optimizer.step() + + with th.no_grad(): + top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics + # Select action according to policy + next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) + # Compute and cut quantiles at the next state + # batch x nets x quantiles + next_z = self.critic_target(replay_data.next_observations, next_actions) + sorted_z, _ = th.sort(next_z.reshape(batch_size, -1)) + sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop] + + target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1) + # td error + entropy term + q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q + + # Get current Q estimates + # using action from the replay buffer + current_z = self.critic(replay_data.observations, replay_data.actions) + # Compute critic loss + critic_loss = self.quantile_huber_loss(current_z, q_backup) + critic_losses.append(critic_loss.item()) + + # Optimize the critic + self.critic.optimizer.zero_grad() + critic_loss.backward() + self.critic.optimizer.step() + + # Compute actor loss + qf_pi = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True) + actor_loss = (ent_coef * log_prob - qf_pi).mean() + actor_losses.append(actor_loss.item()) + + # Optimize the actor + self.actor.optimizer.zero_grad() + actor_loss.backward() + self.actor.optimizer.step() + + # Update target networks + if gradient_step % self.target_update_interval == 0: + polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) + + self._n_updates += gradient_steps + + logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + logger.record("train/ent_coef", np.mean(ent_coefs)) + logger.record("train/actor_loss", np.mean(actor_losses)) + logger.record("train/critic_loss", np.mean(critic_losses)) + if len(ent_coef_losses) > 0: + logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) + + def pretrain( + self, + gradient_steps: int, + batch_size: int = 64, + n_action_samples: int = -1, + target_update_interval: int = 1, + tau: float = 0.005, + strategy: str = "exp", + reduce: str = "mean", + exp_temperature: float = 1.0, + off_policy_update_freq: int = -1, + ) -> None: + """ + Pretrain with Critic Regularized Regression (CRR) + Paper: https://arxiv.org/abs/2006.15134 + """ + # Update optimizers learning rate + optimizers = [self.actor.optimizer, self.critic.optimizer] + if self.ent_coef_optimizer is not None: + optimizers += [self.ent_coef_optimizer] + + # Update learning rate according to lr schedule + self._update_learning_rate(optimizers) + + actor_losses, critic_losses = [], [] + + for gradient_step in tqdm(range(gradient_steps)): + + if off_policy_update_freq > 0 and gradient_step % off_policy_update_freq == 0: + self.train(gradient_steps=1, batch_size=batch_size) + continue + + # Sample replay buffer + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + + # We need to sample because `log_std` may have changed between two gradient steps + if self.use_sde: + self.actor.reset_noise() + + # Action by the current actor for the sampled state + _, log_prob = self.actor.action_log_prob(replay_data.observations) + log_prob = log_prob.reshape(-1, 1) + + ent_coef_loss = None + if self.ent_coef_optimizer is not None: + # Important: detach the variable from the graph + # so we don't change it with other losses + # see https://github.com/rail-berkeley/softlearning/issues/60 + ent_coef = th.exp(self.log_ent_coef.detach()) + ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() + else: + ent_coef = self.ent_coef_tensor + + self.replay_buffer.ent_coef = ent_coef.item() + + # Optimize entropy coefficient, also called + # entropy temperature or alpha in the paper + if ent_coef_loss is not None: + self.ent_coef_optimizer.zero_grad() + ent_coef_loss.backward() + self.ent_coef_optimizer.step() + + with th.no_grad(): + top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics + # Select action according to policy + next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) + # Compute and cut quantiles at the next state + # batch x nets x quantiles + next_z = self.critic_target(replay_data.next_observations, next_actions) + sorted_z, _ = th.sort(next_z.reshape(batch_size, -1)) + sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop] + + target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1) + # td error + entropy term + q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q + + # Get current Q estimates + # using action from the replay buffer + current_z = self.critic(replay_data.observations, replay_data.actions) + # Compute critic loss + critic_loss = self.quantile_huber_loss(current_z, q_backup) + critic_losses.append(critic_loss.item()) + + # Optimize the critic + self.critic.optimizer.zero_grad() + critic_loss.backward() + self.critic.optimizer.step() + + if strategy == "bc": + # Behavior cloning + weight = 1 + else: + # Tensor version: TODO: check that the reshape works as expected + # cleaner but not faster on cpu for large batch size + # with th.no_grad(): + # # Q-value for the action in the buffer + # qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True) + # # Create tensor to avoid loop + # # Note: For SDE, we need to sample several matrices + # obs_ = replay_data.observations.repeat(n_action_samples, 1) + # if self.use_sde: + # self.actor.reset_noise(batch_size * n_action_samples) + # actions_pi, _ = self.actor.action_log_prob(obs_) + # qf_pi = self.critic(obs_, actions_pi.detach()).mean(2).mean(1, keepdim=True) + # # Agregate: reduce mean or reduce max + # if reduce == "max": + # _, qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).max(axis=0) + # else: + # qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).mean(axis=0) + with th.no_grad(): + qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True) + + # Use the mean (as done in AWAC, cf rlkit) + if n_action_samples == -1: + actions_pi = self.actor.forward(replay_data.observations, deterministic=True) + qf_agg = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True) + else: + qf_agg = None + for _ in range(n_action_samples): + if self.use_sde: + self.actor.reset_noise() + actions_pi, _ = self.actor.action_log_prob(replay_data.observations) + + qf_pi = self.critic(replay_data.observations, actions_pi.detach()).mean(2).mean(1, keepdim=True) + if qf_agg is None: + if reduce == "max": + qf_agg = qf_pi + else: + qf_agg = qf_pi / n_action_samples + else: + if reduce == "max": + qf_agg = th.max(qf_pi, qf_agg) + else: + qf_agg += qf_pi / n_action_samples + + advantage = qf_buffer - qf_agg + if strategy == "binary": + # binary advantage + weight = advantage > 0 + else: + # exp advantage + exp_clip = 20.0 + weight = th.clamp(th.exp(advantage / exp_temperature), 0.0, exp_clip) + + # Log prob by the current actor for the sampled state and action + log_prob = self.actor.evaluate_actions(replay_data.observations, replay_data.actions) + log_prob = log_prob.reshape(-1, 1) + + # weigthed regression loss (close to policy gradient loss) + actor_loss = (-log_prob * weight).mean() + # actor_loss = ((actions_pi - replay_data.actions * weight) ** 2).mean() + actor_losses.append(actor_loss.item()) + + # Optimize the actor + self.actor.optimizer.zero_grad() + actor_loss.backward() + self.actor.optimizer.step() + + # Update target networks + if gradient_step % target_update_interval == 0: + polyak_update(self.critic.parameters(), self.critic_target.parameters(), tau) + + if self.use_sde: + print(f"std={(self.actor.get_std()).mean().item()}") + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "TQC", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> OffPolicyAlgorithm: + + return super(TQC, self).learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) + + def _excluded_save_params(self) -> List[str]: + """ + Returns the names of the parameters that should be excluded by default + when saving the model. + + :return: (List[str]) List of parameters that should be excluded from save + """ + # Exclude aliases + return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] + + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + """ + cf base class + """ + state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] + saved_pytorch_variables = ["log_ent_coef"] + if self.ent_coef_optimizer is not None: + state_dicts.append("ent_coef_optimizer") + else: + saved_pytorch_variables.append("ent_coef_tensor") + return state_dicts, saved_pytorch_variables diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt new file mode 100644 index 0000000..63aaf3d --- /dev/null +++ b/sb3_contrib/version.txt @@ -0,0 +1 @@ +0.9.0a2 diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh new file mode 100644 index 0000000..258e302 --- /dev/null +++ b/scripts/run_tests.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..951fc85 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,48 @@ +[metadata] +# This includes the license file in the wheel. +license_file = LICENSE + +[tool:pytest] +# Deterministic ordering for tests; useful for pytest-xdist. +env = + PYTHONHASHSEED=0 +filterwarnings = + # Tensorboard/Tensorflow warnings + ignore:inspect.getargspec:DeprecationWarning:tensorflow + ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning + ignore:The binary mode of fromstring is deprecated:DeprecationWarning + ignore::FutureWarning:tensorflow + # Gym warnings + ignore:Parameters to load are deprecated.:DeprecationWarning + ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning + ignore::UserWarning:gym + +[pytype] +inputs = sb3_contrib + +[flake8] +ignore = W503,W504,E203,E231 # line breaks before and after binary operators +# Ignore import not used when aliases are defined +per-file-ignores = + ./sb3_contrib/__init__.py:F401 + ./sb3_contrib/tqc/__init__.py:F401 +exclude = + # No need to traverse our git directory + .git, + # There's no value in checking cache directories + __pycache__, + # Don't check the doc + docs/ + # This contains our built documentation + build, + # This contains builds of flake8 that we don't want to check + dist + *.egg-info +max-complexity = 15 +# The GitHub editor is 127 chars wide +max-line-length = 127 + +[isort] +profile = black +line_length = 127 +src_paths = sb3_contrib diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b7f6f3e --- /dev/null +++ b/setup.py @@ -0,0 +1,45 @@ +import os + +from setuptools import find_packages, setup + +with open(os.path.join("sb3_contrib", "version.txt"), "r") as file_handler: + __version__ = file_handler.read().strip() + + +long_description = """ + +# Stable-Baselines3 - Contrib + +Contrib package for [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) - Experimental code + +Implemented: +- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) + +""" # noqa:E501 + + +setup( + name="sb3_contrib", + packages=[package for package in find_packages() if package.startswith("sb3_contrib")], + package_data={"sb3_contrib": ["py.typed", "version.txt"]}, + install_requires=[ + "stable_baselines3[tests,docs]", + # Enable CMA + # "cma", + ], + description="Contrib package of Stable Baselines3, experimental code.", + author="Antonin Raffin", + url="https://github.com/Stable-Baselines-Team/stable-baselines3-contrib", + author_email="antonin.raffin@dlr.de", + keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning " + "gym openai stable baselines toolbox python data-science", + license="MIT", + long_description=long_description, + long_description_content_type="text/markdown", + version=__version__, +) + +# python setup.py sdist +# python setup.py bdist_wheel +# twine upload --repository-url https://test.pypi.org/legacy/ dist/* +# twine upload dist/* diff --git a/tests/test_run.py b/tests/test_run.py new file mode 100644 index 0000000..2cbcf41 --- /dev/null +++ b/tests/test_run.py @@ -0,0 +1,54 @@ +import pytest + +from sb3_contrib import TQC + + +@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) +def test_tqc(ent_coef): + model = TQC( + "MlpPolicy", + "Pendulum-v0", + policy_kwargs=dict(net_arch=[64, 64]), + learning_starts=100, + verbose=1, + create_eval_env=True, + ent_coef=ent_coef, + ) + model.learn(total_timesteps=1000, eval_freq=500) + + +@pytest.mark.parametrize("n_critics", [1, 3]) +def test_n_critics(n_critics): + # Test TQC with different number of critics + model = TQC( + "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1 + ) + model.learn(total_timesteps=1000) + + +# "CartPole-v1" +# @pytest.mark.parametrize("env_id", ["MountainCarContinuous-v0"]) +# def test_cmaes(env_id): +# if CMAES is None: +# return +# model = CMAES("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[64]), verbose=1, create_eval_env=True) +# model.learn(total_timesteps=50000, eval_freq=10000) + + +@pytest.mark.parametrize("strategy", ["exp", "bc", "binary"]) +@pytest.mark.parametrize("reduce", ["mean", "max"]) +def test_crr(tmp_path, strategy, reduce): + model = TQC( + "MlpPolicy", + "Pendulum-v0", + policy_kwargs=dict(net_arch=[64]), + learning_starts=1000, + verbose=1, + create_eval_env=True, + action_noise=None, + use_sde=False, + ) + + model.learn(total_timesteps=1000, eval_freq=0) + for n_action_samples in [1, 2, -1]: + model.pretrain(gradient_steps=32, batch_size=32, n_action_samples=n_action_samples, strategy=strategy, reduce=reduce) diff --git a/tests/test_save_load.py b/tests/test_save_load.py new file mode 100644 index 0000000..4415267 --- /dev/null +++ b/tests/test_save_load.py @@ -0,0 +1,426 @@ +import os +import pathlib +import warnings +from collections import OrderedDict +from copy import deepcopy + +import gym +import numpy as np +import pytest +import torch as th +from stable_baselines3 import DQN +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox +from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl +from stable_baselines3.common.utils import get_device +from stable_baselines3.common.vec_env import DummyVecEnv + +from sb3_contrib import TQC + +MODEL_LIST = [TQC] + + +def select_env(model_class: BaseAlgorithm) -> gym.Env: + """ + Selects an environment with the correct action space as DQN only supports discrete action space + """ + if model_class == DQN: + return IdentityEnv(10) + else: + return IdentityEnvBox(10) + + +@pytest.mark.parametrize("model_class", MODEL_LIST) +def test_save_load(tmp_path, model_class): + """ + Test if 'save' and 'load' saves and loads model correctly + and if 'get_parameters' and 'set_parameters' and work correctly. + + ''warning does not test function of optimizer parameter load + + :param model_class: (BaseAlgorithm) A RL model + """ + + env = DummyVecEnv([lambda: select_env(model_class)]) + + # create model + model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1) + model.learn(total_timesteps=500, eval_freq=250) + + env.reset() + observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) + + # Get parameters of different objects + # deepcopy to avoid referencing to tensors we are about to modify + original_params = deepcopy(model.get_parameters()) + + # Test different error cases of set_parameters. + # Test that invalid object names throw errors + invalid_object_params = deepcopy(original_params) + invalid_object_params["I_should_not_be_a_valid_object"] = "and_I_am_an_invalid_tensor" + with pytest.raises(ValueError): + model.set_parameters(invalid_object_params, exact_match=True) + with pytest.raises(ValueError): + model.set_parameters(invalid_object_params, exact_match=False) + + # Test that exact_match catches when something was missed. + missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1]) + with pytest.raises(ValueError): + model.set_parameters(missing_object_params, exact_match=True) + + # Test that exact_match catches when something inside state-dict + # is missing but we have exact_match. + missing_state_dict_tensor_params = {} + for object_name in original_params: + object_params = {} + missing_state_dict_tensor_params[object_name] = object_params + # Skip last item in state-dict + for k, v in list(original_params[object_name].items())[:-1]: + object_params[k] = v + with pytest.raises(RuntimeError): + # PyTorch load_state_dict throws RuntimeError if strict but + # invalid state-dict. + model.set_parameters(missing_state_dict_tensor_params, exact_match=True) + + # Test that parameters do indeed change. + random_params = {} + for object_name, params in original_params.items(): + # Do not randomize optimizer parameters (custom layout) + if "optim" in object_name: + random_params[object_name] = params + else: + # Again, skip the last item in state-dict + random_params[object_name] = OrderedDict( + (param_name, th.rand_like(param)) for param_name, param in list(params.items())[:-1] + ) + + # Update model parameters with the new random values + model.set_parameters(random_params, exact_match=False) + + new_params = model.get_parameters() + # Check that all params except the final item in each state-dict are different. + for object_name in original_params: + # Skip optimizers (no valid comparison with just th.allclose) + if "optim" in object_name: + continue + # state-dicts use ordered dictionaries, so key order + # is guaranteed. + last_key = list(original_params[object_name].keys())[-1] + for k in original_params[object_name]: + if k == last_key: + # Should be same as before + assert th.allclose( + original_params[object_name][k], new_params[object_name][k] + ), "Parameter changed despite not included in the loaded parameters." + else: + # Should be different + assert not th.allclose( + original_params[object_name][k], new_params[object_name][k] + ), "Parameters did not change as expected." + + params = new_params + + # get selected actions + selected_actions, _ = model.predict(observations, deterministic=True) + + # Check + model.save(tmp_path / "test_save.zip") + del model + + # Check if the model loads as expected for every possible choice of device: + for device in ["auto", "cpu", "cuda"]: + model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device) + + # check if the model was loaded to the correct device + assert model.device.type == get_device(device).type + assert model.policy.device.type == get_device(device).type + + # check if params are still the same after load + new_params = model.get_parameters() + + # Check that all params are the same as before save load procedure now + for object_name in new_params: + # Skip optimizers (no valid comparison with just th.allclose) + if "optim" in object_name: + continue + for key in params[object_name]: + assert new_params[object_name][key].device.type == get_device(device).type + assert th.allclose( + params[object_name][key].to("cpu"), new_params[object_name][key].to("cpu") + ), "Model parameters not the same after save and load." + + # check if model still selects the same actions + new_selected_actions, _ = model.predict(observations, deterministic=True) + assert np.allclose(selected_actions, new_selected_actions, 1e-4) + + # check if learn still works + model.learn(total_timesteps=1000, eval_freq=500) + + del model + + # clear file from os + os.remove(tmp_path / "test_save.zip") + + +@pytest.mark.parametrize("model_class", MODEL_LIST) +def test_set_env(model_class): + """ + Test if set_env function does work correct + :param model_class: (BaseAlgorithm) A RL model + """ + + # use discrete for DQN + env = DummyVecEnv([lambda: select_env(model_class)]) + env2 = DummyVecEnv([lambda: select_env(model_class)]) + env3 = select_env(model_class) + + # create model + model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16])) + # learn + model.learn(total_timesteps=1000, eval_freq=500) + + # change env + model.set_env(env2) + # learn again + model.learn(total_timesteps=1000, eval_freq=500) + + # change env test wrapping + model.set_env(env3) + # learn again + model.learn(total_timesteps=1000, eval_freq=500) + + +@pytest.mark.parametrize("model_class", MODEL_LIST) +def test_exclude_include_saved_params(tmp_path, model_class): + """ + Test if exclude and include parameters of save() work + + :param model_class: (BaseAlgorithm) A RL model + """ + env = DummyVecEnv([lambda: select_env(model_class)]) + + # create model, set verbose as 2, which is not standard + model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=2) + + # Check if exclude works + model.save(tmp_path / "test_save", exclude=["verbose"]) + del model + model = model_class.load(str(tmp_path / "test_save.zip")) + # check if verbose was not saved + assert model.verbose != 2 + + # set verbose as something different then standard settings + model.verbose = 2 + # Check if include works + model.save(tmp_path / "test_save", exclude=["verbose"], include=["verbose"]) + del model + model = model_class.load(str(tmp_path / "test_save.zip")) + assert model.verbose == 2 + + # clear file from os + os.remove(tmp_path / "test_save.zip") + + +@pytest.mark.parametrize("model_class", [TQC]) +def test_save_load_replay_buffer(tmp_path, model_class): + path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl") + path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning + model = model_class("MlpPolicy", select_env(model_class), buffer_size=1000) + model.learn(500) + old_replay_buffer = deepcopy(model.replay_buffer) + model.save_replay_buffer(path) + model.replay_buffer = None + model.load_replay_buffer(path) + + assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations) + assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions) + assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards) + assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones) + + # test extending replay buffer + model.replay_buffer.extend( + old_replay_buffer.observations, + old_replay_buffer.observations, + old_replay_buffer.actions, + old_replay_buffer.rewards, + old_replay_buffer.dones, + ) + + +@pytest.mark.parametrize("model_class", [TQC]) +@pytest.mark.parametrize("optimize_memory_usage", [False, True]) +def test_warn_buffer(recwarn, model_class, optimize_memory_usage): + """ + When using memory efficient replay buffer, + a warning must be emitted when calling `.learn()` + multiple times. + See https://github.com/DLR-RM/stable-baselines3/issues/46 + """ + # remove gym warnings + warnings.filterwarnings(action="ignore", category=DeprecationWarning) + warnings.filterwarnings(action="ignore", category=UserWarning, module="gym") + + model = model_class( + "MlpPolicy", + select_env(model_class), + buffer_size=100, + optimize_memory_usage=optimize_memory_usage, + policy_kwargs=dict(net_arch=[64]), + learning_starts=10, + ) + + model.learn(150) + + model.learn(150, reset_num_timesteps=False) + + # Check that there is no warning + assert len(recwarn) == 0 + + model.learn(150) + + if optimize_memory_usage: + assert len(recwarn) == 1 + warning = recwarn.pop(UserWarning) + assert "The last trajectory in the replay buffer will be truncated" in str(warning.message) + else: + assert len(recwarn) == 0 + + +@pytest.mark.parametrize("model_class", MODEL_LIST) +@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) +def test_save_load_policy(tmp_path, model_class, policy_str): + """ + Test saving and loading policy only. + + :param model_class: (BaseAlgorithm) A RL model + :param policy_str: (str) Name of the policy. + """ + kwargs = {} + if policy_str == "MlpPolicy": + env = select_env(model_class) + else: + if model_class in [TQC]: + # Avoid memory error when using replay buffer + # Reduce the size of the features + kwargs = dict(buffer_size=250) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN) + + env = DummyVecEnv([lambda: env]) + + # create model + model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs) + model.learn(total_timesteps=500, eval_freq=250) + + env.reset() + observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) + + policy = model.policy + policy_class = policy.__class__ + actor, actor_class = None, None + if model_class in [TQC]: + actor = policy.actor + actor_class = actor.__class__ + + # Get dictionary of current parameters + params = deepcopy(policy.state_dict()) + + # Modify all parameters to be random values + random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + + # Update model parameters with the new random values + policy.load_state_dict(random_params) + + new_params = policy.state_dict() + # Check that all params are different now + for k in params: + assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." + + params = new_params + + # get selected actions + selected_actions, _ = policy.predict(observations, deterministic=True) + # Should also work with the actor only + if actor is not None: + selected_actions_actor, _ = actor.predict(observations, deterministic=True) + + # Save and load policy + policy.save(tmp_path / "policy.pkl") + # Save and load actor + if actor is not None: + actor.save(tmp_path / "actor.pkl") + + del policy, actor + + policy = policy_class.load(tmp_path / "policy.pkl") + if actor_class is not None: + actor = actor_class.load(tmp_path / "actor.pkl") + + # check if params are still the same after load + new_params = policy.state_dict() + + # Check that all params are the same as before save load procedure now + for key in params: + assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load." + + # check if model still selects the same actions + new_selected_actions, _ = policy.predict(observations, deterministic=True) + assert np.allclose(selected_actions, new_selected_actions, 1e-4) + + if actor_class is not None: + new_selected_actions_actor, _ = actor.predict(observations, deterministic=True) + assert np.allclose(selected_actions_actor, new_selected_actions_actor, 1e-4) + assert np.allclose(selected_actions_actor, new_selected_actions, 1e-4) + + # clear file from os + os.remove(tmp_path / "policy.pkl") + if actor_class is not None: + os.remove(tmp_path / "actor.pkl") + + +@pytest.mark.parametrize("pathtype", [str, pathlib.Path]) +def test_open_file_str_pathlib(tmp_path, pathtype): + # check that suffix isn't added because we used open_path first + with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1: + save_to_pkl(fp1, "foo") + assert fp1.closed + with pytest.warns(None) as record: + assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo" + assert not record + + # test custom suffix + with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1: + save_to_pkl(fp1, "foo") + assert fp1.closed + with pytest.warns(None) as record: + assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo" + assert not record + + # test without suffix + with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1: + save_to_pkl(fp1, "foo") + assert fp1.closed + with pytest.warns(None) as record: + assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo" + assert not record + + # test that a warning is raised when the path doesn't exist + with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1: + save_to_pkl(fp1, "foo") + assert fp1.closed + with pytest.warns(None) as record: + assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo" + assert len(record) == 0 + + with pytest.warns(None) as record: + assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo" + assert len(record) == 1 + + fp = pathlib.Path(f"{tmp_path}/t2").open("w") + fp.write("rubbish") + fp.close() + # test that a warning is only raised when verbose = 0 + with pytest.warns(None) as record: + open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close() + open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close() + open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close() + assert len(record) == 1