Construct tensors directly on GPUs (#128)
* `to(device)` to `device=device` and `float()` to `dtype=th.float32` * Update changelog * Fix type checking Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
ab8684f469
commit
9cf8b5076f
|
|
@ -39,6 +39,7 @@ Others:
|
||||||
- Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong)
|
- Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong)
|
||||||
- Removed ignores on Flake8 F401 (@ZikangXiong)
|
- Removed ignores on Flake8 F401 (@ZikangXiong)
|
||||||
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
|
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
|
||||||
|
- Set tensors construction directly on the device
|
||||||
|
|
||||||
Release 1.6.2 (2022-10-10)
|
Release 1.6.2 (2022-10-10)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import gym
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
from stable_baselines3.common.torch_layers import create_mlp
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -68,6 +68,8 @@ class ARSPolicy(BasePolicy):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||||
|
# Make mypy happy:
|
||||||
|
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
|
||||||
|
|
||||||
features = self.extract_features(obs, self.features_extractor)
|
features = self.extract_features(obs, self.features_extractor)
|
||||||
if isinstance(self.action_space, gym.spaces.Box):
|
if isinstance(self.action_space, gym.spaces.Box):
|
||||||
|
|
|
||||||
|
|
@ -403,8 +403,10 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Convert to PyTorch tensors
|
# Convert to PyTorch tensors
|
||||||
states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device)
|
states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor(
|
||||||
episode_starts = th.tensor(episode_start).float().to(self.device)
|
state[1], dtype=th.float32, device=self.device
|
||||||
|
)
|
||||||
|
episode_starts = th.tensor(episode_start, dtype=th.float32, device=self.device)
|
||||||
actions, states = self._predict(
|
actions, states = self._predict(
|
||||||
observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic
|
observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -163,12 +163,12 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
# hidden and cell states for actor and critic
|
# hidden and cell states for actor and critic
|
||||||
self._last_lstm_states = RNNStates(
|
self._last_lstm_states = RNNStates(
|
||||||
(
|
(
|
||||||
th.zeros(single_hidden_state_shape).to(self.device),
|
th.zeros(single_hidden_state_shape, device=self.device),
|
||||||
th.zeros(single_hidden_state_shape).to(self.device),
|
th.zeros(single_hidden_state_shape, device=self.device),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
th.zeros(single_hidden_state_shape).to(self.device),
|
th.zeros(single_hidden_state_shape, device=self.device),
|
||||||
th.zeros(single_hidden_state_shape).to(self.device),
|
th.zeros(single_hidden_state_shape, device=self.device),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -239,7 +239,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Convert to pytorch tensor or to TensorDict
|
# Convert to pytorch tensor or to TensorDict
|
||||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||||
episode_starts = th.tensor(self._last_episode_starts).float().to(self.device)
|
episode_starts = th.tensor(self._last_episode_starts, dtype=th.float32, device=self.device)
|
||||||
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
|
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
|
||||||
|
|
||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
|
@ -281,7 +281,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
lstm_states.vf[1][:, idx : idx + 1, :].contiguous(),
|
lstm_states.vf[1][:, idx : idx + 1, :].contiguous(),
|
||||||
)
|
)
|
||||||
# terminal_lstm_state = None
|
# terminal_lstm_state = None
|
||||||
episode_starts = th.tensor([False]).float().to(self.device)
|
episode_starts = th.tensor([False], dtype=th.float32, device=self.device)
|
||||||
terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0]
|
terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0]
|
||||||
rewards[idx] += self.gamma * terminal_value
|
rewards[idx] += self.gamma * terminal_value
|
||||||
|
|
||||||
|
|
@ -301,7 +301,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Compute value for the last timestep
|
# Compute value for the last timestep
|
||||||
episode_starts = th.tensor(dones).float().to(self.device)
|
episode_starts = th.tensor(dones, dtype=th.float32, device=self.device)
|
||||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts)
|
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts)
|
||||||
|
|
||||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||||
|
|
|
||||||
|
|
@ -175,7 +175,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
# Force conversion to float
|
# Force conversion to float
|
||||||
# this will throw an error if a malformed string (different from 'auto')
|
# this will throw an error if a malformed string (different from 'auto')
|
||||||
# is passed
|
# is passed
|
||||||
self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)
|
self.ent_coef_tensor = th.tensor(float(self.ent_coef), device=self.device)
|
||||||
|
|
||||||
def _create_aliases(self) -> None:
|
def _create_aliases(self) -> None:
|
||||||
self.actor = self.policy.actor
|
self.actor = self.policy.actor
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue