Code style changes
This commit is contained in:
parent
7711813dbb
commit
4a37f58259
|
|
@ -4,9 +4,9 @@ from sb3_contrib.ars import ARS
|
||||||
from sb3_contrib.ppo_mask import MaskablePPO
|
from sb3_contrib.ppo_mask import MaskablePPO
|
||||||
from sb3_contrib.ppo_recurrent import RecurrentPPO
|
from sb3_contrib.ppo_recurrent import RecurrentPPO
|
||||||
from sb3_contrib.qrdqn import QRDQN
|
from sb3_contrib.qrdqn import QRDQN
|
||||||
|
from sb3_contrib.sacd import SACD
|
||||||
from sb3_contrib.tqc import TQC
|
from sb3_contrib.tqc import TQC
|
||||||
from sb3_contrib.trpo import TRPO
|
from sb3_contrib.trpo import TRPO
|
||||||
from sb3_contrib.sacd import SACD
|
|
||||||
|
|
||||||
# Read version from file
|
# Read version from file
|
||||||
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
|
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch as th
|
import torch as th
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from torch import nn
|
from stable_baselines3.common.policies import BaseModel, BasePolicy
|
||||||
|
|
||||||
from torch.distributions import Categorical
|
|
||||||
|
|
||||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
|
||||||
from stable_baselines3.common.policies import BasePolicy, BaseModel
|
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
|
||||||
from stable_baselines3.common.torch_layers import (
|
from stable_baselines3.common.torch_layers import (
|
||||||
BaseFeaturesExtractor,
|
BaseFeaturesExtractor,
|
||||||
CombinedExtractor,
|
CombinedExtractor,
|
||||||
|
|
@ -18,10 +12,13 @@ from stable_baselines3.common.torch_layers import (
|
||||||
get_actor_critic_arch,
|
get_actor_critic_arch,
|
||||||
)
|
)
|
||||||
from stable_baselines3.common.type_aliases import Schedule
|
from stable_baselines3.common.type_aliases import Schedule
|
||||||
|
from torch import nn
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
|
|
||||||
class Actor(BasePolicy):
|
class Actor(BasePolicy):
|
||||||
"""
|
"""
|
||||||
Actor network (policy) for SAC.
|
Actor network (policy) for SACD
|
||||||
|
|
||||||
:param observation_space: Obervation space
|
:param observation_space: Obervation space
|
||||||
:param action_space: Action space
|
:param action_space: Action space
|
||||||
|
|
@ -42,7 +39,7 @@ class Actor(BasePolicy):
|
||||||
dividing by 255.0 (True by default)
|
dividing by 255.0 (True by default)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
action_space: spaces.Box
|
action_space: spaces.Discrete
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -51,7 +48,7 @@ class Actor(BasePolicy):
|
||||||
net_arch: List[int],
|
net_arch: List[int],
|
||||||
features_extractor: nn.Module,
|
features_extractor: nn.Module,
|
||||||
features_dim: int,
|
features_dim: int,
|
||||||
activation_fn: Type[nn.Module] = nn.Softmax(dim=1),
|
activation_fn: Type[nn.Module] = nn.Softmax,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
log_std_init: float = -3,
|
log_std_init: float = -3,
|
||||||
full_std: bool = True,
|
full_std: bool = True,
|
||||||
|
|
@ -132,6 +129,7 @@ class Actor(BasePolicy):
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
return self(observation, deterministic)
|
return self(observation, deterministic)
|
||||||
|
|
||||||
|
|
||||||
class DiscreteCritic(BaseModel):
|
class DiscreteCritic(BaseModel):
|
||||||
"""
|
"""
|
||||||
Critic network(s) for DDPG/SAC/TD3.
|
Critic network(s) for DDPG/SAC/TD3.
|
||||||
|
|
@ -159,10 +157,12 @@ class DiscreteCritic(BaseModel):
|
||||||
between the actor and the critic (this saves computation time)
|
between the actor and the critic (this saves computation time)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
action_space: spaces.Discrete
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Box,
|
action_space: spaces.Discrete,
|
||||||
net_arch: List[int],
|
net_arch: List[int],
|
||||||
features_extractor: BaseFeaturesExtractor,
|
features_extractor: BaseFeaturesExtractor,
|
||||||
features_dim: int,
|
features_dim: int,
|
||||||
|
|
@ -199,9 +199,10 @@ class DiscreteCritic(BaseModel):
|
||||||
features = self.extract_features(obs, self.features_extractor)
|
features = self.extract_features(obs, self.features_extractor)
|
||||||
return tuple(q_net(features) for q_net in self.q_networks)
|
return tuple(q_net(features) for q_net in self.q_networks)
|
||||||
|
|
||||||
class SACPolicy(BasePolicy):
|
|
||||||
|
class SACDPolicy(BasePolicy):
|
||||||
"""
|
"""
|
||||||
Policy class (with both actor and critic) for SAC.
|
Policy class (with both actor and critic) for SACD.
|
||||||
|
|
||||||
:param observation_space: Observation space
|
:param observation_space: Observation space
|
||||||
:param action_space: Action space
|
:param action_space: Action space
|
||||||
|
|
@ -351,14 +352,6 @@ class SACPolicy(BasePolicy):
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def reset_noise(self, batch_size: int = 1) -> None:
|
|
||||||
"""
|
|
||||||
Sample new weights for the exploration matrix, when using gSDE.
|
|
||||||
|
|
||||||
:param batch_size:
|
|
||||||
"""
|
|
||||||
self.actor.reset_noise(batch_size=batch_size)
|
|
||||||
|
|
||||||
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
|
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
|
||||||
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
|
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
|
||||||
return Actor(**actor_kwargs).to(self.device)
|
return Actor(**actor_kwargs).to(self.device)
|
||||||
|
|
@ -386,12 +379,12 @@ class SACPolicy(BasePolicy):
|
||||||
self.training = mode
|
self.training = mode
|
||||||
|
|
||||||
|
|
||||||
MlpPolicy = SACPolicy
|
MlpPolicy = SACDPolicy
|
||||||
|
|
||||||
|
|
||||||
class CnnPolicy(SACPolicy):
|
class CnnPolicy(SACDPolicy):
|
||||||
"""
|
"""
|
||||||
Policy class (with both actor and critic) for SAC.
|
Policy class (with both actor and critic) for SACD.
|
||||||
|
|
||||||
:param observation_space: Observation space
|
:param observation_space: Observation space
|
||||||
:param action_space: Action space
|
:param action_space: Action space
|
||||||
|
|
@ -455,9 +448,9 @@ class CnnPolicy(SACPolicy):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiInputPolicy(SACPolicy):
|
class MultiInputPolicy(SACDPolicy):
|
||||||
"""
|
"""
|
||||||
Policy class (with both actor and critic) for SAC.
|
Policy class (with both actor and critic) for SACD.
|
||||||
|
|
||||||
:param observation_space: Observation space
|
:param observation_space: Observation space
|
||||||
:param action_space: Action space
|
:param action_space: Action space
|
||||||
|
|
|
||||||
|
|
@ -3,23 +3,22 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Un
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from stable_baselines3.common.buffers import ReplayBuffer
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.noise import ActionNoise
|
from stable_baselines3.common.noise import ActionNoise
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from sb3_contrib.sacd.policies import Actor, DiscreteCritic, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
|
from sb3_contrib.sacd.policies import Actor, CnnPolicy, DiscreteCritic, MlpPolicy, MultiInputPolicy, SACDPolicy
|
||||||
|
|
||||||
SelfSACD = TypeVar("SelfSACD", bound="SACD")
|
SelfSACD = TypeVar("SelfSACD", bound="SACD")
|
||||||
|
|
||||||
|
|
||||||
class SACD(OffPolicyAlgorithm):
|
class SACD(OffPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
Soft Actor-Critic (SAC)
|
Soft Actor-Critic (SACD)
|
||||||
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
|
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
|
||||||
This implementation borrows code from original implementation (https://github.com/haarnoja/sac)
|
This implementation borrows code from original implementation (https://github.com/haarnoja/sac)
|
||||||
from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo
|
from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo
|
||||||
|
|
@ -84,14 +83,14 @@ class SACD(OffPolicyAlgorithm):
|
||||||
"CnnPolicy": CnnPolicy,
|
"CnnPolicy": CnnPolicy,
|
||||||
"MultiInputPolicy": MultiInputPolicy,
|
"MultiInputPolicy": MultiInputPolicy,
|
||||||
}
|
}
|
||||||
policy: SACPolicy
|
policy: SACDPolicy
|
||||||
actor: Actor
|
actor: Actor
|
||||||
critic: DiscreteCritic
|
critic: DiscreteCritic
|
||||||
critic_target: DiscreteCritic
|
critic_target: DiscreteCritic
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policy: Union[str, Type[SACPolicy]],
|
policy: Union[str, Type[SACDPolicy]],
|
||||||
env: Union[GymEnv, str],
|
env: Union[GymEnv, str],
|
||||||
learning_rate: Union[float, Schedule] = 3e-4,
|
learning_rate: Union[float, Schedule] = 3e-4,
|
||||||
buffer_size: int = 1_000_000, # 1e6
|
buffer_size: int = 1_000_000, # 1e6
|
||||||
|
|
@ -189,7 +188,7 @@ class SACD(OffPolicyAlgorithm):
|
||||||
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
|
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
|
||||||
# self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
|
# self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
|
||||||
self.log_ent_coef = th.zeros(1, device=self.device, requires_grad=True)
|
self.log_ent_coef = th.zeros(1, device=self.device, requires_grad=True)
|
||||||
self.ent_coef = th.exp(self.log_ent_coef)
|
self.ent_coef_tensor = th.exp(self.log_ent_coef)
|
||||||
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
|
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
|
||||||
else:
|
else:
|
||||||
# Force conversion to float
|
# Force conversion to float
|
||||||
|
|
@ -231,8 +230,8 @@ class SACD(OffPolicyAlgorithm):
|
||||||
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm)
|
self.take_optimisation_step(self.actor.optimizer, self.actor, actor_loss, self.gradient_clip_norm)
|
||||||
|
|
||||||
# Compute entropy loss and optimize
|
# Compute entropy loss and optimize
|
||||||
ent_coeff = self.calc_entropy_loss(log_action_prob)
|
self.ent_coef_tensor = self.calc_entropy_loss(log_action_prob)
|
||||||
ent_coefs.append(self.ent_coef.item())
|
ent_coefs.append(self.ent_coef_tensor.item())
|
||||||
|
|
||||||
# Update target networks
|
# Update target networks
|
||||||
if gradient_step % self.target_update_interval == 0:
|
if gradient_step % self.target_update_interval == 0:
|
||||||
|
|
@ -265,7 +264,7 @@ class SACD(OffPolicyAlgorithm):
|
||||||
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
|
next_q_values = th.stack(self.critic_target(replay_data.next_observations), dim=2)
|
||||||
next_q_values, _ = th.min(next_q_values, dim=2)
|
next_q_values, _ = th.min(next_q_values, dim=2)
|
||||||
|
|
||||||
next_q_values = (action_prob * (next_q_values - self.ent_coef * next_log_prob)).sum(dim=1).unsqueeze(-1)
|
next_q_values = (action_prob * (next_q_values - self.ent_coef_tensor * next_log_prob)).sum(dim=1).unsqueeze(-1)
|
||||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
||||||
|
|
||||||
# Get current Q-values estimates for each critic network
|
# Get current Q-values estimates for each critic network
|
||||||
|
|
@ -273,7 +272,9 @@ class SACD(OffPolicyAlgorithm):
|
||||||
current_q_values = self.critic(replay_data.observations)
|
current_q_values = self.critic(replay_data.observations)
|
||||||
|
|
||||||
# Compute critic loss
|
# Compute critic loss
|
||||||
critic_loss = 0.5 * sum(F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values)
|
critic_loss = 0.5 * sum(
|
||||||
|
F.mse_loss(current_q.gather(1, replay_data.actions), target_q_values) for current_q in current_q_values
|
||||||
|
)
|
||||||
|
|
||||||
return critic_loss
|
return critic_loss
|
||||||
|
|
||||||
|
|
@ -284,25 +285,21 @@ class SACD(OffPolicyAlgorithm):
|
||||||
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
|
q_values_pi = th.stack(self.critic(replay_data.observations), dim=2)
|
||||||
min_qf_pi, _ = th.min(q_values_pi, dim=2)
|
min_qf_pi, _ = th.min(q_values_pi, dim=2)
|
||||||
|
|
||||||
inside_term = self.ent_coef * log_prob - min_qf_pi
|
inside_term = self.ent_coef_tensor * log_prob - min_qf_pi
|
||||||
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
|
actor_loss = (action_prob * inside_term).sum(dim=1).mean()
|
||||||
return actor_loss, log_prob
|
return actor_loss, log_prob
|
||||||
|
|
||||||
def calc_entropy_loss(self, log_action_prob):
|
def calc_entropy_loss(self, log_action_prob):
|
||||||
ent_coef_loss = None
|
|
||||||
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
||||||
# Important: detach the variable from the graph
|
# Important: detach the variable from the graph
|
||||||
# so we don't change it with other losses
|
# so we don't change it with other losses
|
||||||
# see https://github.com/rail-berkeley/softlearning/issues/60
|
# see https://github.com/rail-berkeley/softlearning/issues/60
|
||||||
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
|
ent_coef_loss = -(self.log_ent_coef * (log_action_prob + self.target_entropy).detach()).mean()
|
||||||
# ent_coef_losses.append(ent_coef_loss.item())
|
|
||||||
ent_coef_loss.backward()
|
ent_coef_loss.backward()
|
||||||
self.ent_coef_optimizer.step()
|
self.ent_coef_optimizer.step()
|
||||||
self.ent_coef = th.exp(self.log_ent_coef.detach())
|
return th.exp(self.log_ent_coef.detach())
|
||||||
else:
|
else:
|
||||||
self.ent_coef = self.ent_coef_tensor
|
return self.ent_coef_tensor
|
||||||
|
|
||||||
return self.ent_coef
|
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self: SelfSACD,
|
self: SelfSACD,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue