Reworked code to work whith more than 2 critic networks

This commit is contained in:
Paul Auerbach 2023-08-02 13:17:22 +02:00
parent 875b8bca0d
commit 7711813dbb
1 changed files with 11 additions and 17 deletions

View File

@ -68,6 +68,7 @@ class SACD(OffPolicyAlgorithm):
during the warm up phase (before learning starts) during the warm up phase (before learning starts)
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over the reported success rate, mean episode length, and mean reward over
:param max_grad_norm: The maximum value for the gradient clipping
:param tensorboard_log: the log location for tensorboard (if None, no logging) :param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation :param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
@ -111,6 +112,7 @@ class SACD(OffPolicyAlgorithm):
sde_sample_freq: int = -1, sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False, use_sde_at_warmup: bool = False,
stats_window_size: int = 100, stats_window_size: int = 100,
max_grad_norm = 5.0,
tensorboard_log: Optional[str] = None, tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None, policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0, verbose: int = 0,
@ -153,6 +155,7 @@ class SACD(OffPolicyAlgorithm):
self.ent_coef = ent_coef self.ent_coef = ent_coef
self.target_update_interval = target_update_interval self.target_update_interval = target_update_interval
self.ent_coef_optimizer: Optional[th.optim.Adam] = None self.ent_coef_optimizer: Optional[th.optim.Adam] = None
self.gradient_clip_norm = max_grad_norm
if _init_setup_model: if _init_setup_model:
self._setup_model() self._setup_model()
@ -220,12 +223,12 @@ class SACD(OffPolicyAlgorithm):
# Compute the critic loss # Compute the critic loss
critic_loss = self.calc_critic_loss(replay_data) critic_loss = self.calc_critic_loss(replay_data)
critic_losses.append(critic_loss.item()) critic_losses.append(critic_loss.item())
self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, 5.0) self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, self.gradient_clip_norm)
# Compute the actor loss # Compute the actor loss
actor_loss, log_action_prob = self.calc_actor_loss(replay_data) actor_loss, log_action_prob = self.calc_actor_loss(replay_data)
actor_losses.append(actor_loss.item()) actor_losses.append(actor_loss.item())
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, 5.0) self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm)
# Compute entropy loss and optimize # Compute entropy loss and optimize
ent_coeff = self.calc_entropy_loss(log_action_prob) ent_coeff = self.calc_entropy_loss(log_action_prob)
@ -235,7 +238,7 @@ class SACD(OffPolicyAlgorithm):
if gradient_step % self.target_update_interval == 0: if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
# Copy running stats, see GH issue #996 # Copy running stats, see GH issue #996
# polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
self._n_updates += gradient_steps self._n_updates += gradient_steps
@ -257,17 +260,10 @@ class SACD(OffPolicyAlgorithm):
with th.no_grad(): with th.no_grad():
# Select action according to policy # Select action according to policy
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute the next Q values: min over all critics targets
# next_q_values = th.cat(self.critic_target(replay_data.next_observations), dim=1)
next_q_values = self.critic_target(replay_data.next_observations)
# print(self.critic_target(replay_data.next_observations))
# exit(0)
# print(next_q_values)
# next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) # Compute the next Q values: min over all critics targets
next_q_values = th.min(*next_q_values) next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
# print(next_q_values) next_q_values, _ = th.min(next_q_values, dim=2)
# exit(0)
next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1) next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1)
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
@ -285,10 +281,8 @@ class SACD(OffPolicyAlgorithm):
action_prob, log_prob = self.actor.action_log_prob(replay_data.observations) action_prob, log_prob = self.actor.action_log_prob(replay_data.observations)
# Min over all critic networks # Min over all critic networks
# q_values_pi = th.cat(self.critic(replay_data.observations), dim=1) q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
q_values_pi = self.critic(replay_data.observations) min_qf_pi, _ = th.min(q_values_pi, dim=2)
# min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
min_qf_pi = th.min(*q_values_pi)
inside_term = self.ent_coef * log_prob - min_qf_pi inside_term = self.ent_coef * log_prob - min_qf_pi
actor_loss = (action_prob * inside_term).sum(dim=1).mean() actor_loss = (action_prob * inside_term).sum(dim=1).mean()