Fix returned type in predict (#88)

* actions[0] -> actions.squeeze(0)

* Update changelog

* Update changelog

* Update version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-07-18 11:49:03 +02:00 committed by GitHub
parent c9d621b816
commit 3cbd2429be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 4 deletions

View File

@ -3,6 +3,22 @@
Changelog
==========
Release 1.6.1a0 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Release 1.6.0 (2022-07-11)
-------------------------------
@ -276,4 +292,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
Contributors:
-------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec

View File

@ -254,7 +254,7 @@ class MaskableActorCriticPolicy(BasePolicy):
if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
actions = actions[0]
actions = actions.squeeze(axis=0)
return actions, None

View File

@ -416,7 +416,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
# Remove batch dimension if needed
if not vectorized_env:
actions = actions[0]
actions = actions.squeeze(axis=0)
return actions, states

View File

@ -1 +1 @@
1.6.0
1.6.1a0