Fix reshape LSTM states (#112)
* Fix LSTM states reshape * Fix warnings and update changelog * Remove unused variable * Fix runtime error when using n_lstm_layers > 1
This commit is contained in:
parent
c75ad7dd58
commit
a9735b9f31
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.7.0a0 (WIP)
|
Release 1.7.0a1 (WIP)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -17,6 +17,8 @@ New Features:
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
- Fixed a bug in ``RecurrentPPO`` where the lstm states where incorrectly reshaped for ``n_lstm_layers > 1`` (thanks @kolbytn)
|
||||||
|
- Fixed ``RuntimeError: rnn: hx is not contiguous`` while predicting terminal values for ``RecurrentPPO`` when ``n_lstm_layers > 1``
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -206,7 +206,6 @@ class RecurrentRolloutBuffer(RolloutBuffer):
|
||||||
self.episode_starts[batch_inds], env_change[batch_inds], self.device
|
self.episode_starts[batch_inds], env_change[batch_inds], self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
n_layers = self.hidden_states_pi.shape[1]
|
|
||||||
# Number of sequences
|
# Number of sequences
|
||||||
n_seq = len(self.seq_start_indices)
|
n_seq = len(self.seq_start_indices)
|
||||||
max_length = self.pad(self.actions[batch_inds]).shape[1]
|
max_length = self.pad(self.actions[batch_inds]).shape[1]
|
||||||
|
|
@ -214,17 +213,19 @@ class RecurrentRolloutBuffer(RolloutBuffer):
|
||||||
# We retrieve the lstm hidden states that will allow
|
# We retrieve the lstm hidden states that will allow
|
||||||
# to properly initialize the LSTM at the beginning of each sequence
|
# to properly initialize the LSTM at the beginning of each sequence
|
||||||
lstm_states_pi = (
|
lstm_states_pi = (
|
||||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
# 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim)
|
||||||
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
# 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim)
|
||||||
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
# 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim)
|
||||||
|
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||||
|
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||||
)
|
)
|
||||||
lstm_states_vf = (
|
lstm_states_vf = (
|
||||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
|
||||||
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||||
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||||
)
|
)
|
||||||
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]))
|
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
|
||||||
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1]))
|
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())
|
||||||
|
|
||||||
return RecurrentRolloutBufferSamples(
|
return RecurrentRolloutBufferSamples(
|
||||||
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
|
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
|
||||||
|
|
@ -349,24 +350,23 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer):
|
||||||
self.episode_starts[batch_inds], env_change[batch_inds], self.device
|
self.episode_starts[batch_inds], env_change[batch_inds], self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
n_layers = self.hidden_states_pi.shape[1]
|
|
||||||
n_seq = len(self.seq_start_indices)
|
n_seq = len(self.seq_start_indices)
|
||||||
max_length = self.pad(self.actions[batch_inds]).shape[1]
|
max_length = self.pad(self.actions[batch_inds]).shape[1]
|
||||||
padded_batch_size = n_seq * max_length
|
padded_batch_size = n_seq * max_length
|
||||||
# We retrieve the lstm hidden states that will allow
|
# We retrieve the lstm hidden states that will allow
|
||||||
# to properly initialize the LSTM at the beginning of each sequence
|
# to properly initialize the LSTM at the beginning of each sequence
|
||||||
lstm_states_pi = (
|
lstm_states_pi = (
|
||||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
|
||||||
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||||
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||||
)
|
)
|
||||||
lstm_states_vf = (
|
lstm_states_vf = (
|
||||||
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
|
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
|
||||||
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||||
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
|
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
|
||||||
)
|
)
|
||||||
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]))
|
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
|
||||||
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1]))
|
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())
|
||||||
|
|
||||||
observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()}
|
observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()}
|
||||||
observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()}
|
observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()}
|
||||||
|
|
|
||||||
|
|
@ -277,8 +277,8 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
|
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
terminal_lstm_state = (
|
terminal_lstm_state = (
|
||||||
lstm_states.vf[0][:, idx : idx + 1, :],
|
lstm_states.vf[0][:, idx : idx + 1, :].contiguous(),
|
||||||
lstm_states.vf[1][:, idx : idx + 1, :],
|
lstm_states.vf[1][:, idx : idx + 1, :].contiguous(),
|
||||||
)
|
)
|
||||||
# terminal_lstm_state = None
|
# terminal_lstm_state = None
|
||||||
episode_starts = th.tensor([False]).float().to(self.device)
|
episode_starts = th.tensor([False]).float().to(self.device)
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.7.0a0
|
1.7.0a1
|
||||||
|
|
|
||||||
|
|
@ -66,11 +66,13 @@ class CartPoleNoVelEnv(CartPoleEnv):
|
||||||
enable_critic_lstm=True,
|
enable_critic_lstm=True,
|
||||||
lstm_hidden_size=4,
|
lstm_hidden_size=4,
|
||||||
lstm_kwargs=dict(dropout=0.5),
|
lstm_kwargs=dict(dropout=0.5),
|
||||||
|
n_lstm_layers=2,
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
enable_critic_lstm=False,
|
enable_critic_lstm=False,
|
||||||
lstm_hidden_size=4,
|
lstm_hidden_size=4,
|
||||||
lstm_kwargs=dict(dropout=0.5),
|
lstm_kwargs=dict(dropout=0.5),
|
||||||
|
n_lstm_layers=2,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -95,11 +97,13 @@ def test_cnn(policy_kwargs):
|
||||||
enable_critic_lstm=True,
|
enable_critic_lstm=True,
|
||||||
lstm_hidden_size=4,
|
lstm_hidden_size=4,
|
||||||
lstm_kwargs=dict(dropout=0.5),
|
lstm_kwargs=dict(dropout=0.5),
|
||||||
|
n_lstm_layers=2,
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
enable_critic_lstm=False,
|
enable_critic_lstm=False,
|
||||||
lstm_hidden_size=4,
|
lstm_hidden_size=4,
|
||||||
lstm_kwargs=dict(dropout=0.5),
|
lstm_kwargs=dict(dropout=0.5),
|
||||||
|
n_lstm_layers=2,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -162,11 +166,13 @@ def test_run_sde():
|
||||||
enable_critic_lstm=True,
|
enable_critic_lstm=True,
|
||||||
lstm_hidden_size=4,
|
lstm_hidden_size=4,
|
||||||
lstm_kwargs=dict(dropout=0.5),
|
lstm_kwargs=dict(dropout=0.5),
|
||||||
|
n_lstm_layers=2,
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
enable_critic_lstm=False,
|
enable_critic_lstm=False,
|
||||||
lstm_hidden_size=4,
|
lstm_hidden_size=4,
|
||||||
lstm_kwargs=dict(dropout=0.5),
|
lstm_kwargs=dict(dropout=0.5),
|
||||||
|
n_lstm_layers=2,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue