Construct tensors directly on GPUs (#128)

* `to(device)` to `device=device` and `float()` to `dtype=th.float32`

* Update changelog

* Fix type checking

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-12-23 00:44:25 +01:00 committed by GitHub
parent ab8684f469
commit 9cf8b5076f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 16 additions and 11 deletions

View File

@ -39,6 +39,7 @@ Others:
- Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong) - Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong)
- Removed ignores on Flake8 F401 (@ZikangXiong) - Removed ignores on Flake8 F401 (@ZikangXiong)
- Upgraded GitHub CI/setup-python to v4 and checkout to v3 - Upgraded GitHub CI/setup-python to v4 and checkout to v3
- Set tensors construction directly on the device
Release 1.6.2 (2022-10-10) Release 1.6.2 (2022-10-10)
-------------------------- --------------------------

View File

@ -4,7 +4,7 @@ import gym
import torch as th import torch as th
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import create_mlp from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
from torch import nn from torch import nn
@ -68,6 +68,8 @@ class ARSPolicy(BasePolicy):
return data return data
def forward(self, obs: th.Tensor) -> th.Tensor: def forward(self, obs: th.Tensor) -> th.Tensor:
# Make mypy happy:
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
features = self.extract_features(obs, self.features_extractor) features = self.extract_features(obs, self.features_extractor)
if isinstance(self.action_space, gym.spaces.Box): if isinstance(self.action_space, gym.spaces.Box):

View File

@ -403,8 +403,10 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
with th.no_grad(): with th.no_grad():
# Convert to PyTorch tensors # Convert to PyTorch tensors
states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device) states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor(
episode_starts = th.tensor(episode_start).float().to(self.device) state[1], dtype=th.float32, device=self.device
)
episode_starts = th.tensor(episode_start, dtype=th.float32, device=self.device)
actions, states = self._predict( actions, states = self._predict(
observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic
) )

View File

@ -163,12 +163,12 @@ class RecurrentPPO(OnPolicyAlgorithm):
# hidden and cell states for actor and critic # hidden and cell states for actor and critic
self._last_lstm_states = RNNStates( self._last_lstm_states = RNNStates(
( (
th.zeros(single_hidden_state_shape).to(self.device), th.zeros(single_hidden_state_shape, device=self.device),
th.zeros(single_hidden_state_shape).to(self.device), th.zeros(single_hidden_state_shape, device=self.device),
), ),
( (
th.zeros(single_hidden_state_shape).to(self.device), th.zeros(single_hidden_state_shape, device=self.device),
th.zeros(single_hidden_state_shape).to(self.device), th.zeros(single_hidden_state_shape, device=self.device),
), ),
) )
@ -239,7 +239,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
with th.no_grad(): with th.no_grad():
# Convert to pytorch tensor or to TensorDict # Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device) obs_tensor = obs_as_tensor(self._last_obs, self.device)
episode_starts = th.tensor(self._last_episode_starts).float().to(self.device) episode_starts = th.tensor(self._last_episode_starts, dtype=th.float32, device=self.device)
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
actions = actions.cpu().numpy() actions = actions.cpu().numpy()
@ -281,7 +281,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
lstm_states.vf[1][:, idx : idx + 1, :].contiguous(), lstm_states.vf[1][:, idx : idx + 1, :].contiguous(),
) )
# terminal_lstm_state = None # terminal_lstm_state = None
episode_starts = th.tensor([False]).float().to(self.device) episode_starts = th.tensor([False], dtype=th.float32, device=self.device)
terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0]
rewards[idx] += self.gamma * terminal_value rewards[idx] += self.gamma * terminal_value
@ -301,7 +301,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
with th.no_grad(): with th.no_grad():
# Compute value for the last timestep # Compute value for the last timestep
episode_starts = th.tensor(dones).float().to(self.device) episode_starts = th.tensor(dones, dtype=th.float32, device=self.device)
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts) values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts)
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)

View File

@ -175,7 +175,7 @@ class TQC(OffPolicyAlgorithm):
# Force conversion to float # Force conversion to float
# this will throw an error if a malformed string (different from 'auto') # this will throw an error if a malformed string (different from 'auto')
# is passed # is passed
self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device) self.ent_coef_tensor = th.tensor(float(self.ent_coef), device=self.device)
def _create_aliases(self) -> None: def _create_aliases(self) -> None:
self.actor = self.policy.actor self.actor = self.policy.actor