diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2369702..44ffe2b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,22 @@ Changelog ========== +Release 1.6.1a0 (WIP) +------------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec) + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + Release 1.6.0 (2022-07-11) ------------------------------- @@ -276,4 +292,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) Contributors: ------------- -@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt +@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index a26f6fb..37b9e7f 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -254,7 +254,7 @@ class MaskableActorCriticPolicy(BasePolicy): if not vectorized_env: if state is not None: raise ValueError("Error: The environment must be vectorized when using recurrent policies.") - actions = actions[0] + actions = actions.squeeze(axis=0) return actions, None diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 16f1c20..2285baa 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -416,7 +416,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): # Remove batch dimension if needed if not vectorized_env: - actions = actions[0] + actions = actions.squeeze(axis=0) return actions, states diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index dc1e644..035e3b6 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.6.0 +1.6.1a0