Issue forms and pyproject.toml (#162)

* Issue forms and pyproject.toml

* [ci skip] Fix typos

* Fix isort config

* Use secret link to download atari roms

* Fix for mypy and update config

* Upgrade SB3 and fix warnings

* Fix doc build

* Update Makefile

* Lint first
This commit is contained in:
Antonin RAFFIN 2023-03-11 22:57:45 +01:00 committed by GitHub
parent 376d9551de
commit 728c1c5b7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 375 additions and 152 deletions

72
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@ -0,0 +1,72 @@
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve Stable-Baselines3 Contrib
title: "[Bug]: bug title"
labels: ["bug"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
If your issue is related to a **custom gym environment**, please use the custom gym env template.
- type: textarea
id: description
attributes:
label: 🐛 Bug
description: A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: To Reproduce
description: |
Steps to reproduce the behavior. Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
value: |
```python
from stable_baselines3 import ...
from sb3_contrib import ...
```
- type: textarea
id: traceback
attributes:
label: Relevant log output / Error message
description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
placeholder: "Traceback (most recent call last): File ..."
render: shell
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Describe the characteristic of your environment:
* Describe how the library was installed (pip, docker, source, ...)
* Stable-Baselines3 and sb3-contrib versions
* GPU models and configuration
* Python version
* PyTorch version
* Gym version
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
required: true
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
required: true
- label: I have provided a minimal working example to reproduce the bug
required: true
- label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
required: true

108
.github/ISSUE_TEMPLATE/custom_env.yml vendored Normal file
View File

@ -0,0 +1,108 @@
name: "\U0001F916 Custom Gym Environment Issue"
description: How to report an issue when using a custom Gym environment
labels: ["question", "custom gym env"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
**Please check your environment first using**:
```python
from stable_baselines3.common.env_checker import check_env
env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)
```
- type: textarea
id: description
attributes:
label: 🐛 Bug
description: A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
id: code-example
attributes:
label: Code example
description: |
Please try to provide a minimal example to reproduce the bug.
For a custom environment, you need to give at least the observation space, action space, `reset()` and `step()` methods (see working example below).
Error messages and stack traces are also helpful.
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
value: |
```python
import gym
import numpy as np
from gym import spaces
from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env
class CustomEnv(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
def reset(self):
return self.observation_space.sample()
def step(self, action):
obs = self.observation_space.sample()
reward = 1.0
done = False
info = {}
return obs, reward, done, info
env = CustomEnv()
check_env(env)
model = A2C("MlpPolicy", env, verbose=1).learn(1000)
```
- type: textarea
id: traceback
attributes:
label: Relevant log output / Error message
description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
placeholder: "Traceback (most recent call last): File ..."
render: shell
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Describe the characteristic of your environment:
* Describe how the library was installed (pip, docker, source, ...)
* Stable-Baselines3 and sb3-contrib versions
* GPU models and configuration
* Python version
* PyTorch version
* Gym version
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
required: true
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
required: true
- label: I have provided a minimal working example to reproduce the bug
required: true
- label: I have checked my env using the env checker
required: true
- label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
required: true

View File

@ -0,0 +1,25 @@
name: "\U0001F4DA Documentation"
description: Report an issue related to Stable-Baselines3 Contrib documentation
labels: ["documentation"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
- type: textarea
id: description
attributes:
label: 📚 Documentation
description: A clear and concise description of what should be improved in the documentation.
validations:
required: true
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
required: true
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
required: true

View File

@ -0,0 +1,44 @@
name: "\U0001F680 Feature Request"
description: How to create an issue for requesting a feature
title: "[Feature Request] request title"
labels: ["enhancement"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
- type: textarea
id: description
attributes:
label: 🚀 Feature
description: A clear and concise description of the feature proposal.
validations:
required: true
- type: textarea
id: motivation
attributes:
label: Motivation
description: Please outline the motivation for the proposal. Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". If this is related to another GitHub issue, please link here too.
- type: textarea
id: pitch
attributes:
label: Pitch
description: A clear and concise description of what you want to happen.
- type: textarea
id: alternatives
attributes:
label: Alternatives
description: A clear and concise description of any alternative solutions or features you've considered, if any.
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add any other context or screenshots about the feature request here.
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
required: true

View File

@ -1,60 +0,0 @@
---
name: Issue Template
about: How to create an issue for this repository
---
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on [reddit](https://www.reddit.com/r/reinforcementlearning/) or [stack overflow](https://stackoverflow.com/) in that case.
If you have any questions, feel free to create an issue with the tag [question].
If you wish to suggest an enhancement or feature request, add the tag [feature request].
If you are submitting a bug report, please fill in the following details.
If your issue is related to a custom gym environment, please check it first using:
```python
from stable_baselines3.common.env_checker import check_env
env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)
```
**Describe the bug**
A clear and concise description of what the bug is.
**Code example**
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks)
for both code and stack traces.
```python
from stable_baselines3 import ...
from sb3_contrib import ...
```
```bash
Traceback (most recent call last): File ...
```
**System Info**
Describe the characteristic of your environment:
* Describe how the library was installed (pip, docker, source, ...)
* Stable-Baselines3 and sb3-contrib versions
* GPU models and configuration
* Python version
* PyTorch version
* Gym version
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
**Additional context**
Add any other context about the problem here.

30
.github/ISSUE_TEMPLATE/question.yml vendored Normal file
View File

@ -0,0 +1,30 @@
name: "❓ Question"
description: How to ask a question regarding Stable-Baselines3 Contrib
title: "[Question] question title"
labels: ["question"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
- type: textarea
id: question
attributes:
label: ❓ Question
description: Your question. This can be e.g. questions regarding confusing or unclear behaviour of functions or a question if X can be done using stable-baselines3. Make sure to check out the documentation first.
validations:
required: true
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
required: true
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
required: true
- label: If code there is, it is minimal and working
required: true
- label: If code there is, it is formatted using the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
required: true

View File

@ -14,6 +14,7 @@ jobs:
env:
TERM: xterm-256color
FORCE_COLOR: 1
ATARI_ROMS: ${{ secrets.ATARI_ROMS }}
# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
runs-on: ubuntu-latest
@ -32,24 +33,32 @@ jobs:
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
# Install Atari Roms
pip install autorom
wget $ATARI_ROMS
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# Install master version
# and dependencies for docs and tests
pip install git+https://github.com/DLR-RM/stable-baselines3#egg=stable_baselines3[extra,tests,docs]
pip install .
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Check codestyle
run: |
make check-codestyle
- name: Build the doc
run: |
make doc
- name: Type check
run: |
make type
- name: Check codestyle
run: |
make check-codestyle
- name: Lint with flake8
run: |
make lint
- name: Test with pytest
run: |
make pytest

View File

@ -14,24 +14,17 @@ type: pytype mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://lintlyci.github.io/Flake8Rules/
flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings.
flake8 ${LINT_PATHS} --count --exit-zero --statistics
ruff:
# stop the build if there are Python syntax errors or undefined names
# see https://lintlyci.github.io/Flake8Rules/
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero --line-length 127
ruff ${LINT_PATHS} --exit-zero
format:
# Sort imports
isort ${LINT_PATHS}
# Reformat using black
black -l 127 ${LINT_PATHS}
black ${LINT_PATHS}
check-codestyle:
# Sort imports

View File

@ -3,7 +3,7 @@
Changelog
==========
Release 1.8.0a4 (WIP)
Release 1.8.0a8 (WIP)
--------------------------
Breaking Changes:
@ -22,6 +22,8 @@ Deprecations:
Others:
^^^^^^^
- Moved to pyproject.toml (except flake8)
- Added github issue forms
Documentation:
^^^^^^^^^^^^^^

View File

@ -1,5 +1,64 @@
[tool.ruff]
# Same as Black.
line-length = 127
# Assume Python 3.7
target-version = "py37"
select = ["E", "F", "B", "UP", "C90", "RUF"]
ignore = []
[tool.ruff.mccabe]
# Unlike Flake8, ruff default to a complexity level of 10.
max-complexity = 15
[tool.black]
line-length = 127
[tool.isort]
profile = "black"
line_length = 127
src_paths = ["sb3_contrib"]
[tool.pytype]
inputs = ['sb3_contrib']
[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
| sb3_contrib/ars/ars.py$
| sb3_contrib/qrdqn/qrdqn.py$
| sb3_contrib/qrdqn/policies.py$
| sb3_contrib/common/recurrent/policies.py$
| sb3_contrib/common/recurrent/buffers.py$
| sb3_contrib/common/maskable/distributions.py$
| sb3_contrib/common/maskable/callbacks.py$
| sb3_contrib/common/maskable/policies.py$
| sb3_contrib/common/maskable/buffers.py$
| sb3_contrib/common/envs/invalid_actions_env.py$
| sb3_contrib/common/vec_env/async_eval.py$
| sb3_contrib/tqc/tqc.py$
| sb3_contrib/tqc/policies.py$
| sb3_contrib/trpo/trpo.py$
| sb3_contrib/ppo_mask/ppo_mask.py$
| tests/test_train_eval_mode.py$
)"""
[tool.pytest.ini_options]
# Deterministic ordering for tests; useful for pytest-xdist.
env = [
"PYTHONHASHSEED=0"
]
filterwarnings = [
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# Gym warnings
"ignore:Parameters to load are deprecated.:DeprecationWarning",
"ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning",
"ignore::UserWarning:gym",
]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')"
]

View File

@ -10,7 +10,7 @@ from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_suppo
from sb3_contrib.ppo_mask import MaskablePPO
def evaluate_policy( # noqa: C901
def evaluate_policy(
model: MaskablePPO,
env: Union[gym.Env, VecEnv],
n_eval_episodes: int = 10,

View File

@ -229,7 +229,7 @@ class RecurrentRolloutBuffer(RolloutBuffer):
return RecurrentRolloutBufferSamples(
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size,) + self.obs_shape),
observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size, *self.obs_shape)),
actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]),
old_values=self.pad_and_flatten(self.values[batch_inds]),
old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]),

View File

@ -149,9 +149,9 @@ def flat_grad(
:param output: functional output to compute the gradient for
:param parameters: sequence of ``Parameter``
:param retain_graph: If ``False``, the graph used to compute the grad will be freed.
:param retain_graph: If ``False``, the graph used to compute the grad will be freed.
Defaults to the value of ``create_graph``.
:param create_graph: If ``True``, graph of the derivative will be constructed,
:param create_graph: If ``True``, graph of the derivative will be constructed,
allowing to compute higher order derivative products. Default: ``False``.
:return: Tensor containing the flattened gradients
"""

View File

@ -264,7 +264,7 @@ class QRDQN(OffPolicyAlgorithm):
)
def _excluded_save_params(self) -> List[str]:
return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"]
return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"] # noqa: RUF005
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

View File

@ -302,7 +302,7 @@ class TQC(OffPolicyAlgorithm):
def _excluded_save_params(self) -> List[str]:
# Exclude aliases
return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]

View File

@ -1 +1 @@
1.8.0a4
1.8.0a8

View File

@ -1,66 +0,0 @@
[metadata]
# This includes the license file in the wheel.
license_files = LICENSE
[tool:pytest]
# Deterministic ordering for tests; useful for pytest-xdist.
env =
PYTHONHASHSEED=0
filterwarnings =
# Tensorboard warnings
ignore::DeprecationWarning:tensorboard
# Gym warnings
ignore:Parameters to load are deprecated.:DeprecationWarning
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
ignore::UserWarning:gym
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
[mypy]
ignore_missing_imports = True
follow_imports = silent
show_error_codes = True
exclude = (?x)(
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
| sb3_contrib/ars/ars.py$
| sb3_contrib/qrdqn/qrdqn.py$
| sb3_contrib/qrdqn/policies.py$
| sb3_contrib/common/recurrent/policies.py$
| sb3_contrib/common/recurrent/buffers.py$
| sb3_contrib/common/maskable/distributions.py$
| sb3_contrib/common/maskable/callbacks.py$
| sb3_contrib/common/maskable/policies.py$
| sb3_contrib/common/maskable/buffers.py$
| sb3_contrib/common/envs/invalid_actions_env.py$
| sb3_contrib/common/vec_env/async_eval.py$
| sb3_contrib/tqc/tqc.py$
| sb3_contrib/tqc/policies.py$
| sb3_contrib/trpo/trpo.py$
| sb3_contrib/ppo_mask/ppo_mask.py$
| tests/test_train_eval_mode.py$
)
[flake8]
# line breaks before and after binary operators
ignore = W503,W504,E203,E231
exclude =
# No need to traverse our git directory
.git,
# There's no value in checking cache directories
__pycache__,
# Don't check the doc
docs/
# This contains our built documentation
build,
# This contains builds of flake8 that we don't want to check
dist
*.egg-info
max-complexity = 15
# The GitHub editor is 127 chars wide
max-line-length = 127
[isort]
profile = black
line_length = 127
src_paths = sb3_contrib

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.8.0a2",
"stable_baselines3>=1.8.0a8",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
@ -79,6 +79,13 @@ setup(
version=__version__,
python_requires=">=3.7",
# PyPI package information.
# PyPI package information.
project_urls={
"Code": "https://github.com/Stable-Baselines-Team/stable-baselines3-contrib",
"Documentation": "https://sb3-contrib.readthedocs.io/",
"Stable-Baselines3": "https://github.com/DLR-RM/stable-baselines3",
"RL-Zoo": "https://github.com/DLR-RM/rl-baselines3-zoo",
},
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",

View File

@ -10,7 +10,7 @@ from sb3_contrib.common.wrappers import TimeFeatureWrapper
class CustomGoalEnv(gym.GoalEnv):
def __init__(self):
super(CustomGoalEnv, self).__init__()
super().__init__()
self.observation_space = spaces.Dict(
{
"observation": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),