Issue forms and pyproject.toml (#162)
* Issue forms and pyproject.toml * [ci skip] Fix typos * Fix isort config * Use secret link to download atari roms * Fix for mypy and update config * Upgrade SB3 and fix warnings * Fix doc build * Update Makefile * Lint first
This commit is contained in:
parent
376d9551de
commit
728c1c5b7f
|
|
@ -0,0 +1,72 @@
|
|||
name: "\U0001F41B Bug Report"
|
||||
description: Submit a bug report to help us improve Stable-Baselines3 Contrib
|
||||
title: "[Bug]: bug title"
|
||||
labels: ["bug"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
|
||||
|
||||
If your issue is related to a **custom gym environment**, please use the custom gym env template.
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: 🐛 Bug
|
||||
description: A clear and concise description of what the bug is.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: reproduce
|
||||
attributes:
|
||||
label: To Reproduce
|
||||
description: |
|
||||
Steps to reproduce the behavior. Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
|
||||
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||
value: |
|
||||
```python
|
||||
from stable_baselines3 import ...
|
||||
from sb3_contrib import ...
|
||||
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: traceback
|
||||
attributes:
|
||||
label: Relevant log output / Error message
|
||||
description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
|
||||
placeholder: "Traceback (most recent call last): File ..."
|
||||
render: shell
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Describe the characteristic of your environment:
|
||||
* Describe how the library was installed (pip, docker, source, ...)
|
||||
* Stable-Baselines3 and sb3-contrib versions
|
||||
* GPU models and configuration
|
||||
* Python version
|
||||
* PyTorch version
|
||||
* Gym version
|
||||
* Versions of any other relevant libraries
|
||||
|
||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||
```sh
|
||||
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
|
||||
```
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
|
||||
required: true
|
||||
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
|
||||
required: true
|
||||
- label: I have provided a minimal working example to reproduce the bug
|
||||
required: true
|
||||
- label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||
required: true
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
name: "\U0001F916 Custom Gym Environment Issue"
|
||||
description: How to report an issue when using a custom Gym environment
|
||||
labels: ["question", "custom gym env"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
|
||||
|
||||
**Please check your environment first using**:
|
||||
```python
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
|
||||
env = CustomEnv(arg1, ...)
|
||||
# It will check your custom environment and output additional warnings if needed
|
||||
check_env(env)
|
||||
```
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: 🐛 Bug
|
||||
description: A clear and concise description of what the bug is.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: code-example
|
||||
attributes:
|
||||
label: Code example
|
||||
description: |
|
||||
Please try to provide a minimal example to reproduce the bug.
|
||||
For a custom environment, you need to give at least the observation space, action space, `reset()` and `step()` methods (see working example below).
|
||||
Error messages and stack traces are also helpful.
|
||||
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||
value: |
|
||||
```python
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
|
||||
|
||||
class CustomEnv(gym.Env):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
||||
def step(self, action):
|
||||
obs = self.observation_space.sample()
|
||||
reward = 1.0
|
||||
done = False
|
||||
info = {}
|
||||
return obs, reward, done, info
|
||||
|
||||
env = CustomEnv()
|
||||
check_env(env)
|
||||
|
||||
model = A2C("MlpPolicy", env, verbose=1).learn(1000)
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: traceback
|
||||
attributes:
|
||||
label: Relevant log output / Error message
|
||||
description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
|
||||
placeholder: "Traceback (most recent call last): File ..."
|
||||
render: shell
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Describe the characteristic of your environment:
|
||||
* Describe how the library was installed (pip, docker, source, ...)
|
||||
* Stable-Baselines3 and sb3-contrib versions
|
||||
* GPU models and configuration
|
||||
* Python version
|
||||
* PyTorch version
|
||||
* Gym version
|
||||
* Versions of any other relevant libraries
|
||||
|
||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||
```sh
|
||||
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
|
||||
```
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
|
||||
required: true
|
||||
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
|
||||
required: true
|
||||
- label: I have provided a minimal working example to reproduce the bug
|
||||
required: true
|
||||
- label: I have checked my env using the env checker
|
||||
required: true
|
||||
- label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||
required: true
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
name: "\U0001F4DA Documentation"
|
||||
description: Report an issue related to Stable-Baselines3 Contrib documentation
|
||||
labels: ["documentation"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: 📚 Documentation
|
||||
description: A clear and concise description of what should be improved in the documentation.
|
||||
validations:
|
||||
required: true
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
|
||||
required: true
|
||||
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
|
||||
required: true
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
name: "\U0001F680 Feature Request"
|
||||
description: How to create an issue for requesting a feature
|
||||
title: "[Feature Request] request title"
|
||||
labels: ["enhancement"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: 🚀 Feature
|
||||
description: A clear and concise description of the feature proposal.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation
|
||||
description: Please outline the motivation for the proposal. Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". If this is related to another GitHub issue, please link here too.
|
||||
- type: textarea
|
||||
id: pitch
|
||||
attributes:
|
||||
label: Pitch
|
||||
description: A clear and concise description of what you want to happen.
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives
|
||||
description: A clear and concise description of any alternative solutions or features you've considered, if any.
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Add any other context or screenshots about the feature request here.
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
|
||||
required: true
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
---
|
||||
name: Issue Template
|
||||
about: How to create an issue for this repository
|
||||
|
||||
---
|
||||
|
||||
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||
Please post your question on [reddit](https://www.reddit.com/r/reinforcementlearning/) or [stack overflow](https://stackoverflow.com/) in that case.
|
||||
|
||||
If you have any questions, feel free to create an issue with the tag [question].
|
||||
If you wish to suggest an enhancement or feature request, add the tag [feature request].
|
||||
If you are submitting a bug report, please fill in the following details.
|
||||
|
||||
If your issue is related to a custom gym environment, please check it first using:
|
||||
|
||||
```python
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
|
||||
env = CustomEnv(arg1, ...)
|
||||
# It will check your custom environment and output additional warnings if needed
|
||||
check_env(env)
|
||||
```
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Code example**
|
||||
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
|
||||
|
||||
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks)
|
||||
for both code and stack traces.
|
||||
|
||||
```python
|
||||
from stable_baselines3 import ...
|
||||
from sb3_contrib import ...
|
||||
|
||||
```
|
||||
|
||||
```bash
|
||||
Traceback (most recent call last): File ...
|
||||
|
||||
```
|
||||
|
||||
**System Info**
|
||||
Describe the characteristic of your environment:
|
||||
* Describe how the library was installed (pip, docker, source, ...)
|
||||
* Stable-Baselines3 and sb3-contrib versions
|
||||
* GPU models and configuration
|
||||
* Python version
|
||||
* PyTorch version
|
||||
* Gym version
|
||||
* Versions of any other relevant libraries
|
||||
|
||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||
```sh
|
||||
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
|
||||
```
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
name: "❓ Question"
|
||||
description: How to ask a question regarding Stable-Baselines3 Contrib
|
||||
title: "[Question] question title"
|
||||
labels: ["question"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
|
||||
- type: textarea
|
||||
id: question
|
||||
attributes:
|
||||
label: ❓ Question
|
||||
description: Your question. This can be e.g. questions regarding confusing or unclear behaviour of functions or a question if X can be done using stable-baselines3. Make sure to check out the documentation first.
|
||||
validations:
|
||||
required: true
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo
|
||||
required: true
|
||||
- label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/)
|
||||
required: true
|
||||
- label: If code there is, it is minimal and working
|
||||
required: true
|
||||
- label: If code there is, it is formatted using the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||
required: true
|
||||
|
|
@ -14,6 +14,7 @@ jobs:
|
|||
env:
|
||||
TERM: xterm-256color
|
||||
FORCE_COLOR: 1
|
||||
ATARI_ROMS: ${{ secrets.ATARI_ROMS }}
|
||||
# Skip CI if [ci skip] in the commit message
|
||||
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
|
||||
runs-on: ubuntu-latest
|
||||
|
|
@ -32,24 +33,32 @@ jobs:
|
|||
python -m pip install --upgrade pip
|
||||
# cpu version of pytorch
|
||||
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
# Install Atari Roms
|
||||
pip install autorom
|
||||
wget $ATARI_ROMS
|
||||
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
|
||||
AutoROM --accept-license --source-file Roms.tar.gz
|
||||
|
||||
# Install master version
|
||||
# and dependencies for docs and tests
|
||||
pip install git+https://github.com/DLR-RM/stable-baselines3#egg=stable_baselines3[extra,tests,docs]
|
||||
pip install .
|
||||
# Use headless version
|
||||
pip install opencv-python-headless
|
||||
|
||||
- name: Lint with ruff
|
||||
run: |
|
||||
make lint
|
||||
- name: Check codestyle
|
||||
run: |
|
||||
make check-codestyle
|
||||
- name: Build the doc
|
||||
run: |
|
||||
make doc
|
||||
- name: Type check
|
||||
run: |
|
||||
make type
|
||||
- name: Check codestyle
|
||||
run: |
|
||||
make check-codestyle
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
make lint
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make pytest
|
||||
|
|
|
|||
13
Makefile
13
Makefile
|
|
@ -14,24 +14,17 @@ type: pytype mypy
|
|||
|
||||
lint:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
# see https://lintlyci.github.io/Flake8Rules/
|
||||
flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings.
|
||||
flake8 ${LINT_PATHS} --count --exit-zero --statistics
|
||||
|
||||
ruff:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
# see https://lintlyci.github.io/Flake8Rules/
|
||||
# see https://www.flake8rules.com/
|
||||
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
|
||||
# exit-zero treats all errors as warnings.
|
||||
ruff ${LINT_PATHS} --exit-zero --line-length 127
|
||||
ruff ${LINT_PATHS} --exit-zero
|
||||
|
||||
|
||||
format:
|
||||
# Sort imports
|
||||
isort ${LINT_PATHS}
|
||||
# Reformat using black
|
||||
black -l 127 ${LINT_PATHS}
|
||||
black ${LINT_PATHS}
|
||||
|
||||
check-codestyle:
|
||||
# Sort imports
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.8.0a4 (WIP)
|
||||
Release 1.8.0a8 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -22,6 +22,8 @@ Deprecations:
|
|||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Moved to pyproject.toml (except flake8)
|
||||
- Added github issue forms
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -1,5 +1,64 @@
|
|||
[tool.ruff]
|
||||
# Same as Black.
|
||||
line-length = 127
|
||||
# Assume Python 3.7
|
||||
target-version = "py37"
|
||||
select = ["E", "F", "B", "UP", "C90", "RUF"]
|
||||
ignore = []
|
||||
|
||||
[tool.ruff.mccabe]
|
||||
# Unlike Flake8, ruff default to a complexity level of 10.
|
||||
max-complexity = 15
|
||||
|
||||
[tool.black]
|
||||
line-length = 127
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 127
|
||||
src_paths = ["sb3_contrib"]
|
||||
|
||||
[tool.pytype]
|
||||
inputs = ['sb3_contrib']
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "silent"
|
||||
show_error_codes = true
|
||||
exclude = """(?x)(
|
||||
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
|
||||
| sb3_contrib/ars/ars.py$
|
||||
| sb3_contrib/qrdqn/qrdqn.py$
|
||||
| sb3_contrib/qrdqn/policies.py$
|
||||
| sb3_contrib/common/recurrent/policies.py$
|
||||
| sb3_contrib/common/recurrent/buffers.py$
|
||||
| sb3_contrib/common/maskable/distributions.py$
|
||||
| sb3_contrib/common/maskable/callbacks.py$
|
||||
| sb3_contrib/common/maskable/policies.py$
|
||||
| sb3_contrib/common/maskable/buffers.py$
|
||||
| sb3_contrib/common/envs/invalid_actions_env.py$
|
||||
| sb3_contrib/common/vec_env/async_eval.py$
|
||||
| sb3_contrib/tqc/tqc.py$
|
||||
| sb3_contrib/tqc/policies.py$
|
||||
| sb3_contrib/trpo/trpo.py$
|
||||
| sb3_contrib/ppo_mask/ppo_mask.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
)"""
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# Deterministic ordering for tests; useful for pytest-xdist.
|
||||
env = [
|
||||
"PYTHONHASHSEED=0"
|
||||
]
|
||||
|
||||
filterwarnings = [
|
||||
# Tensorboard warnings
|
||||
"ignore::DeprecationWarning:tensorboard",
|
||||
# Gym warnings
|
||||
"ignore:Parameters to load are deprecated.:DeprecationWarning",
|
||||
"ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning",
|
||||
"ignore::UserWarning:gym",
|
||||
]
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_suppo
|
|||
from sb3_contrib.ppo_mask import MaskablePPO
|
||||
|
||||
|
||||
def evaluate_policy( # noqa: C901
|
||||
def evaluate_policy(
|
||||
model: MaskablePPO,
|
||||
env: Union[gym.Env, VecEnv],
|
||||
n_eval_episodes: int = 10,
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ class RecurrentRolloutBuffer(RolloutBuffer):
|
|||
|
||||
return RecurrentRolloutBufferSamples(
|
||||
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
|
||||
observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size,) + self.obs_shape),
|
||||
observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size, *self.obs_shape)),
|
||||
actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]),
|
||||
old_values=self.pad_and_flatten(self.values[batch_inds]),
|
||||
old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]),
|
||||
|
|
|
|||
|
|
@ -149,9 +149,9 @@ def flat_grad(
|
|||
|
||||
:param output: functional output to compute the gradient for
|
||||
:param parameters: sequence of ``Parameter``
|
||||
:param retain_graph: – If ``False``, the graph used to compute the grad will be freed.
|
||||
:param retain_graph: If ``False``, the graph used to compute the grad will be freed.
|
||||
Defaults to the value of ``create_graph``.
|
||||
:param create_graph: – If ``True``, graph of the derivative will be constructed,
|
||||
:param create_graph: If ``True``, graph of the derivative will be constructed,
|
||||
allowing to compute higher order derivative products. Default: ``False``.
|
||||
:return: Tensor containing the flattened gradients
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -264,7 +264,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"]
|
||||
return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"] # noqa: RUF005
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
state_dicts = ["policy", "policy.optimizer"]
|
||||
|
|
|
|||
|
|
@ -302,7 +302,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
# Exclude aliases
|
||||
return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
|
||||
return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0a4
|
||||
1.8.0a8
|
||||
|
|
|
|||
66
setup.cfg
66
setup.cfg
|
|
@ -1,66 +0,0 @@
|
|||
[metadata]
|
||||
# This includes the license file in the wheel.
|
||||
license_files = LICENSE
|
||||
|
||||
[tool:pytest]
|
||||
# Deterministic ordering for tests; useful for pytest-xdist.
|
||||
env =
|
||||
PYTHONHASHSEED=0
|
||||
filterwarnings =
|
||||
# Tensorboard warnings
|
||||
ignore::DeprecationWarning:tensorboard
|
||||
# Gym warnings
|
||||
ignore:Parameters to load are deprecated.:DeprecationWarning
|
||||
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
|
||||
ignore::UserWarning:gym
|
||||
markers =
|
||||
slow: marks tests as slow (deselect with '-m "not slow"')
|
||||
|
||||
[mypy]
|
||||
ignore_missing_imports = True
|
||||
follow_imports = silent
|
||||
show_error_codes = True
|
||||
exclude = (?x)(
|
||||
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
|
||||
| sb3_contrib/ars/ars.py$
|
||||
| sb3_contrib/qrdqn/qrdqn.py$
|
||||
| sb3_contrib/qrdqn/policies.py$
|
||||
| sb3_contrib/common/recurrent/policies.py$
|
||||
| sb3_contrib/common/recurrent/buffers.py$
|
||||
| sb3_contrib/common/maskable/distributions.py$
|
||||
| sb3_contrib/common/maskable/callbacks.py$
|
||||
| sb3_contrib/common/maskable/policies.py$
|
||||
| sb3_contrib/common/maskable/buffers.py$
|
||||
| sb3_contrib/common/envs/invalid_actions_env.py$
|
||||
| sb3_contrib/common/vec_env/async_eval.py$
|
||||
| sb3_contrib/tqc/tqc.py$
|
||||
| sb3_contrib/tqc/policies.py$
|
||||
| sb3_contrib/trpo/trpo.py$
|
||||
| sb3_contrib/ppo_mask/ppo_mask.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
)
|
||||
|
||||
[flake8]
|
||||
# line breaks before and after binary operators
|
||||
ignore = W503,W504,E203,E231
|
||||
|
||||
exclude =
|
||||
# No need to traverse our git directory
|
||||
.git,
|
||||
# There's no value in checking cache directories
|
||||
__pycache__,
|
||||
# Don't check the doc
|
||||
docs/
|
||||
# This contains our built documentation
|
||||
build,
|
||||
# This contains builds of flake8 that we don't want to check
|
||||
dist
|
||||
*.egg-info
|
||||
max-complexity = 15
|
||||
# The GitHub editor is 127 chars wide
|
||||
max-line-length = 127
|
||||
|
||||
[isort]
|
||||
profile = black
|
||||
line_length = 127
|
||||
src_paths = sb3_contrib
|
||||
9
setup.py
9
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.8.0a2",
|
||||
"stable_baselines3>=1.8.0a8",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
@ -79,6 +79,13 @@ setup(
|
|||
version=__version__,
|
||||
python_requires=">=3.7",
|
||||
# PyPI package information.
|
||||
# PyPI package information.
|
||||
project_urls={
|
||||
"Code": "https://github.com/Stable-Baselines-Team/stable-baselines3-contrib",
|
||||
"Documentation": "https://sb3-contrib.readthedocs.io/",
|
||||
"Stable-Baselines3": "https://github.com/DLR-RM/stable-baselines3",
|
||||
"RL-Zoo": "https://github.com/DLR-RM/rl-baselines3-zoo",
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from sb3_contrib.common.wrappers import TimeFeatureWrapper
|
|||
|
||||
class CustomGoalEnv(gym.GoalEnv):
|
||||
def __init__(self):
|
||||
super(CustomGoalEnv, self).__init__()
|
||||
super().__init__()
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"observation": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
|
||||
|
|
|
|||
Loading…
Reference in New Issue