From 9cf8b5076f862313e3a1db069b67cba768fa69d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 23 Dec 2022 00:44:25 +0100 Subject: [PATCH] Construct tensors directly on GPUs (#128) * `to(device)` to `device=device` and `float()` to `dtype=th.float32` * Update changelog * Fix type checking Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 1 + sb3_contrib/ars/policies.py | 4 +++- sb3_contrib/common/recurrent/policies.py | 6 ++++-- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 14 +++++++------- sb3_contrib/tqc/tqc.py | 2 +- 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 564f426..0efe1b9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -39,6 +39,7 @@ Others: - Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong) - Removed ignores on Flake8 F401 (@ZikangXiong) - Upgraded GitHub CI/setup-python to v4 and checkout to v3 +- Set tensors construction directly on the device Release 1.6.2 (2022-10-10) -------------------------- diff --git a/sb3_contrib/ars/policies.py b/sb3_contrib/ars/policies.py index 4b1a677..b44e576 100644 --- a/sb3_contrib/ars/policies.py +++ b/sb3_contrib/ars/policies.py @@ -4,7 +4,7 @@ import gym import torch as th from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.torch_layers import create_mlp +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp from torch import nn @@ -68,6 +68,8 @@ class ARSPolicy(BasePolicy): return data def forward(self, obs: th.Tensor) -> th.Tensor: + # Make mypy happy: + assert isinstance(self.features_extractor, BaseFeaturesExtractor) features = self.extract_features(obs, self.features_extractor) if isinstance(self.action_space, gym.spaces.Box): diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 7cb5b37..953fafd 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -403,8 +403,10 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): with th.no_grad(): # Convert to PyTorch tensors - states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device) - episode_starts = th.tensor(episode_start).float().to(self.device) + states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor( + state[1], dtype=th.float32, device=self.device + ) + episode_starts = th.tensor(episode_start, dtype=th.float32, device=self.device) actions, states = self._predict( observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic ) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 12a8c7f..ee71471 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -163,12 +163,12 @@ class RecurrentPPO(OnPolicyAlgorithm): # hidden and cell states for actor and critic self._last_lstm_states = RNNStates( ( - th.zeros(single_hidden_state_shape).to(self.device), - th.zeros(single_hidden_state_shape).to(self.device), + th.zeros(single_hidden_state_shape, device=self.device), + th.zeros(single_hidden_state_shape, device=self.device), ), ( - th.zeros(single_hidden_state_shape).to(self.device), - th.zeros(single_hidden_state_shape).to(self.device), + th.zeros(single_hidden_state_shape, device=self.device), + th.zeros(single_hidden_state_shape, device=self.device), ), ) @@ -239,7 +239,7 @@ class RecurrentPPO(OnPolicyAlgorithm): with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) - episode_starts = th.tensor(self._last_episode_starts).float().to(self.device) + episode_starts = th.tensor(self._last_episode_starts, dtype=th.float32, device=self.device) actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) actions = actions.cpu().numpy() @@ -281,7 +281,7 @@ class RecurrentPPO(OnPolicyAlgorithm): lstm_states.vf[1][:, idx : idx + 1, :].contiguous(), ) # terminal_lstm_state = None - episode_starts = th.tensor([False]).float().to(self.device) + episode_starts = th.tensor([False], dtype=th.float32, device=self.device) terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] rewards[idx] += self.gamma * terminal_value @@ -301,7 +301,7 @@ class RecurrentPPO(OnPolicyAlgorithm): with th.no_grad(): # Compute value for the last timestep - episode_starts = th.tensor(dones).float().to(self.device) + episode_starts = th.tensor(dones, dtype=th.float32, device=self.device) values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 6ec91a1..17151e8 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -175,7 +175,7 @@ class TQC(OffPolicyAlgorithm): # Force conversion to float # this will throw an error if a malformed string (different from 'auto') # is passed - self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device) + self.ent_coef_tensor = th.tensor(float(self.ent_coef), device=self.device) def _create_aliases(self) -> None: self.actor = self.policy.actor