Support `device="auto"`for buffers and set it as default value (#98)
* Default device for buffer is auto * `device=auto` in ARS * Undo ARS change * Update changelog * Update min SB3 version Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
049f5a16e9
commit
7993b75781
|
|
@ -9,6 +9,8 @@ Release 1.6.1a2 (WIP)
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
|
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
|
||||||
|
- Upgraded to Stable-Baselines3 >= 1.6.1a2
|
||||||
|
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
@ -24,6 +26,11 @@ Bug Fixes:
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Others:
|
||||||
|
^^^^^^^
|
||||||
|
- Changed the default buffer device from ``"cpu"`` to ``"auto"``
|
||||||
|
|
||||||
|
|
||||||
Release 1.6.0 (2022-07-11)
|
Release 1.6.0 (2022-07-11)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,16 +35,25 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
||||||
:param buffer_size: Max number of element in the buffer
|
:param buffer_size: Max number of element in the buffer
|
||||||
:param observation_space: Observation space
|
:param observation_space: Observation space
|
||||||
:param action_space: Action space
|
:param action_space: Action space
|
||||||
:param device:
|
:param device: PyTorch device
|
||||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||||
Equivalent to classic advantage when set to 1.
|
Equivalent to classic advantage when set to 1.
|
||||||
:param gamma: Discount factor
|
:param gamma: Discount factor
|
||||||
:param n_envs: Number of parallel environments
|
:param n_envs: Number of parallel environments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
buffer_size: int,
|
||||||
|
observation_space: spaces.Space,
|
||||||
|
action_space: spaces.Space,
|
||||||
|
device: Union[th.device, str] = "auto",
|
||||||
|
gae_lambda: float = 1,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
n_envs: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
|
||||||
self.action_masks = None
|
self.action_masks = None
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
if isinstance(self.action_space, spaces.Discrete):
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
|
|
@ -127,7 +136,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
:param buffer_size: Max number of element in the buffer
|
:param buffer_size: Max number of element in the buffer
|
||||||
:param observation_space: Observation space
|
:param observation_space: Observation space
|
||||||
:param action_space: Action space
|
:param action_space: Action space
|
||||||
:param device:
|
:param device: PyTorch device
|
||||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||||
Equivalent to classic advantage when set to 1.
|
Equivalent to classic advantage when set to 1.
|
||||||
:param gamma: Discount factor
|
:param gamma: Discount factor
|
||||||
|
|
@ -139,7 +148,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
buffer_size: int,
|
buffer_size: int,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
device: Union[th.device, str] = "cpu",
|
device: Union[th.device, str] = "auto",
|
||||||
gae_lambda: float = 1,
|
gae_lambda: float = 1,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
n_envs: int = 1,
|
n_envs: int = 1,
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ class RecurrentRolloutBuffer(RolloutBuffer):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
hidden_state_shape: Tuple[int, int, int, int],
|
hidden_state_shape: Tuple[int, int, int, int],
|
||||||
device: Union[th.device, str] = "cpu",
|
device: Union[th.device, str] = "auto",
|
||||||
gae_lambda: float = 1,
|
gae_lambda: float = 1,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
n_envs: int = 1,
|
n_envs: int = 1,
|
||||||
|
|
@ -262,7 +262,7 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
hidden_state_shape: Tuple[int, int, int, int],
|
hidden_state_shape: Tuple[int, int, int, int],
|
||||||
device: Union[th.device, str] = "cpu",
|
device: Union[th.device, str] = "auto",
|
||||||
gae_lambda: float = 1,
|
gae_lambda: float = 1,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
n_envs: int = 1,
|
n_envs: int = 1,
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.6.0",
|
"stable_baselines3>=1.6.1a2",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue