Fixed bugs in that lead to wrong results, currently only working with 2

critics
This commit is contained in:
Paul Auerbach 2023-08-01 15:09:55 +02:00
parent a14ae69b6b
commit 875b8bca0d
3 changed files with 79 additions and 73 deletions

View File

@ -6,6 +6,7 @@ from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC
from sb3_contrib.trpo import TRPO
from sb3_contrib.sacd import SACD
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
@ -19,4 +20,5 @@ __all__ = [
"QRDQN",
"TQC",
"TRPO",
"SACD",
]

View File

@ -178,13 +178,13 @@ class DiscreteCritic(BaseModel):
normalize_images=normalize_images,
)
action_dim = get_action_dim(self.action_space)
num_actions = self.action_space.n
self.share_features_extractor = share_features_extractor
self.n_critics = n_critics
self.q_networks = []
for idx in range(n_critics):
q_net = create_mlp(features_dim, action_dim, net_arch, activation_fn)
q_net = create_mlp(features_dim, num_actions, net_arch, activation_fn)
q_net = nn.Sequential(*q_net)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)

View File

@ -186,6 +186,7 @@ class SACD(OffPolicyAlgorithm):
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
# self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
self.log_ent_coef = th.zeros(1, device=self.device, requires_grad=True)
self.ent_coef = th.exp(self.log_ent_coef)
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
else:
# Force conversion to float
@ -216,80 +217,19 @@ class SACD(OffPolicyAlgorithm):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
# We need to sample because `log_std` may have changed between two gradient steps
# if self.use_sde:
# self.actor.reset_noise()
# Compute the critic loss
critic_loss = self.calc_critic_loss(replay_data)
critic_losses.append(critic_loss.item())
self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, 5.0)
# Action by the current actor for the sampled state
actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
# Compute entropy loss
ent_coef_loss = None
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef = th.exp(self.log_ent_coef.detach())
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
else:
ent_coef = self.ent_coef_tensor
ent_coefs.append(ent_coef.item())
# print(f"Alpha Loss{ent_coef_loss.item()}")
# Optimize entropy coefficient, also called
# entropy temperature or alpha in the paper
if ent_coef_loss is not None and self.ent_coef_optimizer is not None:
self.ent_coef_optimizer.zero_grad()
ent_coef_loss.backward()
self.ent_coef_optimizer.step()
with th.no_grad():
# Select action according to policy
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute the next Q values: min over all critics targets
next_q_values = th.cat(self.critic_target(replay_data.next_observations), dim=1)
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
# add entropy term
next_q_values = (action_prob * next_q_values - ent_coef * next_log_prob).sum(dim=1).unsqueeze(-1)
# td error + entropy term
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q-values estimates for each critic network
# using action from the replay buffer
current_q_values = self.critic(replay_data.observations)
# Compute critic loss
critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
critic_losses.append(critic_loss.item()) # type: ignore[union-attr]
# print(f"Critic Loss{critic_loss.item()}")
# Optimize the critic
self.critic.optimizer.zero_grad()
critic_loss.backward()
th.nn.utils.clip_grad_norm(self.actor.parameters(), 5.0)
self.critic.optimizer.step()
# Compute actor loss
# Min over all critic networks
q_values_pi = th.cat(self.critic(replay_data.observations), dim=1)
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
inside_term = ent_coef * log_prob - min_qf_pi
actor_loss = (actions_pi * inside_term).sum(dim=1).mean()
# Compute the actor loss
actor_loss, log_action_prob = self.calc_actor_loss(replay_data)
actor_losses.append(actor_loss.item())
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, 5.0)
# print(f"Actor Loss{actor_loss.item()}")
# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
th.nn.utils.clip_grad_norm(self.critic.parameters(), 5.0)
self.actor.optimizer.step()
# Compute entropy loss and optimize
ent_coeff = self.calc_entropy_loss(log_action_prob)
ent_coefs.append(self.ent_coef.item())
# Update target networks
if gradient_step % self.target_update_interval == 0:
@ -306,6 +246,70 @@ class SACD(OffPolicyAlgorithm):
if len(ent_coef_losses) > 0:
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def take_optimisation_step(self, optimizer, network, loss, clipping_norm=None):
optimizer.zero_grad()
loss.backward()
if clipping_norm is not None:
th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) #clip gradients to help stabilise training
optimizer.step()
def calc_critic_loss(self, replay_data):
with th.no_grad():
# Select action according to policy
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute the next Q values: min over all critics targets
# next_q_values = th.cat(self.critic_target(replay_data.next_observations), dim=1)
next_q_values = self.critic_target(replay_data.next_observations)
# print(self.critic_target(replay_data.next_observations))
# exit(0)
# print(next_q_values)
# next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
next_q_values = th.min(*next_q_values)
# print(next_q_values)
# exit(0)
next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1)
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q-values estimates for each critic network
# using action from the replay buffer
current_q_values = self.critic(replay_data.observations)
# Compute critic loss
critic_loss = 0.5 * sum(F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values)
return critic_loss
def calc_actor_loss(self, replay_data):
action_prob, log_prob = self.actor.action_log_prob(replay_data.observations)
# Min over all critic networks
# q_values_pi = th.cat(self.critic(replay_data.observations), dim=1)
q_values_pi = self.critic(replay_data.observations)
# min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
min_qf_pi = th.min(*q_values_pi)
inside_term = self.ent_coef * log_prob - min_qf_pi
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
return actor_loss, log_prob
def calc_entropy_loss(self, log_action_prob):
ent_coef_loss = None
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
# ent_coef_losses.append(ent_coef_loss.item())
ent_coef_loss.backward()
self.ent_coef_optimizer.step()
self.ent_coef = th.exp(self.log_ent_coef.detach())
else:
self.ent_coef = self.ent_coef_tensor
return self.ent_coef
def learn(
self: SelfSACD,
total_timesteps: int,