Reworked code to work whith more than 2 critic networks

This commit is contained in:
Paul Auerbach 2023-08-02 13:17:22 +02:00
parent 875b8bca0d
commit 7711813dbb
1 changed files with 11 additions and 17 deletions

View File

@ -68,6 +68,7 @@ class SACD(OffPolicyAlgorithm):
during the warm up phase (before learning starts)
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param max_grad_norm: The maximum value for the gradient clipping
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
@ -111,6 +112,7 @@ class SACD(OffPolicyAlgorithm):
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
stats_window_size: int = 100,
max_grad_norm = 5.0,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
@ -153,6 +155,7 @@ class SACD(OffPolicyAlgorithm):
self.ent_coef = ent_coef
self.target_update_interval = target_update_interval
self.ent_coef_optimizer: Optional[th.optim.Adam] = None
self.gradient_clip_norm = max_grad_norm
if _init_setup_model:
self._setup_model()
@ -220,12 +223,12 @@ class SACD(OffPolicyAlgorithm):
# Compute the critic loss
critic_loss = self.calc_critic_loss(replay_data)
critic_losses.append(critic_loss.item())
self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, 5.0)
self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, self.gradient_clip_norm)
# Compute the actor loss
actor_loss, log_action_prob = self.calc_actor_loss(replay_data)
actor_losses.append(actor_loss.item())
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, 5.0)
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm)
# Compute entropy loss and optimize
ent_coeff = self.calc_entropy_loss(log_action_prob)
@ -235,7 +238,7 @@ class SACD(OffPolicyAlgorithm):
if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
# Copy running stats, see GH issue #996
# polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
self._n_updates += gradient_steps
@ -257,17 +260,10 @@ class SACD(OffPolicyAlgorithm):
with th.no_grad():
# Select action according to policy
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute the next Q values: min over all critics targets
# next_q_values = th.cat(self.critic_target(replay_data.next_observations), dim=1)
next_q_values = self.critic_target(replay_data.next_observations)
# print(self.critic_target(replay_data.next_observations))
# exit(0)
# print(next_q_values)
# next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
next_q_values = th.min(*next_q_values)
# print(next_q_values)
# exit(0)
# Compute the next Q values: min over all critics targets
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
next_q_values, _ = th.min(next_q_values, dim=2)
next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1)
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
@ -285,10 +281,8 @@ class SACD(OffPolicyAlgorithm):
action_prob, log_prob = self.actor.action_log_prob(replay_data.observations)
# Min over all critic networks
# q_values_pi = th.cat(self.critic(replay_data.observations), dim=1)
q_values_pi = self.critic(replay_data.observations)
# min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
min_qf_pi = th.min(*q_values_pi)
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
min_qf_pi, _ = th.min(q_values_pi, dim=2)
inside_term = self.ent_coef * log_prob - min_qf_pi
actor_loss = (action_prob * inside_term).sum(dim=1).mean()