Fixed bugs in that lead to wrong results, currently only working with 2
critics
This commit is contained in:
parent
a14ae69b6b
commit
875b8bca0d
|
|
@ -6,6 +6,7 @@ from sb3_contrib.ppo_recurrent import RecurrentPPO
|
||||||
from sb3_contrib.qrdqn import QRDQN
|
from sb3_contrib.qrdqn import QRDQN
|
||||||
from sb3_contrib.tqc import TQC
|
from sb3_contrib.tqc import TQC
|
||||||
from sb3_contrib.trpo import TRPO
|
from sb3_contrib.trpo import TRPO
|
||||||
|
from sb3_contrib.sacd import SACD
|
||||||
|
|
||||||
# Read version from file
|
# Read version from file
|
||||||
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
|
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
|
||||||
|
|
@ -19,4 +20,5 @@ __all__ = [
|
||||||
"QRDQN",
|
"QRDQN",
|
||||||
"TQC",
|
"TQC",
|
||||||
"TRPO",
|
"TRPO",
|
||||||
|
"SACD",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -178,13 +178,13 @@ class DiscreteCritic(BaseModel):
|
||||||
normalize_images=normalize_images,
|
normalize_images=normalize_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
action_dim = get_action_dim(self.action_space)
|
num_actions = self.action_space.n
|
||||||
|
|
||||||
self.share_features_extractor = share_features_extractor
|
self.share_features_extractor = share_features_extractor
|
||||||
self.n_critics = n_critics
|
self.n_critics = n_critics
|
||||||
self.q_networks = []
|
self.q_networks = []
|
||||||
for idx in range(n_critics):
|
for idx in range(n_critics):
|
||||||
q_net = create_mlp(features_dim, action_dim, net_arch, activation_fn)
|
q_net = create_mlp(features_dim, num_actions, net_arch, activation_fn)
|
||||||
q_net = nn.Sequential(*q_net)
|
q_net = nn.Sequential(*q_net)
|
||||||
self.add_module(f"qf{idx}", q_net)
|
self.add_module(f"qf{idx}", q_net)
|
||||||
self.q_networks.append(q_net)
|
self.q_networks.append(q_net)
|
||||||
|
|
|
||||||
|
|
@ -186,6 +186,7 @@ class SACD(OffPolicyAlgorithm):
|
||||||
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
|
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
|
||||||
# self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
|
# self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
|
||||||
self.log_ent_coef = th.zeros(1, device=self.device, requires_grad=True)
|
self.log_ent_coef = th.zeros(1, device=self.device, requires_grad=True)
|
||||||
|
self.ent_coef = th.exp(self.log_ent_coef)
|
||||||
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
|
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
|
||||||
else:
|
else:
|
||||||
# Force conversion to float
|
# Force conversion to float
|
||||||
|
|
@ -216,80 +217,19 @@ class SACD(OffPolicyAlgorithm):
|
||||||
# Sample replay buffer
|
# Sample replay buffer
|
||||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||||
|
|
||||||
# We need to sample because `log_std` may have changed between two gradient steps
|
# Compute the critic loss
|
||||||
# if self.use_sde:
|
critic_loss = self.calc_critic_loss(replay_data)
|
||||||
# self.actor.reset_noise()
|
critic_losses.append(critic_loss.item())
|
||||||
|
self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, 5.0)
|
||||||
|
|
||||||
# Action by the current actor for the sampled state
|
# Compute the actor loss
|
||||||
actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
|
actor_loss, log_action_prob = self.calc_actor_loss(replay_data)
|
||||||
|
|
||||||
# Compute entropy loss
|
|
||||||
ent_coef_loss = None
|
|
||||||
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
|
||||||
# Important: detach the variable from the graph
|
|
||||||
# so we don't change it with other losses
|
|
||||||
# see https://github.com/rail-berkeley/softlearning/issues/60
|
|
||||||
ent_coef = th.exp(self.log_ent_coef.detach())
|
|
||||||
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
|
|
||||||
ent_coef_losses.append(ent_coef_loss.item())
|
|
||||||
else:
|
|
||||||
ent_coef = self.ent_coef_tensor
|
|
||||||
|
|
||||||
ent_coefs.append(ent_coef.item())
|
|
||||||
|
|
||||||
# print(f"Alpha Loss{ent_coef_loss.item()}")
|
|
||||||
|
|
||||||
# Optimize entropy coefficient, also called
|
|
||||||
# entropy temperature or alpha in the paper
|
|
||||||
if ent_coef_loss is not None and self.ent_coef_optimizer is not None:
|
|
||||||
self.ent_coef_optimizer.zero_grad()
|
|
||||||
ent_coef_loss.backward()
|
|
||||||
self.ent_coef_optimizer.step()
|
|
||||||
|
|
||||||
with th.no_grad():
|
|
||||||
# Select action according to policy
|
|
||||||
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
|
||||||
# Compute the next Q values: min over all critics targets
|
|
||||||
next_q_values = th.cat(self.critic_target(replay_data.next_observations), dim=1)
|
|
||||||
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
|
|
||||||
|
|
||||||
# add entropy term
|
|
||||||
next_q_values = (action_prob * next_q_values - ent_coef * next_log_prob).sum(dim=1).unsqueeze(-1)
|
|
||||||
# td error + entropy term
|
|
||||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
|
||||||
|
|
||||||
# Get current Q-values estimates for each critic network
|
|
||||||
# using action from the replay buffer
|
|
||||||
current_q_values = self.critic(replay_data.observations)
|
|
||||||
|
|
||||||
# Compute critic loss
|
|
||||||
critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
|
|
||||||
critic_losses.append(critic_loss.item()) # type: ignore[union-attr]
|
|
||||||
|
|
||||||
# print(f"Critic Loss{critic_loss.item()}")
|
|
||||||
|
|
||||||
# Optimize the critic
|
|
||||||
self.critic.optimizer.zero_grad()
|
|
||||||
critic_loss.backward()
|
|
||||||
th.nn.utils.clip_grad_norm(self.actor.parameters(), 5.0)
|
|
||||||
self.critic.optimizer.step()
|
|
||||||
|
|
||||||
# Compute actor loss
|
|
||||||
# Min over all critic networks
|
|
||||||
q_values_pi = th.cat(self.critic(replay_data.observations), dim=1)
|
|
||||||
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
|
|
||||||
|
|
||||||
inside_term = ent_coef * log_prob - min_qf_pi
|
|
||||||
actor_loss = (actions_pi * inside_term).sum(dim=1).mean()
|
|
||||||
actor_losses.append(actor_loss.item())
|
actor_losses.append(actor_loss.item())
|
||||||
|
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, 5.0)
|
||||||
|
|
||||||
# print(f"Actor Loss{actor_loss.item()}")
|
# Compute entropy loss and optimize
|
||||||
|
ent_coeff = self.calc_entropy_loss(log_action_prob)
|
||||||
# Optimize the actor
|
ent_coefs.append(self.ent_coef.item())
|
||||||
self.actor.optimizer.zero_grad()
|
|
||||||
actor_loss.backward()
|
|
||||||
th.nn.utils.clip_grad_norm(self.critic.parameters(), 5.0)
|
|
||||||
self.actor.optimizer.step()
|
|
||||||
|
|
||||||
# Update target networks
|
# Update target networks
|
||||||
if gradient_step % self.target_update_interval == 0:
|
if gradient_step % self.target_update_interval == 0:
|
||||||
|
|
@ -306,6 +246,70 @@ class SACD(OffPolicyAlgorithm):
|
||||||
if len(ent_coef_losses) > 0:
|
if len(ent_coef_losses) > 0:
|
||||||
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||||
|
|
||||||
|
def take_optimisation_step(self, optimizer, network, loss, clipping_norm=None):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
if clipping_norm is not None:
|
||||||
|
th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) #clip gradients to help stabilise training
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
def calc_critic_loss(self, replay_data):
|
||||||
|
with th.no_grad():
|
||||||
|
# Select action according to policy
|
||||||
|
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||||
|
# Compute the next Q values: min over all critics targets
|
||||||
|
# next_q_values = th.cat(self.critic_target(replay_data.next_observations), dim=1)
|
||||||
|
next_q_values = self.critic_target(replay_data.next_observations)
|
||||||
|
# print(self.critic_target(replay_data.next_observations))
|
||||||
|
# exit(0)
|
||||||
|
# print(next_q_values)
|
||||||
|
|
||||||
|
# next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
|
||||||
|
next_q_values = th.min(*next_q_values)
|
||||||
|
# print(next_q_values)
|
||||||
|
# exit(0)
|
||||||
|
|
||||||
|
next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1)
|
||||||
|
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
||||||
|
|
||||||
|
# Get current Q-values estimates for each critic network
|
||||||
|
# using action from the replay buffer
|
||||||
|
current_q_values = self.critic(replay_data.observations)
|
||||||
|
|
||||||
|
# Compute critic loss
|
||||||
|
critic_loss = 0.5 * sum(F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values)
|
||||||
|
|
||||||
|
return critic_loss
|
||||||
|
|
||||||
|
def calc_actor_loss(self, replay_data):
|
||||||
|
action_prob, log_prob = self.actor.action_log_prob(replay_data.observations)
|
||||||
|
|
||||||
|
# Min over all critic networks
|
||||||
|
# q_values_pi = th.cat(self.critic(replay_data.observations), dim=1)
|
||||||
|
q_values_pi = self.critic(replay_data.observations)
|
||||||
|
# min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
|
||||||
|
min_qf_pi = th.min(*q_values_pi)
|
||||||
|
|
||||||
|
inside_term = self.ent_coef * log_prob - min_qf_pi
|
||||||
|
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
|
||||||
|
return actor_loss, log_prob
|
||||||
|
|
||||||
|
def calc_entropy_loss(self, log_action_prob):
|
||||||
|
ent_coef_loss = None
|
||||||
|
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
||||||
|
# Important: detach the variable from the graph
|
||||||
|
# so we don't change it with other losses
|
||||||
|
# see https://github.com/rail-berkeley/softlearning/issues/60
|
||||||
|
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
|
||||||
|
# ent_coef_losses.append(ent_coef_loss.item())
|
||||||
|
ent_coef_loss.backward()
|
||||||
|
self.ent_coef_optimizer.step()
|
||||||
|
self.ent_coef = th.exp(self.log_ent_coef.detach())
|
||||||
|
else:
|
||||||
|
self.ent_coef = self.ent_coef_tensor
|
||||||
|
|
||||||
|
return self.ent_coef
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self: SelfSACD,
|
self: SelfSACD,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue