Code style changes

This commit is contained in:
Paul Auerbach 2023-08-02 13:52:40 +02:00
parent 7711813dbb
commit 4a37f58259
3 changed files with 37 additions and 47 deletions

View File

@ -4,9 +4,9 @@ from sb3_contrib.ars import ARS
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.sacd import SACD
from sb3_contrib.tqc import TQC
from sb3_contrib.trpo import TRPO
from sb3_contrib.sacd import SACD
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")

View File

@ -2,13 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch as th
from gymnasium import spaces
from torch import nn
from torch.distributions import Categorical
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BasePolicy, BaseModel
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.policies import BaseModel, BasePolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
@ -18,10 +12,13 @@ from stable_baselines3.common.torch_layers import (
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
from torch import nn
from torch.distributions import Categorical
class Actor(BasePolicy):
"""
Actor network (policy) for SAC.
Actor network (policy) for SACD
:param observation_space: Obervation space
:param action_space: Action space
@ -42,7 +39,7 @@ class Actor(BasePolicy):
dividing by 255.0 (True by default)
"""
action_space: spaces.Box
action_space: spaces.Discrete
def __init__(
self,
@ -51,7 +48,7 @@ class Actor(BasePolicy):
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.Softmax(dim=1),
activation_fn: Type[nn.Module] = nn.Softmax,
use_sde: bool = False,
log_std_init: float = -3,
full_std: bool = True,
@ -132,6 +129,7 @@ class Actor(BasePolicy):
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic)
class DiscreteCritic(BaseModel):
"""
Critic network(s) for DDPG/SAC/TD3.
@ -159,10 +157,12 @@ class DiscreteCritic(BaseModel):
between the actor and the critic (this saves computation time)
"""
action_space: spaces.Discrete
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
action_space: spaces.Discrete,
net_arch: List[int],
features_extractor: BaseFeaturesExtractor,
features_dim: int,
@ -199,9 +199,10 @@ class DiscreteCritic(BaseModel):
features = self.extract_features(obs, self.features_extractor)
return tuple(q_net(features) for q_net in self.q_networks)
class SACPolicy(BasePolicy):
class SACDPolicy(BasePolicy):
"""
Policy class (with both actor and critic) for SAC.
Policy class (with both actor and critic) for SACD.
:param observation_space: Observation space
:param action_space: Action space
@ -351,14 +352,6 @@ class SACPolicy(BasePolicy):
)
return data
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size:
"""
self.actor.reset_noise(batch_size=batch_size)
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
return Actor(**actor_kwargs).to(self.device)
@ -386,12 +379,12 @@ class SACPolicy(BasePolicy):
self.training = mode
MlpPolicy = SACPolicy
MlpPolicy = SACDPolicy
class CnnPolicy(SACPolicy):
class CnnPolicy(SACDPolicy):
"""
Policy class (with both actor and critic) for SAC.
Policy class (with both actor and critic) for SACD.
:param observation_space: Observation space
:param action_space: Action space
@ -455,9 +448,9 @@ class CnnPolicy(SACPolicy):
)
class MultiInputPolicy(SACPolicy):
class MultiInputPolicy(SACDPolicy):
"""
Policy class (with both actor and critic) for SAC.
Policy class (with both actor and critic) for SACD.
:param observation_space: Observation space
:param action_space: Action space

View File

@ -3,23 +3,22 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Un
import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from torch.nn import functional as F
from sb3_contrib.sacd.policies import Actor, DiscreteCritic, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
from sb3_contrib.sacd.policies import Actor, CnnPolicy, DiscreteCritic, MlpPolicy, MultiInputPolicy, SACDPolicy
SelfSACD = TypeVar("SelfSACD", bound="SACD")
class SACD(OffPolicyAlgorithm):
"""
Soft Actor-Critic (SAC)
Soft Actor-Critic (SACD)
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
This implementation borrows code from original implementation (https://github.com/haarnoja/sac)
from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo
@ -84,14 +83,14 @@ class SACD(OffPolicyAlgorithm):
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
policy: SACPolicy
policy: SACDPolicy
actor: Actor
critic: DiscreteCritic
critic_target: DiscreteCritic
def __init__(
self,
policy: Union[str, Type[SACPolicy]],
policy: Union[str, Type[SACDPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
buffer_size: int = 1_000_000, # 1e6
@ -112,7 +111,7 @@ class SACD(OffPolicyAlgorithm):
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
stats_window_size: int = 100,
max_grad_norm = 5.0,
max_grad_norm=5.0,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
@ -189,7 +188,7 @@ class SACD(OffPolicyAlgorithm):
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
# self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
self.log_ent_coef = th.zeros(1, device=self.device, requires_grad=True)
self.ent_coef = th.exp(self.log_ent_coef)
self.ent_coef_tensor = th.exp(self.log_ent_coef)
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
else:
# Force conversion to float
@ -231,8 +230,8 @@ class SACD(OffPolicyAlgorithm):
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm)
# Compute entropy loss and optimize
ent_coeff = self.calc_entropy_loss(log_action_prob)
ent_coefs.append(self.ent_coef.item())
self.ent_coef_tensor = self.calc_entropy_loss(log_action_prob)
ent_coefs.append(self.ent_coef_tensor.item())
# Update target networks
if gradient_step % self.target_update_interval == 0:
@ -253,7 +252,7 @@ class SACD(OffPolicyAlgorithm):
optimizer.zero_grad()
loss.backward()
if clipping_norm is not None:
th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) #clip gradients to help stabilise training
th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) # clip gradients to help stabilise training
optimizer.step()
def calc_critic_loss(self, replay_data):
@ -265,7 +264,7 @@ class SACD(OffPolicyAlgorithm):
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
next_q_values, _ = th.min(next_q_values, dim=2)
next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1)
next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1)
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q-values estimates for each critic network
@ -273,7 +272,9 @@ class SACD(OffPolicyAlgorithm):
current_q_values = self.critic(replay_data.observations)
# Compute critic loss
critic_loss = 0.5 * sum(F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values)
critic_loss = 0.5 * sum(
F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values
)
return critic_loss
@ -284,25 +285,21 @@ class SACD(OffPolicyAlgorithm):
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
min_qf_pi, _ = th.min(q_values_pi, dim=2)
inside_term = self.ent_coef * log_prob - min_qf_pi
inside_term = self.ent_coef_tensor * log_prob - min_qf_pi
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
return actor_loss, log_prob
def calc_entropy_loss(self, log_action_prob):
ent_coef_loss = None
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
# ent_coef_losses.append(ent_coef_loss.item())
ent_coef_loss.backward()
self.ent_coef_optimizer.step()
self.ent_coef = th.exp(self.log_ent_coef.detach())
return th.exp(self.log_ent_coef.detach())
else:
self.ent_coef = self.ent_coef_tensor
return self.ent_coef
return self.ent_coef_tensor
def learn(
self: SelfSACD,