From ded9f65bfd525843675ffd76058f42fce9fe1e0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 29 Nov 2022 10:24:44 +0100 Subject: [PATCH] Fix `sb3_contrib/common/utils.py` type hint (#120) * Update contribution.md * New loop struct to make mypy happy * Update setup.cfg * Update changelog * Add return statement in cg solver for max_iter=0 Co-authored-by: Antonin Raffin --- CONTRIBUTING.md | 4 ++-- docs/misc/changelog.rst | 1 + sb3_contrib/common/utils.py | 2 ++ setup.cfg | 1 - 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1112e98..9013df7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -152,7 +152,7 @@ To run tests with `pytest`: make pytest ``` -Type checking with `pytype`: +Type checking with `pytype` and `mypy`: ``` make type @@ -165,7 +165,7 @@ make check-codestyle make lint ``` -To run `pytype`, `format` and `lint` in one command: +To run `type`, `format` and `lint` in one command: ``` make commit-checks ``` diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7a9bdf1..1d7a288 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -27,6 +27,7 @@ Deprecations: Others: ^^^^^^^ - Fixed flake8 config +- Fixed ``sb3_contrib/common/utils.py`` type hint Release 1.6.2 (2022-10-10) -------------------------- diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index d380bb7..8066ccc 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -133,6 +133,8 @@ def conjugate_gradient_solver( beta = new_residual_squared_norm / residual_squared_norm residual_squared_norm = new_residual_squared_norm p = residual + beta * p + # Note: this return statement is only used when max_iter=0 + return x def flat_grad( diff --git a/setup.cfg b/setup.cfg index 75afed5..dad73ae 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,6 @@ exclude = (?x)( | sb3_contrib/common/recurrent/type_aliases.py$ | sb3_contrib/common/recurrent/policies.py$ | sb3_contrib/common/recurrent/buffers.py$ - | sb3_contrib/common/utils.py$ | sb3_contrib/common/maskable/distributions.py$ | sb3_contrib/common/maskable/callbacks.py$ | sb3_contrib/common/maskable/policies.py$