From 3cbd2429be7cf94d40c72e1e8801743345999ace Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 18 Jul 2022 11:49:03 +0200 Subject: [PATCH] Fix returned type in predict (#88) * actions[0] -> actions.squeeze(0) * Update changelog * Update changelog * Update version Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 18 +++++++++++++++++- sb3_contrib/common/maskable/policies.py | 2 +- sb3_contrib/common/recurrent/policies.py | 2 +- sb3_contrib/version.txt | 2 +- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2369702..44ffe2b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,22 @@ Changelog ========== +Release 1.6.1a0 (WIP) +------------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec) + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + Release 1.6.0 (2022-07-11) ------------------------------- @@ -276,4 +292,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) Contributors: ------------- -@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt +@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index a26f6fb..37b9e7f 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -254,7 +254,7 @@ class MaskableActorCriticPolicy(BasePolicy): if not vectorized_env: if state is not None: raise ValueError("Error: The environment must be vectorized when using recurrent policies.") - actions = actions[0] + actions = actions.squeeze(axis=0) return actions, None diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 16f1c20..2285baa 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -416,7 +416,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): # Remove batch dimension if needed if not vectorized_env: - actions = actions[0] + actions = actions.squeeze(axis=0) return actions, states diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index dc1e644..035e3b6 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.6.0 +1.6.1a0