Issue forms and pyproject.toml (#162)
* Issue forms and pyproject.toml * [ci skip] Fix typos * Fix isort config * Use secret link to download atari roms * Fix for mypy and update config * Upgrade SB3 and fix warnings * Fix doc build * Update Makefile * Lint first
This commit is contained in:
parent
376d9551de
commit
728c1c5b7f
|
|
@ -0,0 +1,72 @@
|
||||||
|
name: "\U0001F41B Bug Report"
|
||||||
|
description: Submit a bug report to help us improve Stable-Baselines3 Contrib
|
||||||
|
title: "[Bug]: bug title"
|
||||||
|
labels: ["bug"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||||
|
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
|
||||||
|
|
||||||
|
If your issue is related to a **custom gym environment**, please use the custom gym env template.
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: 🐛 Bug
|
||||||
|
description: A clear and concise description of what the bug is.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: reproduce
|
||||||
|
attributes:
|
||||||
|
label: To Reproduce
|
||||||
|
description: |
|
||||||
|
Steps to reproduce the behavior. Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
|
||||||
|
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||||
|
value: |
|
||||||
|
```python
|
||||||
|
from stable_baselines3 import ...
|
||||||
|
from sb3_contrib import ...
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: traceback
|
||||||
|
attributes:
|
||||||
|
label: Relevant log output / Error message
|
||||||
|
description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
|
||||||
|
placeholder: "Traceback (most recent call last): File ..."
|
||||||
|
render: shell
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: system-info
|
||||||
|
attributes:
|
||||||
|
label: System Info
|
||||||
|
description: |
|
||||||
|
Describe the characteristic of your environment:
|
||||||
|
* Describe how the library was installed (pip, docker, source, ...)
|
||||||
|
* Stable-Baselines3 and sb3-contrib versions
|
||||||
|
* GPU models and configuration
|
||||||
|
* Python version
|
||||||
|
* PyTorch version
|
||||||
|
* Gym version
|
||||||
|
* Versions of any other relevant libraries
|
||||||
|
|
||||||
|
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||||
|
```sh
|
||||||
|
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
|
||||||
|
```
|
||||||
|
- type: checkboxes
|
||||||
|
id: terms
|
||||||
|
attributes:
|
||||||
|
label: Checklist
|
||||||
|
options:
|
||||||
|
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
|
||||||
|
required: true
|
||||||
|
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
|
||||||
|
required: true
|
||||||
|
- label: I have provided a minimal working example to reproduce the bug
|
||||||
|
required: true
|
||||||
|
- label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||||
|
required: true
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
name: "\U0001F916 Custom Gym Environment Issue"
|
||||||
|
description: How to report an issue when using a custom Gym environment
|
||||||
|
labels: ["question", "custom gym env"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||||
|
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
|
||||||
|
|
||||||
|
**Please check your environment first using**:
|
||||||
|
```python
|
||||||
|
from stable_baselines3.common.env_checker import check_env
|
||||||
|
|
||||||
|
env = CustomEnv(arg1, ...)
|
||||||
|
# It will check your custom environment and output additional warnings if needed
|
||||||
|
check_env(env)
|
||||||
|
```
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: 🐛 Bug
|
||||||
|
description: A clear and concise description of what the bug is.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: code-example
|
||||||
|
attributes:
|
||||||
|
label: Code example
|
||||||
|
description: |
|
||||||
|
Please try to provide a minimal example to reproduce the bug.
|
||||||
|
For a custom environment, you need to give at least the observation space, action space, `reset()` and `step()` methods (see working example below).
|
||||||
|
Error messages and stack traces are also helpful.
|
||||||
|
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||||
|
value: |
|
||||||
|
```python
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
from gym import spaces
|
||||||
|
|
||||||
|
from stable_baselines3 import A2C
|
||||||
|
from stable_baselines3.common.env_checker import check_env
|
||||||
|
|
||||||
|
|
||||||
|
class CustomEnv(gym.Env):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs = self.observation_space.sample()
|
||||||
|
reward = 1.0
|
||||||
|
done = False
|
||||||
|
info = {}
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
|
env = CustomEnv()
|
||||||
|
check_env(env)
|
||||||
|
|
||||||
|
model = A2C("MlpPolicy", env, verbose=1).learn(1000)
|
||||||
|
```
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: traceback
|
||||||
|
attributes:
|
||||||
|
label: Relevant log output / Error message
|
||||||
|
description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
|
||||||
|
placeholder: "Traceback (most recent call last): File ..."
|
||||||
|
render: shell
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: system-info
|
||||||
|
attributes:
|
||||||
|
label: System Info
|
||||||
|
description: |
|
||||||
|
Describe the characteristic of your environment:
|
||||||
|
* Describe how the library was installed (pip, docker, source, ...)
|
||||||
|
* Stable-Baselines3 and sb3-contrib versions
|
||||||
|
* GPU models and configuration
|
||||||
|
* Python version
|
||||||
|
* PyTorch version
|
||||||
|
* Gym version
|
||||||
|
* Versions of any other relevant libraries
|
||||||
|
|
||||||
|
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||||
|
```sh
|
||||||
|
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
|
||||||
|
```
|
||||||
|
- type: checkboxes
|
||||||
|
id: terms
|
||||||
|
attributes:
|
||||||
|
label: Checklist
|
||||||
|
options:
|
||||||
|
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
|
||||||
|
required: true
|
||||||
|
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
|
||||||
|
required: true
|
||||||
|
- label: I have provided a minimal working example to reproduce the bug
|
||||||
|
required: true
|
||||||
|
- label: I have checked my env using the env checker
|
||||||
|
required: true
|
||||||
|
- label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||||
|
required: true
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
name: "\U0001F4DA Documentation"
|
||||||
|
description: Report an issue related to Stable-Baselines3 Contrib documentation
|
||||||
|
labels: ["documentation"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||||
|
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: 📚 Documentation
|
||||||
|
description: A clear and concise description of what should be improved in the documentation.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: checkboxes
|
||||||
|
id: terms
|
||||||
|
attributes:
|
||||||
|
label: Checklist
|
||||||
|
options:
|
||||||
|
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
|
||||||
|
required: true
|
||||||
|
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
|
||||||
|
required: true
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
name: "\U0001F680 Feature Request"
|
||||||
|
description: How to create an issue for requesting a feature
|
||||||
|
title: "[Feature Request] request title"
|
||||||
|
labels: ["enhancement"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||||
|
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: 🚀 Feature
|
||||||
|
description: A clear and concise description of the feature proposal.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: motivation
|
||||||
|
attributes:
|
||||||
|
label: Motivation
|
||||||
|
description: Please outline the motivation for the proposal. Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". If this is related to another GitHub issue, please link here too.
|
||||||
|
- type: textarea
|
||||||
|
id: pitch
|
||||||
|
attributes:
|
||||||
|
label: Pitch
|
||||||
|
description: A clear and concise description of what you want to happen.
|
||||||
|
- type: textarea
|
||||||
|
id: alternatives
|
||||||
|
attributes:
|
||||||
|
label: Alternatives
|
||||||
|
description: A clear and concise description of any alternative solutions or features you've considered, if any.
|
||||||
|
- type: textarea
|
||||||
|
id: additional-context
|
||||||
|
attributes:
|
||||||
|
label: Additional context
|
||||||
|
description: Add any other context or screenshots about the feature request here.
|
||||||
|
- type: checkboxes
|
||||||
|
id: terms
|
||||||
|
attributes:
|
||||||
|
label: Checklist
|
||||||
|
options:
|
||||||
|
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
|
||||||
|
required: true
|
||||||
|
|
@ -1,60 +0,0 @@
|
||||||
---
|
|
||||||
name: Issue Template
|
|
||||||
about: How to create an issue for this repository
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
|
||||||
Please post your question on [reddit](https://www.reddit.com/r/reinforcementlearning/) or [stack overflow](https://stackoverflow.com/) in that case.
|
|
||||||
|
|
||||||
If you have any questions, feel free to create an issue with the tag [question].
|
|
||||||
If you wish to suggest an enhancement or feature request, add the tag [feature request].
|
|
||||||
If you are submitting a bug report, please fill in the following details.
|
|
||||||
|
|
||||||
If your issue is related to a custom gym environment, please check it first using:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from stable_baselines3.common.env_checker import check_env
|
|
||||||
|
|
||||||
env = CustomEnv(arg1, ...)
|
|
||||||
# It will check your custom environment and output additional warnings if needed
|
|
||||||
check_env(env)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Describe the bug**
|
|
||||||
A clear and concise description of what the bug is.
|
|
||||||
|
|
||||||
**Code example**
|
|
||||||
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
|
|
||||||
|
|
||||||
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks)
|
|
||||||
for both code and stack traces.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from stable_baselines3 import ...
|
|
||||||
from sb3_contrib import ...
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
```bash
|
|
||||||
Traceback (most recent call last): File ...
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
**System Info**
|
|
||||||
Describe the characteristic of your environment:
|
|
||||||
* Describe how the library was installed (pip, docker, source, ...)
|
|
||||||
* Stable-Baselines3 and sb3-contrib versions
|
|
||||||
* GPU models and configuration
|
|
||||||
* Python version
|
|
||||||
* PyTorch version
|
|
||||||
* Gym version
|
|
||||||
* Versions of any other relevant libraries
|
|
||||||
|
|
||||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
|
||||||
```sh
|
|
||||||
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
|
|
||||||
```
|
|
||||||
|
|
||||||
**Additional context**
|
|
||||||
Add any other context about the problem here.
|
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
name: "❓ Question"
|
||||||
|
description: How to ask a question regarding Stable-Baselines3 Contrib
|
||||||
|
title: "[Question] question title"
|
||||||
|
labels: ["question"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||||
|
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
|
||||||
|
- type: textarea
|
||||||
|
id: question
|
||||||
|
attributes:
|
||||||
|
label: ❓ Question
|
||||||
|
description: Your question. This can be e.g. questions regarding confusing or unclear behaviour of functions or a question if X can be done using stable-baselines3. Make sure to check out the documentation first.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: checkboxes
|
||||||
|
id: terms
|
||||||
|
attributes:
|
||||||
|
label: Checklist
|
||||||
|
options:
|
||||||
|
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
|
||||||
|
required: true
|
||||||
|
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
|
||||||
|
required: true
|
||||||
|
- label: If code there is, it is minimal and working
|
||||||
|
required: true
|
||||||
|
- label: If code there is, it is formatted using the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||||
|
required: true
|
||||||
|
|
@ -14,6 +14,7 @@ jobs:
|
||||||
env:
|
env:
|
||||||
TERM: xterm-256color
|
TERM: xterm-256color
|
||||||
FORCE_COLOR: 1
|
FORCE_COLOR: 1
|
||||||
|
ATARI_ROMS: ${{ secrets.ATARI_ROMS }}
|
||||||
# Skip CI if [ci skip] in the commit message
|
# Skip CI if [ci skip] in the commit message
|
||||||
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
|
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
@ -32,24 +33,32 @@ jobs:
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
# cpu version of pytorch
|
# cpu version of pytorch
|
||||||
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
|
||||||
|
# Install Atari Roms
|
||||||
|
pip install autorom
|
||||||
|
wget $ATARI_ROMS
|
||||||
|
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
|
||||||
|
AutoROM --accept-license --source-file Roms.tar.gz
|
||||||
|
|
||||||
# Install master version
|
# Install master version
|
||||||
# and dependencies for docs and tests
|
# and dependencies for docs and tests
|
||||||
pip install git+https://github.com/DLR-RM/stable-baselines3#egg=stable_baselines3[extra,tests,docs]
|
pip install git+https://github.com/DLR-RM/stable-baselines3#egg=stable_baselines3[extra,tests,docs]
|
||||||
pip install .
|
pip install .
|
||||||
# Use headless version
|
# Use headless version
|
||||||
pip install opencv-python-headless
|
pip install opencv-python-headless
|
||||||
|
|
||||||
|
- name: Lint with ruff
|
||||||
|
run: |
|
||||||
|
make lint
|
||||||
|
- name: Check codestyle
|
||||||
|
run: |
|
||||||
|
make check-codestyle
|
||||||
- name: Build the doc
|
- name: Build the doc
|
||||||
run: |
|
run: |
|
||||||
make doc
|
make doc
|
||||||
- name: Type check
|
- name: Type check
|
||||||
run: |
|
run: |
|
||||||
make type
|
make type
|
||||||
- name: Check codestyle
|
|
||||||
run: |
|
|
||||||
make check-codestyle
|
|
||||||
- name: Lint with flake8
|
|
||||||
run: |
|
|
||||||
make lint
|
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
make pytest
|
make pytest
|
||||||
|
|
|
||||||
13
Makefile
13
Makefile
|
|
@ -14,24 +14,17 @@ type: pytype mypy
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
# stop the build if there are Python syntax errors or undefined names
|
# stop the build if there are Python syntax errors or undefined names
|
||||||
# see https://lintlyci.github.io/Flake8Rules/
|
# see https://www.flake8rules.com/
|
||||||
flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics
|
|
||||||
# exit-zero treats all errors as warnings.
|
|
||||||
flake8 ${LINT_PATHS} --count --exit-zero --statistics
|
|
||||||
|
|
||||||
ruff:
|
|
||||||
# stop the build if there are Python syntax errors or undefined names
|
|
||||||
# see https://lintlyci.github.io/Flake8Rules/
|
|
||||||
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
|
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
|
||||||
# exit-zero treats all errors as warnings.
|
# exit-zero treats all errors as warnings.
|
||||||
ruff ${LINT_PATHS} --exit-zero --line-length 127
|
ruff ${LINT_PATHS} --exit-zero
|
||||||
|
|
||||||
|
|
||||||
format:
|
format:
|
||||||
# Sort imports
|
# Sort imports
|
||||||
isort ${LINT_PATHS}
|
isort ${LINT_PATHS}
|
||||||
# Reformat using black
|
# Reformat using black
|
||||||
black -l 127 ${LINT_PATHS}
|
black ${LINT_PATHS}
|
||||||
|
|
||||||
check-codestyle:
|
check-codestyle:
|
||||||
# Sort imports
|
# Sort imports
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.8.0a4 (WIP)
|
Release 1.8.0a8 (WIP)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -22,6 +22,8 @@ Deprecations:
|
||||||
|
|
||||||
Others:
|
Others:
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
- Moved to pyproject.toml (except flake8)
|
||||||
|
- Added github issue forms
|
||||||
|
|
||||||
Documentation:
|
Documentation:
|
||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,64 @@
|
||||||
|
[tool.ruff]
|
||||||
|
# Same as Black.
|
||||||
|
line-length = 127
|
||||||
|
# Assume Python 3.7
|
||||||
|
target-version = "py37"
|
||||||
|
select = ["E", "F", "B", "UP", "C90", "RUF"]
|
||||||
|
ignore = []
|
||||||
|
|
||||||
|
[tool.ruff.mccabe]
|
||||||
|
# Unlike Flake8, ruff default to a complexity level of 10.
|
||||||
|
max-complexity = 15
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 127
|
line-length = 127
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
line_length = 127
|
||||||
|
src_paths = ["sb3_contrib"]
|
||||||
|
|
||||||
[tool.pytype]
|
[tool.pytype]
|
||||||
inputs = ['sb3_contrib']
|
inputs = ['sb3_contrib']
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
follow_imports = "silent"
|
||||||
|
show_error_codes = true
|
||||||
|
exclude = """(?x)(
|
||||||
|
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
|
||||||
|
| sb3_contrib/ars/ars.py$
|
||||||
|
| sb3_contrib/qrdqn/qrdqn.py$
|
||||||
|
| sb3_contrib/qrdqn/policies.py$
|
||||||
|
| sb3_contrib/common/recurrent/policies.py$
|
||||||
|
| sb3_contrib/common/recurrent/buffers.py$
|
||||||
|
| sb3_contrib/common/maskable/distributions.py$
|
||||||
|
| sb3_contrib/common/maskable/callbacks.py$
|
||||||
|
| sb3_contrib/common/maskable/policies.py$
|
||||||
|
| sb3_contrib/common/maskable/buffers.py$
|
||||||
|
| sb3_contrib/common/envs/invalid_actions_env.py$
|
||||||
|
| sb3_contrib/common/vec_env/async_eval.py$
|
||||||
|
| sb3_contrib/tqc/tqc.py$
|
||||||
|
| sb3_contrib/tqc/policies.py$
|
||||||
|
| sb3_contrib/trpo/trpo.py$
|
||||||
|
| sb3_contrib/ppo_mask/ppo_mask.py$
|
||||||
|
| tests/test_train_eval_mode.py$
|
||||||
|
)"""
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
# Deterministic ordering for tests; useful for pytest-xdist.
|
||||||
|
env = [
|
||||||
|
"PYTHONHASHSEED=0"
|
||||||
|
]
|
||||||
|
|
||||||
|
filterwarnings = [
|
||||||
|
# Tensorboard warnings
|
||||||
|
"ignore::DeprecationWarning:tensorboard",
|
||||||
|
# Gym warnings
|
||||||
|
"ignore:Parameters to load are deprecated.:DeprecationWarning",
|
||||||
|
"ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning",
|
||||||
|
"ignore::UserWarning:gym",
|
||||||
|
]
|
||||||
|
markers = [
|
||||||
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')"
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_suppo
|
||||||
from sb3_contrib.ppo_mask import MaskablePPO
|
from sb3_contrib.ppo_mask import MaskablePPO
|
||||||
|
|
||||||
|
|
||||||
def evaluate_policy( # noqa: C901
|
def evaluate_policy(
|
||||||
model: MaskablePPO,
|
model: MaskablePPO,
|
||||||
env: Union[gym.Env, VecEnv],
|
env: Union[gym.Env, VecEnv],
|
||||||
n_eval_episodes: int = 10,
|
n_eval_episodes: int = 10,
|
||||||
|
|
|
||||||
|
|
@ -229,7 +229,7 @@ class RecurrentRolloutBuffer(RolloutBuffer):
|
||||||
|
|
||||||
return RecurrentRolloutBufferSamples(
|
return RecurrentRolloutBufferSamples(
|
||||||
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
|
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
|
||||||
observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size,) + self.obs_shape),
|
observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size, *self.obs_shape)),
|
||||||
actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]),
|
actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]),
|
||||||
old_values=self.pad_and_flatten(self.values[batch_inds]),
|
old_values=self.pad_and_flatten(self.values[batch_inds]),
|
||||||
old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]),
|
old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]),
|
||||||
|
|
|
||||||
|
|
@ -149,9 +149,9 @@ def flat_grad(
|
||||||
|
|
||||||
:param output: functional output to compute the gradient for
|
:param output: functional output to compute the gradient for
|
||||||
:param parameters: sequence of ``Parameter``
|
:param parameters: sequence of ``Parameter``
|
||||||
:param retain_graph: – If ``False``, the graph used to compute the grad will be freed.
|
:param retain_graph: If ``False``, the graph used to compute the grad will be freed.
|
||||||
Defaults to the value of ``create_graph``.
|
Defaults to the value of ``create_graph``.
|
||||||
:param create_graph: – If ``True``, graph of the derivative will be constructed,
|
:param create_graph: If ``True``, graph of the derivative will be constructed,
|
||||||
allowing to compute higher order derivative products. Default: ``False``.
|
allowing to compute higher order derivative products. Default: ``False``.
|
||||||
:return: Tensor containing the flattened gradients
|
:return: Tensor containing the flattened gradients
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -264,7 +264,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _excluded_save_params(self) -> List[str]:
|
def _excluded_save_params(self) -> List[str]:
|
||||||
return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"]
|
return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"] # noqa: RUF005
|
||||||
|
|
||||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||||
state_dicts = ["policy", "policy.optimizer"]
|
state_dicts = ["policy", "policy.optimizer"]
|
||||||
|
|
|
||||||
|
|
@ -302,7 +302,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
|
|
||||||
def _excluded_save_params(self) -> List[str]:
|
def _excluded_save_params(self) -> List[str]:
|
||||||
# Exclude aliases
|
# Exclude aliases
|
||||||
return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
|
return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005
|
||||||
|
|
||||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.8.0a4
|
1.8.0a8
|
||||||
|
|
|
||||||
66
setup.cfg
66
setup.cfg
|
|
@ -1,66 +0,0 @@
|
||||||
[metadata]
|
|
||||||
# This includes the license file in the wheel.
|
|
||||||
license_files = LICENSE
|
|
||||||
|
|
||||||
[tool:pytest]
|
|
||||||
# Deterministic ordering for tests; useful for pytest-xdist.
|
|
||||||
env =
|
|
||||||
PYTHONHASHSEED=0
|
|
||||||
filterwarnings =
|
|
||||||
# Tensorboard warnings
|
|
||||||
ignore::DeprecationWarning:tensorboard
|
|
||||||
# Gym warnings
|
|
||||||
ignore:Parameters to load are deprecated.:DeprecationWarning
|
|
||||||
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
|
|
||||||
ignore::UserWarning:gym
|
|
||||||
markers =
|
|
||||||
slow: marks tests as slow (deselect with '-m "not slow"')
|
|
||||||
|
|
||||||
[mypy]
|
|
||||||
ignore_missing_imports = True
|
|
||||||
follow_imports = silent
|
|
||||||
show_error_codes = True
|
|
||||||
exclude = (?x)(
|
|
||||||
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
|
|
||||||
| sb3_contrib/ars/ars.py$
|
|
||||||
| sb3_contrib/qrdqn/qrdqn.py$
|
|
||||||
| sb3_contrib/qrdqn/policies.py$
|
|
||||||
| sb3_contrib/common/recurrent/policies.py$
|
|
||||||
| sb3_contrib/common/recurrent/buffers.py$
|
|
||||||
| sb3_contrib/common/maskable/distributions.py$
|
|
||||||
| sb3_contrib/common/maskable/callbacks.py$
|
|
||||||
| sb3_contrib/common/maskable/policies.py$
|
|
||||||
| sb3_contrib/common/maskable/buffers.py$
|
|
||||||
| sb3_contrib/common/envs/invalid_actions_env.py$
|
|
||||||
| sb3_contrib/common/vec_env/async_eval.py$
|
|
||||||
| sb3_contrib/tqc/tqc.py$
|
|
||||||
| sb3_contrib/tqc/policies.py$
|
|
||||||
| sb3_contrib/trpo/trpo.py$
|
|
||||||
| sb3_contrib/ppo_mask/ppo_mask.py$
|
|
||||||
| tests/test_train_eval_mode.py$
|
|
||||||
)
|
|
||||||
|
|
||||||
[flake8]
|
|
||||||
# line breaks before and after binary operators
|
|
||||||
ignore = W503,W504,E203,E231
|
|
||||||
|
|
||||||
exclude =
|
|
||||||
# No need to traverse our git directory
|
|
||||||
.git,
|
|
||||||
# There's no value in checking cache directories
|
|
||||||
__pycache__,
|
|
||||||
# Don't check the doc
|
|
||||||
docs/
|
|
||||||
# This contains our built documentation
|
|
||||||
build,
|
|
||||||
# This contains builds of flake8 that we don't want to check
|
|
||||||
dist
|
|
||||||
*.egg-info
|
|
||||||
max-complexity = 15
|
|
||||||
# The GitHub editor is 127 chars wide
|
|
||||||
max-line-length = 127
|
|
||||||
|
|
||||||
[isort]
|
|
||||||
profile = black
|
|
||||||
line_length = 127
|
|
||||||
src_paths = sb3_contrib
|
|
||||||
9
setup.py
9
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.8.0a2",
|
"stable_baselines3>=1.8.0a8",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
@ -79,6 +79,13 @@ setup(
|
||||||
version=__version__,
|
version=__version__,
|
||||||
python_requires=">=3.7",
|
python_requires=">=3.7",
|
||||||
# PyPI package information.
|
# PyPI package information.
|
||||||
|
# PyPI package information.
|
||||||
|
project_urls={
|
||||||
|
"Code": "https://github.com/Stable-Baselines-Team/stable-baselines3-contrib",
|
||||||
|
"Documentation": "https://sb3-contrib.readthedocs.io/",
|
||||||
|
"Stable-Baselines3": "https://github.com/DLR-RM/stable-baselines3",
|
||||||
|
"RL-Zoo": "https://github.com/DLR-RM/rl-baselines3-zoo",
|
||||||
|
},
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.7",
|
"Programming Language :: Python :: 3.7",
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from sb3_contrib.common.wrappers import TimeFeatureWrapper
|
||||||
|
|
||||||
class CustomGoalEnv(gym.GoalEnv):
|
class CustomGoalEnv(gym.GoalEnv):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CustomGoalEnv, self).__init__()
|
super().__init__()
|
||||||
self.observation_space = spaces.Dict(
|
self.observation_space = spaces.Dict(
|
||||||
{
|
{
|
||||||
"observation": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
|
"observation": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue