Construct tensors directly on GPUs (#128)
* `to(device)` to `device=device` and `float()` to `dtype=th.float32` * Update changelog * Fix type checking Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
ab8684f469
commit
9cf8b5076f
|
|
@ -39,6 +39,7 @@ Others:
|
|||
- Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong)
|
||||
- Removed ignores on Flake8 F401 (@ZikangXiong)
|
||||
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
|
||||
- Set tensors construction directly on the device
|
||||
|
||||
Release 1.6.2 (2022-10-10)
|
||||
--------------------------
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import gym
|
|||
import torch as th
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import create_mlp
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
|
@ -68,6 +68,8 @@ class ARSPolicy(BasePolicy):
|
|||
return data
|
||||
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
# Make mypy happy:
|
||||
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
|
||||
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
|
|
|
|||
|
|
@ -403,8 +403,10 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
|
||||
with th.no_grad():
|
||||
# Convert to PyTorch tensors
|
||||
states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device)
|
||||
episode_starts = th.tensor(episode_start).float().to(self.device)
|
||||
states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor(
|
||||
state[1], dtype=th.float32, device=self.device
|
||||
)
|
||||
episode_starts = th.tensor(episode_start, dtype=th.float32, device=self.device)
|
||||
actions, states = self._predict(
|
||||
observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic
|
||||
)
|
||||
|
|
|
|||
|
|
@ -163,12 +163,12 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
# hidden and cell states for actor and critic
|
||||
self._last_lstm_states = RNNStates(
|
||||
(
|
||||
th.zeros(single_hidden_state_shape).to(self.device),
|
||||
th.zeros(single_hidden_state_shape).to(self.device),
|
||||
th.zeros(single_hidden_state_shape, device=self.device),
|
||||
th.zeros(single_hidden_state_shape, device=self.device),
|
||||
),
|
||||
(
|
||||
th.zeros(single_hidden_state_shape).to(self.device),
|
||||
th.zeros(single_hidden_state_shape).to(self.device),
|
||||
th.zeros(single_hidden_state_shape, device=self.device),
|
||||
th.zeros(single_hidden_state_shape, device=self.device),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -239,7 +239,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
with th.no_grad():
|
||||
# Convert to pytorch tensor or to TensorDict
|
||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||
episode_starts = th.tensor(self._last_episode_starts).float().to(self.device)
|
||||
episode_starts = th.tensor(self._last_episode_starts, dtype=th.float32, device=self.device)
|
||||
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
|
||||
|
||||
actions = actions.cpu().numpy()
|
||||
|
|
@ -281,7 +281,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
lstm_states.vf[1][:, idx : idx + 1, :].contiguous(),
|
||||
)
|
||||
# terminal_lstm_state = None
|
||||
episode_starts = th.tensor([False]).float().to(self.device)
|
||||
episode_starts = th.tensor([False], dtype=th.float32, device=self.device)
|
||||
terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0]
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
|
|
@ -301,7 +301,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
|
||||
with th.no_grad():
|
||||
# Compute value for the last timestep
|
||||
episode_starts = th.tensor(dones).float().to(self.device)
|
||||
episode_starts = th.tensor(dones, dtype=th.float32, device=self.device)
|
||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts)
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
# Force conversion to float
|
||||
# this will throw an error if a malformed string (different from 'auto')
|
||||
# is passed
|
||||
self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)
|
||||
self.ent_coef_tensor = th.tensor(float(self.ent_coef), device=self.device)
|
||||
|
||||
def _create_aliases(self) -> None:
|
||||
self.actor = self.policy.actor
|
||||
|
|
|
|||
Loading…
Reference in New Issue