diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5c379f4..9c12ef2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -9,6 +9,8 @@ Release 1.6.1a2 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ - Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec) +- Upgraded to Stable-Baselines3 >= 1.6.1a2 + New Features: ^^^^^^^^^^^^^ @@ -24,6 +26,11 @@ Bug Fixes: Deprecations: ^^^^^^^^^^^^^ +Others: +^^^^^^^ +- Changed the default buffer device from ``"cpu"`` to ``"auto"`` + + Release 1.6.0 (2022-07-11) ------------------------------- diff --git a/sb3_contrib/common/maskable/buffers.py b/sb3_contrib/common/maskable/buffers.py index adfbef7..1a43539 100644 --- a/sb3_contrib/common/maskable/buffers.py +++ b/sb3_contrib/common/maskable/buffers.py @@ -35,16 +35,25 @@ class MaskableRolloutBuffer(RolloutBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. :param gamma: Discount factor :param n_envs: Number of parallel environments """ - def __init__(self, *args, **kwargs): + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) self.action_masks = None - super().__init__(*args, **kwargs) def reset(self) -> None: if isinstance(self.action_space, spaces.Discrete): @@ -127,7 +136,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. :param gamma: Discount factor @@ -139,7 +148,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer): buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 88ff425..5385668 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -116,7 +116,7 @@ class RecurrentRolloutBuffer(RolloutBuffer): observation_space: spaces.Space, action_space: spaces.Space, hidden_state_shape: Tuple[int, int, int, int], - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, @@ -262,7 +262,7 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer): observation_space: spaces.Space, action_space: spaces.Space, hidden_state_shape: Tuple[int, int, int, int], - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, diff --git a/setup.py b/setup.py index bcc2c3c..7b3ab98 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.6.0", + "stable_baselines3>=1.6.1a2", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin",