Prepared files for merge request (minor cleanup)
This commit is contained in:
parent
4a37f58259
commit
fca2c6d490
|
|
@ -61,7 +61,6 @@ class Actor(BasePolicy):
|
||||||
action_space,
|
action_space,
|
||||||
features_extractor=features_extractor,
|
features_extractor=features_extractor,
|
||||||
normalize_images=normalize_images,
|
normalize_images=normalize_images,
|
||||||
# squash_output=True,
|
|
||||||
squash_output=False,
|
squash_output=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -78,7 +77,7 @@ class Actor(BasePolicy):
|
||||||
|
|
||||||
num_actions = self.action_space.n
|
num_actions = self.action_space.n
|
||||||
|
|
||||||
latent_pi_net = create_mlp(features_dim, num_actions, net_arch, activation_fn)
|
latent_pi_net = create_mlp(features_dim, num_actions.item(), net_arch, activation_fn)
|
||||||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||||
|
|
||||||
self.output_activation = nn.Softmax(dim=1)
|
self.output_activation = nn.Softmax(dim=1)
|
||||||
|
|
@ -184,8 +183,8 @@ class DiscreteCritic(BaseModel):
|
||||||
self.n_critics = n_critics
|
self.n_critics = n_critics
|
||||||
self.q_networks = []
|
self.q_networks = []
|
||||||
for idx in range(n_critics):
|
for idx in range(n_critics):
|
||||||
q_net = create_mlp(features_dim, num_actions, net_arch, activation_fn)
|
q_net_list = create_mlp(features_dim, num_actions.item(), net_arch, activation_fn)
|
||||||
q_net = nn.Sequential(*q_net)
|
q_net = nn.Sequential(*q_net_list)
|
||||||
self.add_module(f"qf{idx}", q_net)
|
self.add_module(f"qf{idx}", q_net)
|
||||||
self.q_networks.append(q_net)
|
self.q_networks.append(q_net)
|
||||||
|
|
||||||
|
|
@ -195,8 +194,9 @@ class DiscreteCritic(BaseModel):
|
||||||
def forward(self, obs: th.Tensor) -> Tuple[th.Tensor, ...]:
|
def forward(self, obs: th.Tensor) -> Tuple[th.Tensor, ...]:
|
||||||
# Learn the features extractor using the policy loss only
|
# Learn the features extractor using the policy loss only
|
||||||
# when the features_extractor is shared with the actor
|
# when the features_extractor is shared with the actor
|
||||||
with th.set_grad_enabled(not self.share_features_extractor):
|
if self.features_extractor is not None:
|
||||||
features = self.extract_features(obs, self.features_extractor)
|
with th.set_grad_enabled(not self.share_features_extractor):
|
||||||
|
features = self.extract_features(obs, self.features_extractor)
|
||||||
return tuple(q_net(features) for q_net in self.q_networks)
|
return tuple(q_net(features) for q_net in self.q_networks)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,17 +18,14 @@ SelfSACD = TypeVar("SelfSACD", bound="SACD")
|
||||||
|
|
||||||
class SACD(OffPolicyAlgorithm):
|
class SACD(OffPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
Soft Actor-Critic (SACD)
|
Discrete Soft Actor-Critic (SACD)
|
||||||
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
|
The Soft Actor-Critic Algorithm modified for discrete Action spaces.
|
||||||
This implementation borrows code from original implementation (https://github.com/haarnoja/sac)
|
This implementation borrows code from the original implementation of
|
||||||
from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo
|
the papers author (https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch)
|
||||||
(https://github.com/rail-berkeley/softlearning/)
|
from toshikawa (https://github.com/toshikwa/sac-discrete.pytorch#references)
|
||||||
and from Stable Baselines (https://github.com/hill-a/stable-baselines)
|
and from Stable Baselines (https://github.com/DLR-RM/stable-baselines3)
|
||||||
Paper: https://arxiv.org/abs/1801.01290
|
|
||||||
Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
|
|
||||||
|
|
||||||
Note: we use double q target and not value target as discussed
|
Paper: https://arxiv.org/abs/1910.07207
|
||||||
in https://github.com/hill-a/stable-baselines/issues/270
|
|
||||||
|
|
||||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||||
|
|
@ -168,7 +165,7 @@ class SACD(OffPolicyAlgorithm):
|
||||||
# Target entropy is used when learning the entropy coefficient
|
# Target entropy is used when learning the entropy coefficient
|
||||||
if self.target_entropy == "auto":
|
if self.target_entropy == "auto":
|
||||||
# we set the max possible entropy as the target entropy
|
# we set the max possible entropy as the target entropy
|
||||||
self.target_entropy = 0.98 * -np.log(1 / np.prod(self.env.action_space.shape))
|
self.target_entropy = 0.98 * -np.log(1 / np.prod(np.array(self.action_space.shape)))
|
||||||
else:
|
else:
|
||||||
# Force conversion
|
# Force conversion
|
||||||
# this will also throw an error for unexpected string
|
# this will also throw an error for unexpected string
|
||||||
|
|
@ -220,17 +217,45 @@ class SACD(OffPolicyAlgorithm):
|
||||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||||
|
|
||||||
# Compute the critic loss
|
# Compute the critic loss
|
||||||
critic_loss = self.calc_critic_loss(replay_data)
|
with th.no_grad():
|
||||||
|
# Select action according to policy
|
||||||
|
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||||
|
|
||||||
|
# Compute the next Q values: min over all critics targets
|
||||||
|
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
|
||||||
|
next_q_values, _ = th.min(next_q_values, dim=2)
|
||||||
|
|
||||||
|
next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1)
|
||||||
|
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
||||||
|
|
||||||
|
# Get current Q-values estimates for each critic network
|
||||||
|
current_q_values = self.critic(replay_data.observations)
|
||||||
|
critic_loss = 0.5 * sum(
|
||||||
|
F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values
|
||||||
|
)
|
||||||
|
assert isinstance(critic_loss, th.Tensor) # for type checker
|
||||||
critic_losses.append(critic_loss.item())
|
critic_losses.append(critic_loss.item())
|
||||||
self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, self.gradient_clip_norm)
|
self.take_optimisation_step(self.critic.optimizer, self.critic, critic_loss, self.gradient_clip_norm)
|
||||||
|
|
||||||
# Compute the actor loss
|
# Compute the actor loss
|
||||||
actor_loss, log_action_prob = self.calc_actor_loss(replay_data)
|
action_prob, log_action_prob = self.actor.action_log_prob(replay_data.observations)
|
||||||
|
|
||||||
|
# Min over all critic networks
|
||||||
|
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
|
||||||
|
min_qf_pi, _ = th.min(q_values_pi, dim=2)
|
||||||
|
|
||||||
|
inside_term = self.ent_coef_tensor * log_action_prob - min_qf_pi
|
||||||
|
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
|
||||||
actor_losses.append(actor_loss.item())
|
actor_losses.append(actor_loss.item())
|
||||||
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm)
|
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm)
|
||||||
|
|
||||||
# Compute entropy loss and optimize
|
# Compute entropy loss
|
||||||
self.ent_coef_tensor = self.calc_entropy_loss(log_action_prob)
|
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
||||||
|
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
|
||||||
|
ent_coef_losses.append(ent_coef_loss.item())
|
||||||
|
self.take_optimisation_step(self.ent_coef_optimizer, None, ent_coef_loss, None)
|
||||||
|
self.ent_coef_tensor = th.exp(self.log_ent_coef.detach())
|
||||||
|
|
||||||
ent_coefs.append(self.ent_coef_tensor.item())
|
ent_coefs.append(self.ent_coef_tensor.item())
|
||||||
|
|
||||||
# Update target networks
|
# Update target networks
|
||||||
|
|
@ -248,59 +273,19 @@ class SACD(OffPolicyAlgorithm):
|
||||||
if len(ent_coef_losses) > 0:
|
if len(ent_coef_losses) > 0:
|
||||||
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||||
|
|
||||||
def take_optimisation_step(self, optimizer, network, loss, clipping_norm=None):
|
def take_optimisation_step(
|
||||||
|
self,
|
||||||
|
optimizer: th.optim.Optimizer,
|
||||||
|
network: Optional[th.nn.Module],
|
||||||
|
loss: th.Tensor,
|
||||||
|
clipping_norm: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if clipping_norm is not None:
|
if clipping_norm is not None and network is not None:
|
||||||
th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) # clip gradients to help stabilise training
|
th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) # clip gradients to help stabilise training
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
def calc_critic_loss(self, replay_data):
|
|
||||||
with th.no_grad():
|
|
||||||
# Select action according to policy
|
|
||||||
action_prob, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
|
||||||
|
|
||||||
# Compute the next Q values: min over all critics targets
|
|
||||||
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
|
|
||||||
next_q_values, _ = th.min(next_q_values, dim=2)
|
|
||||||
|
|
||||||
next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1)
|
|
||||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
|
||||||
|
|
||||||
# Get current Q-values estimates for each critic network
|
|
||||||
# using action from the replay buffer
|
|
||||||
current_q_values = self.critic(replay_data.observations)
|
|
||||||
|
|
||||||
# Compute critic loss
|
|
||||||
critic_loss = 0.5 * sum(
|
|
||||||
F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values
|
|
||||||
)
|
|
||||||
|
|
||||||
return critic_loss
|
|
||||||
|
|
||||||
def calc_actor_loss(self, replay_data):
|
|
||||||
action_prob, log_prob = self.actor.action_log_prob(replay_data.observations)
|
|
||||||
|
|
||||||
# Min over all critic networks
|
|
||||||
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
|
|
||||||
min_qf_pi, _ = th.min(q_values_pi, dim=2)
|
|
||||||
|
|
||||||
inside_term = self.ent_coef_tensor * log_prob - min_qf_pi
|
|
||||||
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
|
|
||||||
return actor_loss, log_prob
|
|
||||||
|
|
||||||
def calc_entropy_loss(self, log_action_prob):
|
|
||||||
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
|
||||||
# Important: detach the variable from the graph
|
|
||||||
# so we don't change it with other losses
|
|
||||||
# see https://github.com/rail-berkeley/softlearning/issues/60
|
|
||||||
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
|
|
||||||
ent_coef_loss.backward()
|
|
||||||
self.ent_coef_optimizer.step()
|
|
||||||
return th.exp(self.log_ent_coef.detach())
|
|
||||||
else:
|
|
||||||
return self.ent_coef_tensor
|
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self: SelfSACD,
|
self: SelfSACD,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue