Fix reshape LSTM states (#112)
* Fix LSTM states reshape * Fix warnings and update changelog * Remove unused variable * Fix runtime error when using n_lstm_layers > 1
This commit is contained in:
parent
c75ad7dd58
commit
a9735b9f31
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.7.0a0 (WIP)
|
||||
Release 1.7.0a1 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -17,6 +17,8 @@ New Features:
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed a bug in ``RecurrentPPO`` where the lstm states where incorrectly reshaped for ``n_lstm_layers > 1`` (thanks @kolbytn)
|
||||
- Fixed ``RuntimeError: rnn: hx is not contiguous`` while predicting terminal values for ``RecurrentPPO`` when ``n_lstm_layers > 1``
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -206,7 +206,6 @@ class RecurrentRolloutBuffer(RolloutBuffer):
|
|||
self.episode_starts[batch_inds], env_change[batch_inds], self.device
|
||||
)
|
||||
|
||||
n_layers = self.hidden_states_pi.shape[1]
|
||||
# Number of sequences
|
||||
n_seq = len(self.seq_start_indices)
|
||||
max_length = self.pad(self.actions[batch_inds]).shape[1]
|
||||
|
|
@ -214,17 +213,19 @@ class RecurrentRolloutBuffer(RolloutBuffer):
|
|||
# We retrieve the lstm hidden states that will allow
|
||||
# to properly initialize the LSTM at the beginning of each sequence
|
||||
lstm_states_pi = (
|
||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
# 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim)
|
||||
# 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim)
|
||||
# 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||
)
|
||||
lstm_states_vf = (
|
||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||
)
|
||||
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]))
|
||||
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1]))
|
||||
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
|
||||
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())
|
||||
|
||||
return RecurrentRolloutBufferSamples(
|
||||
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
|
||||
|
|
@ -349,24 +350,23 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer):
|
|||
self.episode_starts[batch_inds], env_change[batch_inds], self.device
|
||||
)
|
||||
|
||||
n_layers = self.hidden_states_pi.shape[1]
|
||||
n_seq = len(self.seq_start_indices)
|
||||
max_length = self.pad(self.actions[batch_inds]).shape[1]
|
||||
padded_batch_size = n_seq * max_length
|
||||
# We retrieve the lstm hidden states that will allow
|
||||
# to properly initialize the LSTM at the beginning of each sequence
|
||||
lstm_states_pi = (
|
||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||
)
|
||||
lstm_states_vf = (
|
||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
||||
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
|
||||
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||
)
|
||||
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]))
|
||||
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1]))
|
||||
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
|
||||
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())
|
||||
|
||||
observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()}
|
||||
observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()}
|
||||
|
|
|
|||
|
|
@ -277,8 +277,8 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
|
||||
with th.no_grad():
|
||||
terminal_lstm_state = (
|
||||
lstm_states.vf[0][:, idx : idx + 1, :],
|
||||
lstm_states.vf[1][:, idx : idx + 1, :],
|
||||
lstm_states.vf[0][:, idx : idx + 1, :].contiguous(),
|
||||
lstm_states.vf[1][:, idx : idx + 1, :].contiguous(),
|
||||
)
|
||||
# terminal_lstm_state = None
|
||||
episode_starts = th.tensor([False]).float().to(self.device)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a0
|
||||
1.7.0a1
|
||||
|
|
|
|||
|
|
@ -66,11 +66,13 @@ class CartPoleNoVelEnv(CartPoleEnv):
|
|||
enable_critic_lstm=True,
|
||||
lstm_hidden_size=4,
|
||||
lstm_kwargs=dict(dropout=0.5),
|
||||
n_lstm_layers=2,
|
||||
),
|
||||
dict(
|
||||
enable_critic_lstm=False,
|
||||
lstm_hidden_size=4,
|
||||
lstm_kwargs=dict(dropout=0.5),
|
||||
n_lstm_layers=2,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -95,11 +97,13 @@ def test_cnn(policy_kwargs):
|
|||
enable_critic_lstm=True,
|
||||
lstm_hidden_size=4,
|
||||
lstm_kwargs=dict(dropout=0.5),
|
||||
n_lstm_layers=2,
|
||||
),
|
||||
dict(
|
||||
enable_critic_lstm=False,
|
||||
lstm_hidden_size=4,
|
||||
lstm_kwargs=dict(dropout=0.5),
|
||||
n_lstm_layers=2,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -162,11 +166,13 @@ def test_run_sde():
|
|||
enable_critic_lstm=True,
|
||||
lstm_hidden_size=4,
|
||||
lstm_kwargs=dict(dropout=0.5),
|
||||
n_lstm_layers=2,
|
||||
),
|
||||
dict(
|
||||
enable_critic_lstm=False,
|
||||
lstm_hidden_size=4,
|
||||
lstm_kwargs=dict(dropout=0.5),
|
||||
n_lstm_layers=2,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue