Reworked code to work whith more than 2 critic networks
This commit is contained in:
parent
875b8bca0d
commit
7711813dbb
|
|
@ -68,6 +68,7 @@ class SACD(OffPolicyAlgorithm):
|
||||||
during the warm up phase (before learning starts)
|
during the warm up phase (before learning starts)
|
||||||
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
||||||
the reported success rate, mean episode length, and mean reward over
|
the reported success rate, mean episode length, and mean reward over
|
||||||
|
:param max_grad_norm: The maximum value for the gradient clipping
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
|
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
|
||||||
|
|
@ -111,6 +112,7 @@ class SACD(OffPolicyAlgorithm):
|
||||||
sde_sample_freq: int = -1,
|
sde_sample_freq: int = -1,
|
||||||
use_sde_at_warmup: bool = False,
|
use_sde_at_warmup: bool = False,
|
||||||
stats_window_size: int = 100,
|
stats_window_size: int = 100,
|
||||||
|
max_grad_norm = 5.0,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
|
|
@ -153,6 +155,7 @@ class SACD(OffPolicyAlgorithm):
|
||||||
self.ent_coef = ent_coef
|
self.ent_coef = ent_coef
|
||||||
self.target_update_interval = target_update_interval
|
self.target_update_interval = target_update_interval
|
||||||
self.ent_coef_optimizer: Optional[th.optim.Adam] = None
|
self.ent_coef_optimizer: Optional[th.optim.Adam] = None
|
||||||
|
self.gradient_clip_norm = max_grad_norm
|
||||||
|
|
||||||
if _init_setup_model:
|
if _init_setup_model:
|
||||||
self._setup_model()
|
self._setup_model()
|
||||||
|
|
@ -220,12 +223,12 @@ class SACD(OffPolicyAlgorithm):
|
||||||
# Compute the critic loss
|
# Compute the critic loss
|
||||||
critic_loss = self.calc_critic_loss(replay_data)
|
critic_loss = self.calc_critic_loss(replay_data)
|
||||||
critic_losses.append(critic_loss.item())
|
critic_losses.append(critic_loss.item())
|
||||||
self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, 5.0)
|
self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, self.gradient_clip_norm)
|
||||||
|
|
||||||
# Compute the actor loss
|
# Compute the actor loss
|
||||||
actor_loss, log_action_prob = self.calc_actor_loss(replay_data)
|
actor_loss, log_action_prob = self.calc_actor_loss(replay_data)
|
||||||
actor_losses.append(actor_loss.item())
|
actor_losses.append(actor_loss.item())
|
||||||
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, 5.0)
|
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm)
|
||||||
|
|
||||||
# Compute entropy loss and optimize
|
# Compute entropy loss and optimize
|
||||||
ent_coeff = self.calc_entropy_loss(log_action_prob)
|
ent_coeff = self.calc_entropy_loss(log_action_prob)
|
||||||
|
|
@ -235,7 +238,7 @@ class SACD(OffPolicyAlgorithm):
|
||||||
if gradient_step % self.target_update_interval == 0:
|
if gradient_step % self.target_update_interval == 0:
|
||||||
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
|
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
|
||||||
# Copy running stats, see GH issue #996
|
# Copy running stats, see GH issue #996
|
||||||
# polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
|
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
|
||||||
|
|
||||||
self._n_updates += gradient_steps
|
self._n_updates += gradient_steps
|
||||||
|
|
||||||
|
|
@ -257,17 +260,10 @@ class SACD(OffPolicyAlgorithm):
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Select action according to policy
|
# Select action according to policy
|
||||||
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||||
# Compute the next Q values: min over all critics targets
|
|
||||||
# next_q_values = th.cat(self.critic_target(replay_data.next_observations), dim=1)
|
|
||||||
next_q_values = self.critic_target(replay_data.next_observations)
|
|
||||||
# print(self.critic_target(replay_data.next_observations))
|
|
||||||
# exit(0)
|
|
||||||
# print(next_q_values)
|
|
||||||
|
|
||||||
# next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
|
# Compute the next Q values: min over all critics targets
|
||||||
next_q_values = th.min(*next_q_values)
|
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
|
||||||
# print(next_q_values)
|
next_q_values, _ = th.min(next_q_values, dim=2)
|
||||||
# exit(0)
|
|
||||||
|
|
||||||
next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1)
|
next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1)
|
||||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
||||||
|
|
@ -285,10 +281,8 @@ class SACD(OffPolicyAlgorithm):
|
||||||
action_prob, log_prob = self.actor.action_log_prob(replay_data.observations)
|
action_prob, log_prob = self.actor.action_log_prob(replay_data.observations)
|
||||||
|
|
||||||
# Min over all critic networks
|
# Min over all critic networks
|
||||||
# q_values_pi = th.cat(self.critic(replay_data.observations), dim=1)
|
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
|
||||||
q_values_pi = self.critic(replay_data.observations)
|
min_qf_pi, _ = th.min(q_values_pi, dim=2)
|
||||||
# min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
|
|
||||||
min_qf_pi = th.min(*q_values_pi)
|
|
||||||
|
|
||||||
inside_term = self.ent_coef * log_prob - min_qf_pi
|
inside_term = self.ent_coef * log_prob - min_qf_pi
|
||||||
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
|
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue