From 728c1c5b7f1eefd627187975d27c11fcd1a5482d Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 11 Mar 2023 22:57:45 +0100 Subject: [PATCH] Issue forms and pyproject.toml (#162) * Issue forms and pyproject.toml * [ci skip] Fix typos * Fix isort config * Use secret link to download atari roms * Fix for mypy and update config * Upgrade SB3 and fix warnings * Fix doc build * Update Makefile * Lint first --- .github/ISSUE_TEMPLATE/bug_report.yml | 72 ++++++++++++++ .github/ISSUE_TEMPLATE/custom_env.yml | 108 +++++++++++++++++++++ .github/ISSUE_TEMPLATE/documentation.yml | 25 +++++ .github/ISSUE_TEMPLATE/feature_request.yml | 44 +++++++++ .github/ISSUE_TEMPLATE/issue-template.md | 60 ------------ .github/ISSUE_TEMPLATE/question.yml | 30 ++++++ .github/workflows/ci.yml | 21 ++-- Makefile | 13 +-- docs/misc/changelog.rst | 4 +- pyproject.toml | 59 +++++++++++ sb3_contrib/common/maskable/evaluation.py | 2 +- sb3_contrib/common/recurrent/buffers.py | 2 +- sb3_contrib/common/utils.py | 4 +- sb3_contrib/qrdqn/qrdqn.py | 2 +- sb3_contrib/tqc/tqc.py | 2 +- sb3_contrib/version.txt | 2 +- setup.cfg | 66 ------------- setup.py | 9 +- tests/wrappers/test_time_feature.py | 2 +- 19 files changed, 375 insertions(+), 152 deletions(-) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.yml create mode 100644 .github/ISSUE_TEMPLATE/custom_env.yml create mode 100644 .github/ISSUE_TEMPLATE/documentation.yml create mode 100644 .github/ISSUE_TEMPLATE/feature_request.yml delete mode 100644 .github/ISSUE_TEMPLATE/issue-template.md create mode 100644 .github/ISSUE_TEMPLATE/question.yml delete mode 100644 setup.cfg diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000..4ad1f39 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,72 @@ +name: "\U0001F41B Bug Report" +description: Submit a bug report to help us improve Stable-Baselines3 Contrib +title: "[Bug]: bug title" +labels: ["bug"] +body: + - type: markdown + attributes: + value: | + **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. + Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. + + If your issue is related to a **custom gym environment**, please use the custom gym env template. + - type: textarea + id: description + attributes: + label: 🐛 Bug + description: A clear and concise description of what the bug is. + validations: + required: true + - type: textarea + id: reproduce + attributes: + label: To Reproduce + description: | + Steps to reproduce the behavior. Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. + Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces. + value: | + ```python + from stable_baselines3 import ... + from sb3_contrib import ... + + ``` + + - type: textarea + id: traceback + attributes: + label: Relevant log output / Error message + description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks. + placeholder: "Traceback (most recent call last): File ..." + render: shell + + - type: textarea + id: system-info + attributes: + label: System Info + description: | + Describe the characteristic of your environment: + * Describe how the library was installed (pip, docker, source, ...) + * Stable-Baselines3 and sb3-contrib versions + * GPU models and configuration + * Python version + * PyTorch version + * Gym version + * Versions of any other relevant libraries + + You can use `sb3.get_system_info()` to print relevant packages info: + ```sh + python -c 'import stable_baselines3 as sb3; sb3.get_system_info()' + ``` + - type: checkboxes + id: terms + attributes: + label: Checklist + options: + - label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo + required: true + - label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/) + required: true + - label: I have provided a minimal working example to reproduce the bug + required: true + - label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces. + required: true diff --git a/.github/ISSUE_TEMPLATE/custom_env.yml b/.github/ISSUE_TEMPLATE/custom_env.yml new file mode 100644 index 0000000..c35211c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/custom_env.yml @@ -0,0 +1,108 @@ +name: "\U0001F916 Custom Gym Environment Issue" +description: How to report an issue when using a custom Gym environment +labels: ["question", "custom gym env"] +body: + - type: markdown + attributes: + value: | + **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. + Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. + + **Please check your environment first using**: + ```python + from stable_baselines3.common.env_checker import check_env + + env = CustomEnv(arg1, ...) + # It will check your custom environment and output additional warnings if needed + check_env(env) + ``` + - type: textarea + id: description + attributes: + label: 🐛 Bug + description: A clear and concise description of what the bug is. + validations: + required: true + - type: textarea + id: code-example + attributes: + label: Code example + description: | + Please try to provide a minimal example to reproduce the bug. + For a custom environment, you need to give at least the observation space, action space, `reset()` and `step()` methods (see working example below). + Error messages and stack traces are also helpful. + Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces. + value: | + ```python + import gym + import numpy as np + from gym import spaces + + from stable_baselines3 import A2C + from stable_baselines3.common.env_checker import check_env + + + class CustomEnv(gym.Env): + + def __init__(self): + super().__init__() + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,)) + self.action_space = spaces.Box(low=-1, high=1, shape=(6,)) + + def reset(self): + return self.observation_space.sample() + + def step(self, action): + obs = self.observation_space.sample() + reward = 1.0 + done = False + info = {} + return obs, reward, done, info + + env = CustomEnv() + check_env(env) + + model = A2C("MlpPolicy", env, verbose=1).learn(1000) + ``` + + - type: textarea + id: traceback + attributes: + label: Relevant log output / Error message + description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks. + placeholder: "Traceback (most recent call last): File ..." + render: shell + + - type: textarea + id: system-info + attributes: + label: System Info + description: | + Describe the characteristic of your environment: + * Describe how the library was installed (pip, docker, source, ...) + * Stable-Baselines3 and sb3-contrib versions + * GPU models and configuration + * Python version + * PyTorch version + * Gym version + * Versions of any other relevant libraries + + You can use `sb3.get_system_info()` to print relevant packages info: + ```sh + python -c 'import stable_baselines3 as sb3; sb3.get_system_info()' + ``` + - type: checkboxes + id: terms + attributes: + label: Checklist + options: + - label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo + required: true + - label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/) + required: true + - label: I have provided a minimal working example to reproduce the bug + required: true + - label: I have checked my env using the env checker + required: true + - label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces. + required: true diff --git a/.github/ISSUE_TEMPLATE/documentation.yml b/.github/ISSUE_TEMPLATE/documentation.yml new file mode 100644 index 0000000..14be0f9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation.yml @@ -0,0 +1,25 @@ +name: "\U0001F4DA Documentation" +description: Report an issue related to Stable-Baselines3 Contrib documentation +labels: ["documentation"] +body: + - type: markdown + attributes: + value: | + **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. + Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. + - type: textarea + id: description + attributes: + label: 📚 Documentation + description: A clear and concise description of what should be improved in the documentation. + validations: + required: true + - type: checkboxes + id: terms + attributes: + label: Checklist + options: + - label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo + required: true + - label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/) + required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000..7114686 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,44 @@ +name: "\U0001F680 Feature Request" +description: How to create an issue for requesting a feature +title: "[Feature Request] request title" +labels: ["enhancement"] +body: + - type: markdown + attributes: + value: | + **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. + Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. + - type: textarea + id: description + attributes: + label: 🚀 Feature + description: A clear and concise description of the feature proposal. + validations: + required: true + - type: textarea + id: motivation + attributes: + label: Motivation + description: Please outline the motivation for the proposal. Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". If this is related to another GitHub issue, please link here too. + - type: textarea + id: pitch + attributes: + label: Pitch + description: A clear and concise description of what you want to happen. + - type: textarea + id: alternatives + attributes: + label: Alternatives + description: A clear and concise description of any alternative solutions or features you've considered, if any. + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add any other context or screenshots about the feature request here. + - type: checkboxes + id: terms + attributes: + label: Checklist + options: + - label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo + required: true diff --git a/.github/ISSUE_TEMPLATE/issue-template.md b/.github/ISSUE_TEMPLATE/issue-template.md deleted file mode 100644 index e185110..0000000 --- a/.github/ISSUE_TEMPLATE/issue-template.md +++ /dev/null @@ -1,60 +0,0 @@ ---- -name: Issue Template -about: How to create an issue for this repository - ---- - -**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. -Please post your question on [reddit](https://www.reddit.com/r/reinforcementlearning/) or [stack overflow](https://stackoverflow.com/) in that case. - -If you have any questions, feel free to create an issue with the tag [question]. -If you wish to suggest an enhancement or feature request, add the tag [feature request]. -If you are submitting a bug report, please fill in the following details. - -If your issue is related to a custom gym environment, please check it first using: - -```python -from stable_baselines3.common.env_checker import check_env - -env = CustomEnv(arg1, ...) -# It will check your custom environment and output additional warnings if needed -check_env(env) -``` - -**Describe the bug** -A clear and concise description of what the bug is. - -**Code example** -Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. - -Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) -for both code and stack traces. - -```python -from stable_baselines3 import ... -from sb3_contrib import ... - -``` - -```bash -Traceback (most recent call last): File ... - -``` - -**System Info** -Describe the characteristic of your environment: - * Describe how the library was installed (pip, docker, source, ...) - * Stable-Baselines3 and sb3-contrib versions - * GPU models and configuration - * Python version - * PyTorch version - * Gym version - * Versions of any other relevant libraries - - You can use `sb3.get_system_info()` to print relevant packages info: - ```sh - python -c 'import stable_baselines3 as sb3; sb3.get_system_info()' - ``` - -**Additional context** -Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/question.yml b/.github/ISSUE_TEMPLATE/question.yml new file mode 100644 index 0000000..cf45eff --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.yml @@ -0,0 +1,30 @@ +name: "❓ Question" +description: How to ask a question regarding Stable-Baselines3 Contrib +title: "[Question] question title" +labels: ["question"] +body: + - type: markdown + attributes: + value: | + **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. + Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. + - type: textarea + id: question + attributes: + label: ❓ Question + description: Your question. This can be e.g. questions regarding confusing or unclear behaviour of functions or a question if X can be done using stable-baselines3. Make sure to check out the documentation first. + validations: + required: true + - type: checkboxes + id: terms + attributes: + label: Checklist + options: + - label: I have checked that there is no similar [issue](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues) in the repo + required: true + - label: I have read the [documentation](https://sb3-contrib.readthedocs.io/en/master/) + required: true + - label: If code there is, it is minimal and working + required: true + - label: If code there is, it is formatted using the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces. + required: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e4b7958..ffdfda8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,6 +14,7 @@ jobs: env: TERM: xterm-256color FORCE_COLOR: 1 + ATARI_ROMS: ${{ secrets.ATARI_ROMS }} # Skip CI if [ci skip] in the commit message if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')" runs-on: ubuntu-latest @@ -32,24 +33,32 @@ jobs: python -m pip install --upgrade pip # cpu version of pytorch pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html + + # Install Atari Roms + pip install autorom + wget $ATARI_ROMS + base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz + AutoROM --accept-license --source-file Roms.tar.gz + # Install master version # and dependencies for docs and tests pip install git+https://github.com/DLR-RM/stable-baselines3#egg=stable_baselines3[extra,tests,docs] pip install . # Use headless version pip install opencv-python-headless + + - name: Lint with ruff + run: | + make lint + - name: Check codestyle + run: | + make check-codestyle - name: Build the doc run: | make doc - name: Type check run: | make type - - name: Check codestyle - run: | - make check-codestyle - - name: Lint with flake8 - run: | - make lint - name: Test with pytest run: | make pytest diff --git a/Makefile b/Makefile index 4b5d230..d3151b9 100644 --- a/Makefile +++ b/Makefile @@ -14,24 +14,17 @@ type: pytype mypy lint: # stop the build if there are Python syntax errors or undefined names - # see https://lintlyci.github.io/Flake8Rules/ - flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. - flake8 ${LINT_PATHS} --count --exit-zero --statistics - -ruff: - # stop the build if there are Python syntax errors or undefined names - # see https://lintlyci.github.io/Flake8Rules/ + # see https://www.flake8rules.com/ ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source # exit-zero treats all errors as warnings. - ruff ${LINT_PATHS} --exit-zero --line-length 127 + ruff ${LINT_PATHS} --exit-zero format: # Sort imports isort ${LINT_PATHS} # Reformat using black - black -l 127 ${LINT_PATHS} + black ${LINT_PATHS} check-codestyle: # Sort imports diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 14493ac..8a77467 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.8.0a4 (WIP) +Release 1.8.0a8 (WIP) -------------------------- Breaking Changes: @@ -22,6 +22,8 @@ Deprecations: Others: ^^^^^^^ +- Moved to pyproject.toml (except flake8) +- Added github issue forms Documentation: ^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index e16c432..8fdd02b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,64 @@ +[tool.ruff] +# Same as Black. +line-length = 127 +# Assume Python 3.7 +target-version = "py37" +select = ["E", "F", "B", "UP", "C90", "RUF"] +ignore = [] + +[tool.ruff.mccabe] +# Unlike Flake8, ruff default to a complexity level of 10. +max-complexity = 15 + [tool.black] line-length = 127 +[tool.isort] +profile = "black" +line_length = 127 +src_paths = ["sb3_contrib"] + [tool.pytype] inputs = ['sb3_contrib'] + +[tool.mypy] +ignore_missing_imports = true +follow_imports = "silent" +show_error_codes = true +exclude = """(?x)( + sb3_contrib/ppo_recurrent/ppo_recurrent.py$ + | sb3_contrib/ars/ars.py$ + | sb3_contrib/qrdqn/qrdqn.py$ + | sb3_contrib/qrdqn/policies.py$ + | sb3_contrib/common/recurrent/policies.py$ + | sb3_contrib/common/recurrent/buffers.py$ + | sb3_contrib/common/maskable/distributions.py$ + | sb3_contrib/common/maskable/callbacks.py$ + | sb3_contrib/common/maskable/policies.py$ + | sb3_contrib/common/maskable/buffers.py$ + | sb3_contrib/common/envs/invalid_actions_env.py$ + | sb3_contrib/common/vec_env/async_eval.py$ + | sb3_contrib/tqc/tqc.py$ + | sb3_contrib/tqc/policies.py$ + | sb3_contrib/trpo/trpo.py$ + | sb3_contrib/ppo_mask/ppo_mask.py$ + | tests/test_train_eval_mode.py$ + )""" + +[tool.pytest.ini_options] +# Deterministic ordering for tests; useful for pytest-xdist. +env = [ + "PYTHONHASHSEED=0" +] + +filterwarnings = [ + # Tensorboard warnings + "ignore::DeprecationWarning:tensorboard", + # Gym warnings + "ignore:Parameters to load are deprecated.:DeprecationWarning", + "ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning", + "ignore::UserWarning:gym", +] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')" +] diff --git a/sb3_contrib/common/maskable/evaluation.py b/sb3_contrib/common/maskable/evaluation.py index 4790d5c..8a9fba8 100644 --- a/sb3_contrib/common/maskable/evaluation.py +++ b/sb3_contrib/common/maskable/evaluation.py @@ -10,7 +10,7 @@ from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_suppo from sb3_contrib.ppo_mask import MaskablePPO -def evaluate_policy( # noqa: C901 +def evaluate_policy( model: MaskablePPO, env: Union[gym.Env, VecEnv], n_eval_episodes: int = 10, diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 23d487f..1cb7b44 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -229,7 +229,7 @@ class RecurrentRolloutBuffer(RolloutBuffer): return RecurrentRolloutBufferSamples( # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) - observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size,) + self.obs_shape), + observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size, *self.obs_shape)), actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), old_values=self.pad_and_flatten(self.values[batch_inds]), old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index 8066ccc..73ab0b1 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -149,9 +149,9 @@ def flat_grad( :param output: functional output to compute the gradient for :param parameters: sequence of ``Parameter`` - :param retain_graph: – If ``False``, the graph used to compute the grad will be freed. + :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to the value of ``create_graph``. - :param create_graph: – If ``True``, graph of the derivative will be constructed, + :param create_graph: If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default: ``False``. :return: Tensor containing the flattened gradients """ diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index f9f0a40..a17b4e1 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -264,7 +264,7 @@ class QRDQN(OffPolicyAlgorithm): ) def _excluded_save_params(self) -> List[str]: - return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"] + return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"] # noqa: RUF005 def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "policy.optimizer"] diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 8e19fa2..b6ea3cd 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -302,7 +302,7 @@ class TQC(OffPolicyAlgorithm): def _excluded_save_params(self) -> List[str]: # Exclude aliases - return super()._excluded_save_params() + ["actor", "critic", "critic_target"] + return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005 def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index e8175d3..8daa30f 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.8.0a4 +1.8.0a8 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index b446458..0000000 --- a/setup.cfg +++ /dev/null @@ -1,66 +0,0 @@ -[metadata] -# This includes the license file in the wheel. -license_files = LICENSE - -[tool:pytest] -# Deterministic ordering for tests; useful for pytest-xdist. -env = - PYTHONHASHSEED=0 -filterwarnings = - # Tensorboard warnings - ignore::DeprecationWarning:tensorboard - # Gym warnings - ignore:Parameters to load are deprecated.:DeprecationWarning - ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning - ignore::UserWarning:gym -markers = - slow: marks tests as slow (deselect with '-m "not slow"') - -[mypy] -ignore_missing_imports = True -follow_imports = silent -show_error_codes = True -exclude = (?x)( - sb3_contrib/ppo_recurrent/ppo_recurrent.py$ - | sb3_contrib/ars/ars.py$ - | sb3_contrib/qrdqn/qrdqn.py$ - | sb3_contrib/qrdqn/policies.py$ - | sb3_contrib/common/recurrent/policies.py$ - | sb3_contrib/common/recurrent/buffers.py$ - | sb3_contrib/common/maskable/distributions.py$ - | sb3_contrib/common/maskable/callbacks.py$ - | sb3_contrib/common/maskable/policies.py$ - | sb3_contrib/common/maskable/buffers.py$ - | sb3_contrib/common/envs/invalid_actions_env.py$ - | sb3_contrib/common/vec_env/async_eval.py$ - | sb3_contrib/tqc/tqc.py$ - | sb3_contrib/tqc/policies.py$ - | sb3_contrib/trpo/trpo.py$ - | sb3_contrib/ppo_mask/ppo_mask.py$ - | tests/test_train_eval_mode.py$ - ) - -[flake8] -# line breaks before and after binary operators -ignore = W503,W504,E203,E231 - -exclude = - # No need to traverse our git directory - .git, - # There's no value in checking cache directories - __pycache__, - # Don't check the doc - docs/ - # This contains our built documentation - build, - # This contains builds of flake8 that we don't want to check - dist - *.egg-info -max-complexity = 15 -# The GitHub editor is 127 chars wide -max-line-length = 127 - -[isort] -profile = black -line_length = 127 -src_paths = sb3_contrib diff --git a/setup.py b/setup.py index 0e48c57..2c74282 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.8.0a2", + "stable_baselines3>=1.8.0a8", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", @@ -79,6 +79,13 @@ setup( version=__version__, python_requires=">=3.7", # PyPI package information. + # PyPI package information. + project_urls={ + "Code": "https://github.com/Stable-Baselines-Team/stable-baselines3-contrib", + "Documentation": "https://sb3-contrib.readthedocs.io/", + "Stable-Baselines3": "https://github.com/DLR-RM/stable-baselines3", + "RL-Zoo": "https://github.com/DLR-RM/rl-baselines3-zoo", + }, classifiers=[ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.7", diff --git a/tests/wrappers/test_time_feature.py b/tests/wrappers/test_time_feature.py index 6b43a5b..bfbf33f 100644 --- a/tests/wrappers/test_time_feature.py +++ b/tests/wrappers/test_time_feature.py @@ -10,7 +10,7 @@ from sb3_contrib.common.wrappers import TimeFeatureWrapper class CustomGoalEnv(gym.GoalEnv): def __init__(self): - super(CustomGoalEnv, self).__init__() + super().__init__() self.observation_space = spaces.Dict( { "observation": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),