diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 564f426..0efe1b9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -39,6 +39,7 @@ Others: - Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong) - Removed ignores on Flake8 F401 (@ZikangXiong) - Upgraded GitHub CI/setup-python to v4 and checkout to v3 +- Set tensors construction directly on the device Release 1.6.2 (2022-10-10) -------------------------- diff --git a/sb3_contrib/ars/policies.py b/sb3_contrib/ars/policies.py index 4b1a677..b44e576 100644 --- a/sb3_contrib/ars/policies.py +++ b/sb3_contrib/ars/policies.py @@ -4,7 +4,7 @@ import gym import torch as th from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.torch_layers import create_mlp +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp from torch import nn @@ -68,6 +68,8 @@ class ARSPolicy(BasePolicy): return data def forward(self, obs: th.Tensor) -> th.Tensor: + # Make mypy happy: + assert isinstance(self.features_extractor, BaseFeaturesExtractor) features = self.extract_features(obs, self.features_extractor) if isinstance(self.action_space, gym.spaces.Box): diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 7cb5b37..953fafd 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -403,8 +403,10 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): with th.no_grad(): # Convert to PyTorch tensors - states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device) - episode_starts = th.tensor(episode_start).float().to(self.device) + states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor( + state[1], dtype=th.float32, device=self.device + ) + episode_starts = th.tensor(episode_start, dtype=th.float32, device=self.device) actions, states = self._predict( observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic ) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 12a8c7f..ee71471 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -163,12 +163,12 @@ class RecurrentPPO(OnPolicyAlgorithm): # hidden and cell states for actor and critic self._last_lstm_states = RNNStates( ( - th.zeros(single_hidden_state_shape).to(self.device), - th.zeros(single_hidden_state_shape).to(self.device), + th.zeros(single_hidden_state_shape, device=self.device), + th.zeros(single_hidden_state_shape, device=self.device), ), ( - th.zeros(single_hidden_state_shape).to(self.device), - th.zeros(single_hidden_state_shape).to(self.device), + th.zeros(single_hidden_state_shape, device=self.device), + th.zeros(single_hidden_state_shape, device=self.device), ), ) @@ -239,7 +239,7 @@ class RecurrentPPO(OnPolicyAlgorithm): with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) - episode_starts = th.tensor(self._last_episode_starts).float().to(self.device) + episode_starts = th.tensor(self._last_episode_starts, dtype=th.float32, device=self.device) actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) actions = actions.cpu().numpy() @@ -281,7 +281,7 @@ class RecurrentPPO(OnPolicyAlgorithm): lstm_states.vf[1][:, idx : idx + 1, :].contiguous(), ) # terminal_lstm_state = None - episode_starts = th.tensor([False]).float().to(self.device) + episode_starts = th.tensor([False], dtype=th.float32, device=self.device) terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] rewards[idx] += self.gamma * terminal_value @@ -301,7 +301,7 @@ class RecurrentPPO(OnPolicyAlgorithm): with th.no_grad(): # Compute value for the last timestep - episode_starts = th.tensor(dones).float().to(self.device) + episode_starts = th.tensor(dones, dtype=th.float32, device=self.device) values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 6ec91a1..17151e8 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -175,7 +175,7 @@ class TQC(OffPolicyAlgorithm): # Force conversion to float # this will throw an error if a malformed string (different from 'auto') # is passed - self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device) + self.ent_coef_tensor = th.tensor(float(self.ent_coef), device=self.device) def _create_aliases(self) -> None: self.actor = self.policy.actor