Prepared files for merge request (minor cleanup)

This commit is contained in:
Paul Auerbach 2023-08-04 18:19:34 +02:00
parent 4a37f58259
commit fca2c6d490
2 changed files with 54 additions and 69 deletions

View File

@ -61,7 +61,6 @@ class Actor(BasePolicy):
action_space, action_space,
features_extractor=features_extractor, features_extractor=features_extractor,
normalize_images=normalize_images, normalize_images=normalize_images,
# squash_output=True,
squash_output=False, squash_output=False,
) )
@ -78,7 +77,7 @@ class Actor(BasePolicy):
num_actions = self.action_space.n num_actions = self.action_space.n
latent_pi_net = create_mlp(features_dim, num_actions, net_arch, activation_fn) latent_pi_net = create_mlp(features_dim, num_actions.item(), net_arch, activation_fn)
self.latent_pi = nn.Sequential(*latent_pi_net) self.latent_pi = nn.Sequential(*latent_pi_net)
self.output_activation = nn.Softmax(dim=1) self.output_activation = nn.Softmax(dim=1)
@ -184,8 +183,8 @@ class DiscreteCritic(BaseModel):
self.n_critics = n_critics self.n_critics = n_critics
self.q_networks = [] self.q_networks = []
for idx in range(n_critics): for idx in range(n_critics):
q_net = create_mlp(features_dim, num_actions, net_arch, activation_fn) q_net_list = create_mlp(features_dim, num_actions.item(), net_arch, activation_fn)
q_net = nn.Sequential(*q_net) q_net = nn.Sequential(*q_net_list)
self.add_module(f"qf{idx}", q_net) self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net) self.q_networks.append(q_net)
@ -195,8 +194,9 @@ class DiscreteCritic(BaseModel):
def forward(self, obs: th.Tensor) -> Tuple[th.Tensor, ...]: def forward(self, obs: th.Tensor) -> Tuple[th.Tensor, ...]:
# Learn the features extractor using the policy loss only # Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor # when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor): if self.features_extractor is not None:
features = self.extract_features(obs, self.features_extractor) with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs, self.features_extractor)
return tuple(q_net(features) for q_net in self.q_networks) return tuple(q_net(features) for q_net in self.q_networks)

View File

@ -18,17 +18,14 @@ SelfSACD = TypeVar("SelfSACD", bound="SACD")
class SACD(OffPolicyAlgorithm): class SACD(OffPolicyAlgorithm):
""" """
Soft Actor-Critic (SACD) Discrete Soft Actor-Critic (SACD)
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, The Soft Actor-Critic Algorithm modified for discrete Action spaces.
This implementation borrows code from original implementation (https://github.com/haarnoja/sac) This implementation borrows code from the original implementation of
from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo the papers author (https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch)
(https://github.com/rail-berkeley/softlearning/) from toshikawa (https://github.com/toshikwa/sac-discrete.pytorch#references)
and from Stable Baselines (https://github.com/hill-a/stable-baselines) and from Stable Baselines (https://github.com/DLR-RM/stable-baselines3)
Paper: https://arxiv.org/abs/1801.01290
Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
Note: we use double q target and not value target as discussed Paper: https://arxiv.org/abs/1910.07207
in https://github.com/hill-a/stable-baselines/issues/270
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str) :param env: The environment to learn from (if registered in Gym, can be str)
@ -168,7 +165,7 @@ class SACD(OffPolicyAlgorithm):
# Target entropy is used when learning the entropy coefficient # Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto": if self.target_entropy == "auto":
# we set the max possible entropy as the target entropy # we set the max possible entropy as the target entropy
self.target_entropy = 0.98 * -np.log(1 / np.prod(self.env.action_space.shape)) self.target_entropy = 0.98 * -np.log(1 / np.prod(np.array(self.action_space.shape)))
else: else:
# Force conversion # Force conversion
# this will also throw an error for unexpected string # this will also throw an error for unexpected string
@ -220,17 +217,45 @@ class SACD(OffPolicyAlgorithm):
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
# Compute the critic loss # Compute the critic loss
critic_loss = self.calc_critic_loss(replay_data) with th.no_grad():
# Select action according to policy
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute the next Q values: min over all critics targets
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
next_q_values, _ = th.min(next_q_values, dim=2)
next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1)
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q-values estimates for each critic network
current_q_values = self.critic(replay_data.observations)
critic_loss = 0.5 * sum(
F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values
)
assert isinstance(critic_loss, th.Tensor) # for type checker
critic_losses.append(critic_loss.item()) critic_losses.append(critic_loss.item())
self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, self.gradient_clip_norm) self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, self.gradient_clip_norm)
# Compute the actor loss # Compute the actor loss
actor_loss, log_action_prob = self.calc_actor_loss(replay_data) action_prob, log_action_prob = self.actor.action_log_prob(replay_data.observations)
# Min over all critic networks
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
min_qf_pi, _ = th.min(q_values_pi, dim=2)
inside_term = self.ent_coef_tensor * log_action_prob - min_qf_pi
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
actor_losses.append(actor_loss.item()) actor_losses.append(actor_loss.item())
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm) self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm)
# Compute entropy loss and optimize # Compute entropy loss
self.ent_coef_tensor = self.calc_entropy_loss(log_action_prob) if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
self.take_optimisation_step(self.ent_coef_optimizer, None, ent_coef_loss, None)
self.ent_coef_tensor = th.exp(self.log_ent_coef.detach())
ent_coefs.append(self.ent_coef_tensor.item()) ent_coefs.append(self.ent_coef_tensor.item())
# Update target networks # Update target networks
@ -248,59 +273,19 @@ class SACD(OffPolicyAlgorithm):
if len(ent_coef_losses) > 0: if len(ent_coef_losses) > 0:
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def take_optimisation_step(self, optimizer, network, loss, clipping_norm=None): def take_optimisation_step(
self,
optimizer: th.optim.Optimizer,
network: Optional[th.nn.Module],
loss: th.Tensor,
clipping_norm: Optional[float] = None,
) -> None:
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
if clipping_norm is not None: if clipping_norm is not None and network is not None:
th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) # clip gradients to help stabilise training th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) # clip gradients to help stabilise training
optimizer.step() optimizer.step()
def calc_critic_loss(self, replay_data):
with th.no_grad():
# Select action according to policy
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute the next Q values: min over all critics targets
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
next_q_values, _ = th.min(next_q_values, dim=2)
next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1)
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q-values estimates for each critic network
# using action from the replay buffer
current_q_values = self.critic(replay_data.observations)
# Compute critic loss
critic_loss = 0.5 * sum(
F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values
)
return critic_loss
def calc_actor_loss(self, replay_data):
action_prob, log_prob = self.actor.action_log_prob(replay_data.observations)
# Min over all critic networks
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
min_qf_pi, _ = th.min(q_values_pi, dim=2)
inside_term = self.ent_coef_tensor * log_prob - min_qf_pi
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
return actor_loss, log_prob
def calc_entropy_loss(self, log_action_prob):
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
ent_coef_loss.backward()
self.ent_coef_optimizer.step()
return th.exp(self.log_ent_coef.detach())
else:
return self.ent_coef_tensor
def learn( def learn(
self: SelfSACD, self: SelfSACD,
total_timesteps: int, total_timesteps: int,