From b3e4ddd09abafe8cdfbdd64c3a014fc2d1d2d690 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 29 Nov 2022 10:41:07 +0100 Subject: [PATCH] Fix `sb3_contrib/common/recurrent/type_aliases.py` type hint (#121) * Update setup.cfg * Update changelog * Update type aliases --- docs/misc/changelog.rst | 1 + sb3_contrib/common/recurrent/type_aliases.py | 2 +- setup.cfg | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 1d7a288..41e40ba 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -28,6 +28,7 @@ Others: ^^^^^^^ - Fixed flake8 config - Fixed ``sb3_contrib/common/utils.py`` type hint +- Fixed ``sb3_contrib/common/recurrent/type_aliases.py`` type hint Release 1.6.2 (2022-10-10) -------------------------- diff --git a/sb3_contrib/common/recurrent/type_aliases.py b/sb3_contrib/common/recurrent/type_aliases.py index 1ae9a08..21ac0e0 100644 --- a/sb3_contrib/common/recurrent/type_aliases.py +++ b/sb3_contrib/common/recurrent/type_aliases.py @@ -21,7 +21,7 @@ class RecurrentRolloutBufferSamples(NamedTuple): mask: th.Tensor -class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples): +class RecurrentDictRolloutBufferSamples(NamedTuple): observations: TensorDict actions: th.Tensor old_values: th.Tensor diff --git a/setup.cfg b/setup.cfg index dad73ae..4004a6b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,6 @@ exclude = (?x)( | sb3_contrib/ars/policies.py$ | sb3_contrib/qrdqn/qrdqn.py$ | sb3_contrib/qrdqn/policies.py$ - | sb3_contrib/common/recurrent/type_aliases.py$ | sb3_contrib/common/recurrent/policies.py$ | sb3_contrib/common/recurrent/buffers.py$ | sb3_contrib/common/maskable/distributions.py$