Support `device="auto"`for buffers and set it as default value (#98)
* Default device for buffer is auto * `device=auto` in ARS * Undo ARS change * Update changelog * Update min SB3 version Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
049f5a16e9
commit
7993b75781
|
|
@ -9,6 +9,8 @@ Release 1.6.1a2 (WIP)
|
|||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
|
||||
- Upgraded to Stable-Baselines3 >= 1.6.1a2
|
||||
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -24,6 +26,11 @@ Bug Fixes:
|
|||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Changed the default buffer device from ``"cpu"`` to ``"auto"``
|
||||
|
||||
|
||||
Release 1.6.0 (2022-07-11)
|
||||
-------------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -35,16 +35,25 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param device:
|
||||
:param device: PyTorch device
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
Equivalent to classic advantage when set to 1.
|
||||
:param gamma: Discount factor
|
||||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "auto",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
):
|
||||
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
|
||||
self.action_masks = None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset(self) -> None:
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
|
|
@ -127,7 +136,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param device:
|
||||
:param device: PyTorch device
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
Equivalent to classic advantage when set to 1.
|
||||
:param gamma: Discount factor
|
||||
|
|
@ -139,7 +148,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "cpu",
|
||||
device: Union[th.device, str] = "auto",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class RecurrentRolloutBuffer(RolloutBuffer):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
hidden_state_shape: Tuple[int, int, int, int],
|
||||
device: Union[th.device, str] = "cpu",
|
||||
device: Union[th.device, str] = "auto",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
|
|
@ -262,7 +262,7 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
hidden_state_shape: Tuple[int, int, int, int],
|
||||
device: Union[th.device, str] = "cpu",
|
||||
device: Union[th.device, str] = "auto",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.6.0",
|
||||
"stable_baselines3>=1.6.1a2",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
Loading…
Reference in New Issue