Code style changes

This commit is contained in:
Paul Auerbach 2023-08-02 13:52:40 +02:00
parent 7711813dbb
commit 4a37f58259
3 changed files with 37 additions and 47 deletions

View File

@ -4,9 +4,9 @@ from sb3_contrib.ars import ARS
from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.sacd import SACD
from sb3_contrib.tqc import TQC from sb3_contrib.tqc import TQC
from sb3_contrib.trpo import TRPO from sb3_contrib.trpo import TRPO
from sb3_contrib.sacd import SACD
# Read version from file # Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt") version_file = os.path.join(os.path.dirname(__file__), "version.txt")

View File

@ -2,13 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch as th import torch as th
from gymnasium import spaces from gymnasium import spaces
from torch import nn from stable_baselines3.common.policies import BaseModel, BasePolicy
from torch.distributions import Categorical
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BasePolicy, BaseModel
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor, BaseFeaturesExtractor,
CombinedExtractor, CombinedExtractor,
@ -18,10 +12,13 @@ from stable_baselines3.common.torch_layers import (
get_actor_critic_arch, get_actor_critic_arch,
) )
from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.type_aliases import Schedule
from torch import nn
from torch.distributions import Categorical
class Actor(BasePolicy): class Actor(BasePolicy):
""" """
Actor network (policy) for SAC. Actor network (policy) for SACD
:param observation_space: Obervation space :param observation_space: Obervation space
:param action_space: Action space :param action_space: Action space
@ -42,7 +39,7 @@ class Actor(BasePolicy):
dividing by 255.0 (True by default) dividing by 255.0 (True by default)
""" """
action_space: spaces.Box action_space: spaces.Discrete
def __init__( def __init__(
self, self,
@ -51,7 +48,7 @@ class Actor(BasePolicy):
net_arch: List[int], net_arch: List[int],
features_extractor: nn.Module, features_extractor: nn.Module,
features_dim: int, features_dim: int,
activation_fn: Type[nn.Module] = nn.Softmax(dim=1), activation_fn: Type[nn.Module] = nn.Softmax,
use_sde: bool = False, use_sde: bool = False,
log_std_init: float = -3, log_std_init: float = -3,
full_std: bool = True, full_std: bool = True,
@ -132,6 +129,7 @@ class Actor(BasePolicy):
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic) return self(observation, deterministic)
class DiscreteCritic(BaseModel): class DiscreteCritic(BaseModel):
""" """
Critic network(s) for DDPG/SAC/TD3. Critic network(s) for DDPG/SAC/TD3.
@ -159,10 +157,12 @@ class DiscreteCritic(BaseModel):
between the actor and the critic (this saves computation time) between the actor and the critic (this saves computation time)
""" """
action_space: spaces.Discrete
def __init__( def __init__(
self, self,
observation_space: spaces.Space, observation_space: spaces.Space,
action_space: spaces.Box, action_space: spaces.Discrete,
net_arch: List[int], net_arch: List[int],
features_extractor: BaseFeaturesExtractor, features_extractor: BaseFeaturesExtractor,
features_dim: int, features_dim: int,
@ -199,9 +199,10 @@ class DiscreteCritic(BaseModel):
features = self.extract_features(obs, self.features_extractor) features = self.extract_features(obs, self.features_extractor)
return tuple(q_net(features) for q_net in self.q_networks) return tuple(q_net(features) for q_net in self.q_networks)
class SACPolicy(BasePolicy):
class SACDPolicy(BasePolicy):
""" """
Policy class (with both actor and critic) for SAC. Policy class (with both actor and critic) for SACD.
:param observation_space: Observation space :param observation_space: Observation space
:param action_space: Action space :param action_space: Action space
@ -351,14 +352,6 @@ class SACPolicy(BasePolicy):
) )
return data return data
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size:
"""
self.actor.reset_noise(batch_size=batch_size)
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor: def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
return Actor(**actor_kwargs).to(self.device) return Actor(**actor_kwargs).to(self.device)
@ -386,12 +379,12 @@ class SACPolicy(BasePolicy):
self.training = mode self.training = mode
MlpPolicy = SACPolicy MlpPolicy = SACDPolicy
class CnnPolicy(SACPolicy): class CnnPolicy(SACDPolicy):
""" """
Policy class (with both actor and critic) for SAC. Policy class (with both actor and critic) for SACD.
:param observation_space: Observation space :param observation_space: Observation space
:param action_space: Action space :param action_space: Action space
@ -455,9 +448,9 @@ class CnnPolicy(SACPolicy):
) )
class MultiInputPolicy(SACPolicy): class MultiInputPolicy(SACDPolicy):
""" """
Policy class (with both actor and critic) for SAC. Policy class (with both actor and critic) for SACD.
:param observation_space: Observation space :param observation_space: Observation space
:param action_space: Action space :param action_space: Action space

View File

@ -3,23 +3,22 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Un
import numpy as np import numpy as np
import torch as th import torch as th
from gymnasium import spaces from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from torch.nn import functional as F
from sb3_contrib.sacd.policies import Actor, DiscreteCritic, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy from sb3_contrib.sacd.policies import Actor, CnnPolicy, DiscreteCritic, MlpPolicy, MultiInputPolicy, SACDPolicy
SelfSACD = TypeVar("SelfSACD", bound="SACD") SelfSACD = TypeVar("SelfSACD", bound="SACD")
class SACD(OffPolicyAlgorithm): class SACD(OffPolicyAlgorithm):
""" """
Soft Actor-Critic (SAC) Soft Actor-Critic (SACD)
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
This implementation borrows code from original implementation (https://github.com/haarnoja/sac) This implementation borrows code from original implementation (https://github.com/haarnoja/sac)
from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo
@ -84,14 +83,14 @@ class SACD(OffPolicyAlgorithm):
"CnnPolicy": CnnPolicy, "CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy, "MultiInputPolicy": MultiInputPolicy,
} }
policy: SACPolicy policy: SACDPolicy
actor: Actor actor: Actor
critic: DiscreteCritic critic: DiscreteCritic
critic_target: DiscreteCritic critic_target: DiscreteCritic
def __init__( def __init__(
self, self,
policy: Union[str, Type[SACPolicy]], policy: Union[str, Type[SACDPolicy]],
env: Union[GymEnv, str], env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4, learning_rate: Union[float, Schedule] = 3e-4,
buffer_size: int = 1_000_000, # 1e6 buffer_size: int = 1_000_000, # 1e6
@ -189,7 +188,7 @@ class SACD(OffPolicyAlgorithm):
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37 # as discussed in https://github.com/rail-berkeley/softlearning/issues/37
# self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True) # self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
self.log_ent_coef = th.zeros(1, device=self.device, requires_grad=True) self.log_ent_coef = th.zeros(1, device=self.device, requires_grad=True)
self.ent_coef = th.exp(self.log_ent_coef) self.ent_coef_tensor = th.exp(self.log_ent_coef)
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1)) self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
else: else:
# Force conversion to float # Force conversion to float
@ -231,8 +230,8 @@ class SACD(OffPolicyAlgorithm):
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm) self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm)
# Compute entropy loss and optimize # Compute entropy loss and optimize
ent_coeff = self.calc_entropy_loss(log_action_prob) self.ent_coef_tensor = self.calc_entropy_loss(log_action_prob)
ent_coefs.append(self.ent_coef.item()) ent_coefs.append(self.ent_coef_tensor.item())
# Update target networks # Update target networks
if gradient_step % self.target_update_interval == 0: if gradient_step % self.target_update_interval == 0:
@ -265,7 +264,7 @@ class SACD(OffPolicyAlgorithm):
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2) next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
next_q_values, _ = th.min(next_q_values, dim=2) next_q_values, _ = th.min(next_q_values, dim=2)
next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1) next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1)
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q-values estimates for each critic network # Get current Q-values estimates for each critic network
@ -273,7 +272,9 @@ class SACD(OffPolicyAlgorithm):
current_q_values = self.critic(replay_data.observations) current_q_values = self.critic(replay_data.observations)
# Compute critic loss # Compute critic loss
critic_loss = 0.5 * sum(F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values) critic_loss = 0.5 * sum(
F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values
)
return critic_loss return critic_loss
@ -284,25 +285,21 @@ class SACD(OffPolicyAlgorithm):
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2) q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
min_qf_pi, _ = th.min(q_values_pi, dim=2) min_qf_pi, _ = th.min(q_values_pi, dim=2)
inside_term = self.ent_coef * log_prob - min_qf_pi inside_term = self.ent_coef_tensor * log_prob - min_qf_pi
actor_loss = (action_prob * inside_term).sum(dim=1).mean() actor_loss = (action_prob * inside_term).sum(dim=1).mean()
return actor_loss, log_prob return actor_loss, log_prob
def calc_entropy_loss(self, log_action_prob): def calc_entropy_loss(self, log_action_prob):
ent_coef_loss = None
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
# Important: detach the variable from the graph # Important: detach the variable from the graph
# so we don't change it with other losses # so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60 # see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean() ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
# ent_coef_losses.append(ent_coef_loss.item())
ent_coef_loss.backward() ent_coef_loss.backward()
self.ent_coef_optimizer.step() self.ent_coef_optimizer.step()
self.ent_coef = th.exp(self.log_ent_coef.detach()) return th.exp(self.log_ent_coef.detach())
else: else:
self.ent_coef = self.ent_coef_tensor return self.ent_coef_tensor
return self.ent_coef
def learn( def learn(
self: SelfSACD, self: SelfSACD,