From 4a37f58259148faf12540badc99521dac16b101f Mon Sep 17 00:00:00 2001 From: Paul Auerbach Date: Wed, 2 Aug 2023 13:52:40 +0200 Subject: [PATCH] Code style changes --- sb3_contrib/__init__.py | 2 +- sb3_contrib/sacd/policies.py | 45 +++++++++++++++--------------------- sb3_contrib/sacd/sacd.py | 37 ++++++++++++++--------------- 3 files changed, 37 insertions(+), 47 deletions(-) diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 9e5f4cd..f49a725 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -4,9 +4,9 @@ from sb3_contrib.ars import ARS from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.ppo_recurrent import RecurrentPPO from sb3_contrib.qrdqn import QRDQN +from sb3_contrib.sacd import SACD from sb3_contrib.tqc import TQC from sb3_contrib.trpo import TRPO -from sb3_contrib.sacd import SACD # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") diff --git a/sb3_contrib/sacd/policies.py b/sb3_contrib/sacd/policies.py index f4efef5..49f560a 100644 --- a/sb3_contrib/sacd/policies.py +++ b/sb3_contrib/sacd/policies.py @@ -2,13 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch as th from gymnasium import spaces -from torch import nn - -from torch.distributions import Categorical - -from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution -from stable_baselines3.common.policies import BasePolicy, BaseModel -from stable_baselines3.common.preprocessing import get_action_dim +from stable_baselines3.common.policies import BaseModel, BasePolicy from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, @@ -18,10 +12,13 @@ from stable_baselines3.common.torch_layers import ( get_actor_critic_arch, ) from stable_baselines3.common.type_aliases import Schedule +from torch import nn +from torch.distributions import Categorical + class Actor(BasePolicy): """ - Actor network (policy) for SAC. + Actor network (policy) for SACD :param observation_space: Obervation space :param action_space: Action space @@ -42,7 +39,7 @@ class Actor(BasePolicy): dividing by 255.0 (True by default) """ - action_space: spaces.Box + action_space: spaces.Discrete def __init__( self, @@ -51,7 +48,7 @@ class Actor(BasePolicy): net_arch: List[int], features_extractor: nn.Module, features_dim: int, - activation_fn: Type[nn.Module] = nn.Softmax(dim=1), + activation_fn: Type[nn.Module] = nn.Softmax, use_sde: bool = False, log_std_init: float = -3, full_std: bool = True, @@ -132,6 +129,7 @@ class Actor(BasePolicy): def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: return self(observation, deterministic) + class DiscreteCritic(BaseModel): """ Critic network(s) for DDPG/SAC/TD3. @@ -159,10 +157,12 @@ class DiscreteCritic(BaseModel): between the actor and the critic (this saves computation time) """ + action_space: spaces.Discrete + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Box, + action_space: spaces.Discrete, net_arch: List[int], features_extractor: BaseFeaturesExtractor, features_dim: int, @@ -199,9 +199,10 @@ class DiscreteCritic(BaseModel): features = self.extract_features(obs, self.features_extractor) return tuple(q_net(features) for q_net in self.q_networks) -class SACPolicy(BasePolicy): + +class SACDPolicy(BasePolicy): """ - Policy class (with both actor and critic) for SAC. + Policy class (with both actor and critic) for SACD. :param observation_space: Observation space :param action_space: Action space @@ -351,14 +352,6 @@ class SACPolicy(BasePolicy): ) return data - def reset_noise(self, batch_size: int = 1) -> None: - """ - Sample new weights for the exploration matrix, when using gSDE. - - :param batch_size: - """ - self.actor.reset_noise(batch_size=batch_size) - def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor: actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) return Actor(**actor_kwargs).to(self.device) @@ -386,12 +379,12 @@ class SACPolicy(BasePolicy): self.training = mode -MlpPolicy = SACPolicy +MlpPolicy = SACDPolicy -class CnnPolicy(SACPolicy): +class CnnPolicy(SACDPolicy): """ - Policy class (with both actor and critic) for SAC. + Policy class (with both actor and critic) for SACD. :param observation_space: Observation space :param action_space: Action space @@ -455,9 +448,9 @@ class CnnPolicy(SACPolicy): ) -class MultiInputPolicy(SACPolicy): +class MultiInputPolicy(SACDPolicy): """ - Policy class (with both actor and critic) for SAC. + Policy class (with both actor and critic) for SACD. :param observation_space: Observation space :param action_space: Action space diff --git a/sb3_contrib/sacd/sacd.py b/sb3_contrib/sacd/sacd.py index 9e48513..8433728 100644 --- a/sb3_contrib/sacd/sacd.py +++ b/sb3_contrib/sacd/sacd.py @@ -3,23 +3,22 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Un import numpy as np import torch as th from gymnasium import spaces -from torch.nn import functional as F - from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update +from torch.nn import functional as F -from sb3_contrib.sacd.policies import Actor, DiscreteCritic, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy +from sb3_contrib.sacd.policies import Actor, CnnPolicy, DiscreteCritic, MlpPolicy, MultiInputPolicy, SACDPolicy SelfSACD = TypeVar("SelfSACD", bound="SACD") class SACD(OffPolicyAlgorithm): """ - Soft Actor-Critic (SAC) + Soft Actor-Critic (SACD) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, This implementation borrows code from original implementation (https://github.com/haarnoja/sac) from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo @@ -84,14 +83,14 @@ class SACD(OffPolicyAlgorithm): "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } - policy: SACPolicy + policy: SACDPolicy actor: Actor critic: DiscreteCritic critic_target: DiscreteCritic def __init__( self, - policy: Union[str, Type[SACPolicy]], + policy: Union[str, Type[SACDPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, buffer_size: int = 1_000_000, # 1e6 @@ -112,7 +111,7 @@ class SACD(OffPolicyAlgorithm): sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, stats_window_size: int = 100, - max_grad_norm = 5.0, + max_grad_norm=5.0, tensorboard_log: Optional[str] = None, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, @@ -189,7 +188,7 @@ class SACD(OffPolicyAlgorithm): # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 # self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True) self.log_ent_coef = th.zeros(1, device=self.device, requires_grad=True) - self.ent_coef = th.exp(self.log_ent_coef) + self.ent_coef_tensor = th.exp(self.log_ent_coef) self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1)) else: # Force conversion to float @@ -231,8 +230,8 @@ class SACD(OffPolicyAlgorithm): self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm) # Compute entropy loss and optimize - ent_coeff = self.calc_entropy_loss(log_action_prob) - ent_coefs.append(self.ent_coef.item()) + self.ent_coef_tensor = self.calc_entropy_loss(log_action_prob) + ent_coefs.append(self.ent_coef_tensor.item()) # Update target networks if gradient_step % self.target_update_interval == 0: @@ -253,7 +252,7 @@ class SACD(OffPolicyAlgorithm): optimizer.zero_grad() loss.backward() if clipping_norm is not None: - th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) #clip gradients to help stabilise training + th.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) # clip gradients to help stabilise training optimizer.step() def calc_critic_loss(self, replay_data): @@ -265,7 +264,7 @@ class SACD(OffPolicyAlgorithm): next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2) next_q_values, _ = th.min(next_q_values, dim=2) - next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1) + next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1) target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values # Get current Q-values estimates for each critic network @@ -273,7 +272,9 @@ class SACD(OffPolicyAlgorithm): current_q_values = self.critic(replay_data.observations) # Compute critic loss - critic_loss = 0.5 * sum(F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values) + critic_loss = 0.5 * sum( + F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values + ) return critic_loss @@ -284,25 +285,21 @@ class SACD(OffPolicyAlgorithm): q_values_pi = th.stack(self.critic(replay_data.observations), dim=2) min_qf_pi, _ = th.min(q_values_pi, dim=2) - inside_term = self.ent_coef * log_prob - min_qf_pi + inside_term = self.ent_coef_tensor * log_prob - min_qf_pi actor_loss = (action_prob * inside_term).sum(dim=1).mean() return actor_loss, log_prob def calc_entropy_loss(self, log_action_prob): - ent_coef_loss = None if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: # Important: detach the variable from the graph # so we don't change it with other losses # see https://github.com/rail-berkeley/softlearning/issues/60 ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean() - # ent_coef_losses.append(ent_coef_loss.item()) ent_coef_loss.backward() self.ent_coef_optimizer.step() - self.ent_coef = th.exp(self.log_ent_coef.detach()) + return th.exp(self.log_ent_coef.detach()) else: - self.ent_coef = self.ent_coef_tensor - - return self.ent_coef + return self.ent_coef_tensor def learn( self: SelfSACD,