Support `device="auto"`for buffers and set it as default value (#98)

* Default device for buffer is auto

* `device=auto` in ARS

* Undo ARS change

* Update changelog

* Update min SB3 version

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-08-24 09:48:18 +02:00 committed by GitHub
parent 049f5a16e9
commit 7993b75781
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 8 deletions

View File

@ -9,6 +9,8 @@ Release 1.6.1a2 (WIP)
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
- Upgraded to Stable-Baselines3 >= 1.6.1a2
New Features:
^^^^^^^^^^^^^
@ -24,6 +26,11 @@ Bug Fixes:
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Changed the default buffer device from ``"cpu"`` to ``"auto"``
Release 1.6.0 (2022-07-11)
-------------------------------

View File

@ -35,16 +35,25 @@ class MaskableRolloutBuffer(RolloutBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
def __init__(self, *args, **kwargs):
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
self.action_masks = None
super().__init__(*args, **kwargs)
def reset(self) -> None:
if isinstance(self.action_space, spaces.Discrete):
@ -127,7 +136,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
@ -139,7 +148,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,

View File

@ -116,7 +116,7 @@ class RecurrentRolloutBuffer(RolloutBuffer):
observation_space: spaces.Space,
action_space: spaces.Space,
hidden_state_shape: Tuple[int, int, int, int],
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
@ -262,7 +262,7 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer):
observation_space: spaces.Space,
action_space: spaces.Space,
hidden_state_shape: Tuple[int, int, int, int],
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.6.0",
"stable_baselines3>=1.6.1a2",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",