Prepared files for merge request (minor cleanup)
This commit is contained in:
parent
4a37f58259
commit
fca2c6d490
|
|
@ -61,7 +61,6 @@ class Actor(BasePolicy):
|
|||
action_space,
|
||||
features_extractor=features_extractor,
|
||||
normalize_images=normalize_images,
|
||||
# squash_output=True,
|
||||
squash_output=False,
|
||||
)
|
||||
|
||||
|
|
@ -78,7 +77,7 @@ class Actor(BasePolicy):
|
|||
|
||||
num_actions = self.action_space.n
|
||||
|
||||
latent_pi_net = create_mlp(features_dim, num_actions, net_arch, activation_fn)
|
||||
latent_pi_net = create_mlp(features_dim, num_actions.item(), net_arch, activation_fn)
|
||||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||
|
||||
self.output_activation = nn.Softmax(dim=1)
|
||||
|
|
@ -184,8 +183,8 @@ class DiscreteCritic(BaseModel):
|
|||
self.n_critics = n_critics
|
||||
self.q_networks = []
|
||||
for idx in range(n_critics):
|
||||
q_net = create_mlp(features_dim, num_actions, net_arch, activation_fn)
|
||||
q_net = nn.Sequential(*q_net)
|
||||
q_net_list = create_mlp(features_dim, num_actions.item(), net_arch, activation_fn)
|
||||
q_net = nn.Sequential(*q_net_list)
|
||||
self.add_module(f"qf{idx}", q_net)
|
||||
self.q_networks.append(q_net)
|
||||
|
||||
|
|
@ -195,8 +194,9 @@ class DiscreteCritic(BaseModel):
|
|||
def forward(self, obs: th.Tensor) -> Tuple[th.Tensor, ...]:
|
||||
# Learn the features extractor using the policy loss only
|
||||
# when the features_extractor is shared with the actor
|
||||
with th.set_grad_enabled(not self.share_features_extractor):
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
if self.features_extractor is not None:
|
||||
with th.set_grad_enabled(not self.share_features_extractor):
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
return tuple(q_net(features) for q_net in self.q_networks)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,17 +18,14 @@ SelfSACD = TypeVar("SelfSACD", bound="SACD")
|
|||
|
||||
class SACD(OffPolicyAlgorithm):
|
||||
"""
|
||||
Soft Actor-Critic (SACD)
|
||||
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
|
||||
This implementation borrows code from original implementation (https://github.com/haarnoja/sac)
|
||||
from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo
|
||||
(https://github.com/rail-berkeley/softlearning/)
|
||||
and from Stable Baselines (https://github.com/hill-a/stable-baselines)
|
||||
Paper: https://arxiv.org/abs/1801.01290
|
||||
Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
|
||||
Discrete Soft Actor-Critic (SACD)
|
||||
The Soft Actor-Critic Algorithm modified for discrete Action spaces.
|
||||
This implementation borrows code from the original implementation of
|
||||
the papers author (https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch)
|
||||
from toshikawa (https://github.com/toshikwa/sac-discrete.pytorch#references)
|
||||
and from Stable Baselines (https://github.com/DLR-RM/stable-baselines3)
|
||||
|
||||
Note: we use double q target and not value target as discussed
|
||||
in https://github.com/hill-a/stable-baselines/issues/270
|
||||
Paper: https://arxiv.org/abs/1910.07207
|
||||
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||
|
|
@ -168,7 +165,7 @@ class SACD(OffPolicyAlgorithm):
|
|||
# Target entropy is used when learning the entropy coefficient
|
||||
if self.target_entropy == "auto":
|
||||
# we set the max possible entropy as the target entropy
|
||||
self.target_entropy = 0.98 * -np.log(1 / np.prod(self.env.action_space.shape))
|
||||
self.target_entropy = 0.98 * -np.log(1 / np.prod(np.array(self.action_space.shape)))
|
||||
else:
|
||||
# Force conversion
|
||||
# this will also throw an error for unexpected string
|
||||
|
|
@ -220,17 +217,45 @@ class SACD(OffPolicyAlgorithm):
|
|||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||
|
||||
# Compute the critic loss
|
||||
critic_loss = self.calc_critic_loss(replay_data)
|
||||
with th.no_grad():
|
||||
# Select action according to policy
|
||||
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||
|
||||
# Compute the next Q values: min over all critics targets
|
||||
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
|
||||
next_q_values, _ = th.min(next_q_values, dim=2)
|
||||
|
||||
next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1)
|
||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
||||
|
||||
# Get current Q-values estimates for each critic network
|
||||
current_q_values = self.critic(replay_data.observations)
|
||||
critic_loss = 0.5 * sum(
|
||||
F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values
|
||||
)
|
||||
assert isinstance(critic_loss, th.Tensor) # for type checker
|
||||
critic_losses.append(critic_loss.item())
|
||||
self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, self.gradient_clip_norm)
|
||||
|
||||
# Compute the actor loss
|
||||
actor_loss, log_action_prob = self.calc_actor_loss(replay_data)
|
||||
action_prob, log_action_prob = self.actor.action_log_prob(replay_data.observations)
|
||||
|
||||
# Min over all critic networks
|
||||
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
|
||||
min_qf_pi, _ = th.min(q_values_pi, dim=2)
|
||||
|
||||
inside_term = self.ent_coef_tensor * log_action_prob - min_qf_pi
|
||||
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
|
||||
actor_losses.append(actor_loss.item())
|
||||
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm)
|
||||
|
||||
# Compute entropy loss and optimize
|
||||
self.ent_coef_tensor = self.calc_entropy_loss(log_action_prob)
|
||||
# Compute entropy loss
|
||||
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
||||
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
|
||||
ent_coef_losses.append(ent_coef_loss.item())
|
||||
self.take_optimisation_step(self.ent_coef_optimizer, None, ent_coef_loss, None)
|
||||
self.ent_coef_tensor = th.exp(self.log_ent_coef.detach())
|
||||
|
||||
ent_coefs.append(self.ent_coef_tensor.item())
|
||||
|
||||
# Update target networks
|
||||
|
|
@ -248,59 +273,19 @@ class SACD(OffPolicyAlgorithm):
|
|||
if len(ent_coef_losses) > 0:
|
||||
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||
|
||||
def take_optimisation_step(self, optimizer, network, loss, clipping_norm=None):
|
||||
def take_optimisation_step(
|
||||
self,
|
||||
optimizer: th.optim.Optimizer,
|
||||
network: Optional[th.nn.Module],
|
||||
loss: th.Tensor,
|
||||
clipping_norm: Optional[float] = None,
|
||||
) -> None:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if clipping_norm is not None:
|
||||
if clipping_norm is not None and network is not None:
|
||||
th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) # clip gradients to help stabilise training
|
||||
optimizer.step()
|
||||
|
||||
def calc_critic_loss(self, replay_data):
|
||||
with th.no_grad():
|
||||
# Select action according to policy
|
||||
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||
|
||||
# Compute the next Q values: min over all critics targets
|
||||
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
|
||||
next_q_values, _ = th.min(next_q_values, dim=2)
|
||||
|
||||
next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1)
|
||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
||||
|
||||
# Get current Q-values estimates for each critic network
|
||||
# using action from the replay buffer
|
||||
current_q_values = self.critic(replay_data.observations)
|
||||
|
||||
# Compute critic loss
|
||||
critic_loss = 0.5 * sum(
|
||||
F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values
|
||||
)
|
||||
|
||||
return critic_loss
|
||||
|
||||
def calc_actor_loss(self, replay_data):
|
||||
action_prob, log_prob = self.actor.action_log_prob(replay_data.observations)
|
||||
|
||||
# Min over all critic networks
|
||||
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
|
||||
min_qf_pi, _ = th.min(q_values_pi, dim=2)
|
||||
|
||||
inside_term = self.ent_coef_tensor * log_prob - min_qf_pi
|
||||
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
|
||||
return actor_loss, log_prob
|
||||
|
||||
def calc_entropy_loss(self, log_action_prob):
|
||||
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
||||
# Important: detach the variable from the graph
|
||||
# so we don't change it with other losses
|
||||
# see https://github.com/rail-berkeley/softlearning/issues/60
|
||||
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
|
||||
ent_coef_loss.backward()
|
||||
self.ent_coef_optimizer.step()
|
||||
return th.exp(self.log_ent_coef.detach())
|
||||
else:
|
||||
return self.ent_coef_tensor
|
||||
|
||||
def learn(
|
||||
self: SelfSACD,
|
||||
total_timesteps: int,
|
||||
|
|
|
|||
Loading…
Reference in New Issue