Fix reshape LSTM states (#112)

* Fix LSTM states reshape

* Fix warnings and update changelog

* Remove unused variable

* Fix runtime error when using n_lstm_layers > 1
This commit is contained in:
Antonin RAFFIN 2022-10-26 18:03:45 +02:00 committed by GitHub
parent c75ad7dd58
commit a9735b9f31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 30 additions and 22 deletions

View File

@ -3,7 +3,7 @@
Changelog Changelog
========== ==========
Release 1.7.0a0 (WIP) Release 1.7.0a1 (WIP)
-------------------------- --------------------------
Breaking Changes: Breaking Changes:
@ -17,6 +17,8 @@ New Features:
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^
- Fixed a bug in ``RecurrentPPO`` where the lstm states where incorrectly reshaped for ``n_lstm_layers > 1`` (thanks @kolbytn)
- Fixed ``RuntimeError: rnn: hx is not contiguous`` while predicting terminal values for ``RecurrentPPO`` when ``n_lstm_layers > 1``
Deprecations: Deprecations:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^

View File

@ -206,7 +206,6 @@ class RecurrentRolloutBuffer(RolloutBuffer):
self.episode_starts[batch_inds], env_change[batch_inds], self.device self.episode_starts[batch_inds], env_change[batch_inds], self.device
) )
n_layers = self.hidden_states_pi.shape[1]
# Number of sequences # Number of sequences
n_seq = len(self.seq_start_indices) n_seq = len(self.seq_start_indices)
max_length = self.pad(self.actions[batch_inds]).shape[1] max_length = self.pad(self.actions[batch_inds]).shape[1]
@ -214,17 +213,19 @@ class RecurrentRolloutBuffer(RolloutBuffer):
# We retrieve the lstm hidden states that will allow # We retrieve the lstm hidden states that will allow
# to properly initialize the LSTM at the beginning of each sequence # to properly initialize the LSTM at the beginning of each sequence
lstm_states_pi = ( lstm_states_pi = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) # 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), # 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim)
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), # 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
) )
lstm_states_vf = ( lstm_states_vf = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
) )
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())
return RecurrentRolloutBufferSamples( return RecurrentRolloutBufferSamples(
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
@ -349,24 +350,23 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer):
self.episode_starts[batch_inds], env_change[batch_inds], self.device self.episode_starts[batch_inds], env_change[batch_inds], self.device
) )
n_layers = self.hidden_states_pi.shape[1]
n_seq = len(self.seq_start_indices) n_seq = len(self.seq_start_indices)
max_length = self.pad(self.actions[batch_inds]).shape[1] max_length = self.pad(self.actions[batch_inds]).shape[1]
padded_batch_size = n_seq * max_length padded_batch_size = n_seq * max_length
# We retrieve the lstm hidden states that will allow # We retrieve the lstm hidden states that will allow
# to properly initialize the LSTM at the beginning of each sequence # to properly initialize the LSTM at the beginning of each sequence
lstm_states_pi = ( lstm_states_pi = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
) )
lstm_states_vf = ( lstm_states_vf = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
) )
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())
observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()}
observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()} observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()}

View File

@ -277,8 +277,8 @@ class RecurrentPPO(OnPolicyAlgorithm):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad(): with th.no_grad():
terminal_lstm_state = ( terminal_lstm_state = (
lstm_states.vf[0][:, idx : idx + 1, :], lstm_states.vf[0][:, idx : idx + 1, :].contiguous(),
lstm_states.vf[1][:, idx : idx + 1, :], lstm_states.vf[1][:, idx : idx + 1, :].contiguous(),
) )
# terminal_lstm_state = None # terminal_lstm_state = None
episode_starts = th.tensor([False]).float().to(self.device) episode_starts = th.tensor([False]).float().to(self.device)

View File

@ -1 +1 @@
1.7.0a0 1.7.0a1

View File

@ -66,11 +66,13 @@ class CartPoleNoVelEnv(CartPoleEnv):
enable_critic_lstm=True, enable_critic_lstm=True,
lstm_hidden_size=4, lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5), lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
), ),
dict( dict(
enable_critic_lstm=False, enable_critic_lstm=False,
lstm_hidden_size=4, lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5), lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
), ),
], ],
) )
@ -95,11 +97,13 @@ def test_cnn(policy_kwargs):
enable_critic_lstm=True, enable_critic_lstm=True,
lstm_hidden_size=4, lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5), lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
), ),
dict( dict(
enable_critic_lstm=False, enable_critic_lstm=False,
lstm_hidden_size=4, lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5), lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
), ),
], ],
) )
@ -162,11 +166,13 @@ def test_run_sde():
enable_critic_lstm=True, enable_critic_lstm=True,
lstm_hidden_size=4, lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5), lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
), ),
dict( dict(
enable_critic_lstm=False, enable_critic_lstm=False,
lstm_hidden_size=4, lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5), lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
), ),
], ],
) )