stable-baselines3-contrib-sacd/sb3_contrib/sacd/policies.py

522 lines
20 KiB
Python

from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch as th
from gymnasium import spaces
from torch import nn
from torch.distributions import Categorical
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BasePolicy, BaseModel
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
NatureCNN,
create_mlp,
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
class Actor(BasePolicy):
"""
Actor network (policy) for SAC.
:param observation_space: Obervation space
:param action_space: Action space
:param net_arch: Network architecture
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: Number of features
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
action_space: spaces.Box
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.Softmax(dim=1),
use_sde: bool = False,
log_std_init: float = -3,
full_std: bool = True,
use_expln: bool = False,
clip_mean: float = 2.0,
normalize_images: bool = True,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
# squash_output=True,
squash_output=False,
)
# Save arguments to re-create object at loading
self.use_sde = use_sde
self.sde_features_extractor = None
self.net_arch = net_arch
self.features_dim = features_dim
self.activation_fn = activation_fn
self.log_std_init = log_std_init
self.use_expln = use_expln
self.full_std = full_std
self.clip_mean = clip_mean
num_actions = self.action_space.n
latent_pi_net = create_mlp(features_dim, num_actions, net_arch, activation_fn)
self.latent_pi = nn.Sequential(*latent_pi_net)
self.output_activation = nn.Softmax(dim=1)
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
full_std=self.full_std,
use_expln=self.use_expln,
features_extractor=self.features_extractor,
clip_mean=self.clip_mean,
)
)
return data
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
features = self.extract_features(obs, self.features_extractor)
action_probabilities = self.output_activation(self.latent_pi(features))
if deterministic:
action = th.argmax(action_probabilities)
else:
# random action according to policy
dist = Categorical(probs=action_probabilities)
action = dist.sample()
return action
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
features = self.extract_features(obs, self.features_extractor)
action_prob = self.output_activation(self.latent_pi(features))
# Have to deal with situation of 0.0 probabilities because we can't do log 0
z = action_prob == 0.0
z = z.float() * 1e-8
log_action_prob = th.log(action_prob + z)
return action_prob, log_action_prob
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic)
class DiscreteCritic(BaseModel):
"""
Critic network(s) for DDPG/SAC/TD3.
It represents the action-state value function (Q-value function).
Compared to A2C/PPO critics, this one represents the Q-value
and takes the continuous action as input. It is concatenated with the state
and then fed to the network which outputs a single value: Q(s, a).
For more recent algorithms like SAC/TD3, multiple networks
are created to give different estimates.
By default, it creates two critic networks used to reduce overestimation
thanks to clipped Q-learning (cf TD3 paper).
:param observation_space: Obervation space
:param action_space: Action space
:param net_arch: Network architecture
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: Number of features
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether the features extractor is shared or not
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
net_arch: List[int],
features_extractor: BaseFeaturesExtractor,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)
action_dim = get_action_dim(self.action_space)
self.share_features_extractor = share_features_extractor
self.n_critics = n_critics
self.q_networks = []
for idx in range(n_critics):
q_net = create_mlp(features_dim, action_dim, net_arch, activation_fn)
q_net = nn.Sequential(*q_net)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)
def get_crit_params(self, n):
return self.q_networks[n].parameters()
def forward(self, obs: th.Tensor) -> Tuple[th.Tensor, ...]:
# Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs, self.features_extractor)
return tuple(q_net(features) for q_net in self.q_networks)
class SACPolicy(BasePolicy):
"""
Policy class (with both actor and critic) for SAC.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
actor: Actor
critic: DiscreteCritic
critic_target: DiscreteCritic
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True,
normalize_images=normalize_images,
)
if net_arch is None:
net_arch = [256, 256]
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"net_arch": actor_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
self.actor_kwargs = self.net_args.copy()
sde_kwargs = {
"use_sde": use_sde,
"log_std_init": log_std_init,
"use_expln": use_expln,
"clip_mean": clip_mean,
}
self.actor_kwargs.update(sde_kwargs)
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update(
{
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
}
)
self.share_features_extractor = share_features_extractor
self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None:
self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(
self.actor.parameters(),
lr=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
)
if self.share_features_extractor:
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
# Do not optimize the shared features extractor with the critic loss
# otherwise, there are gradient computation issues
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
else:
# Create a separate features extractor for the critic
# this requires more memory and computation
self.critic = self.make_critic(features_extractor=None)
critic_parameters = list(self.critic.parameters())
# Critic target should not share the features extractor with critic
self.critic_target = self.make_critic(features_extractor=None)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = self.optimizer_class(
critic_parameters,
lr=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
)
# Target networks should always be in eval mode
self.critic_target.set_training_mode(False)
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
activation_fn=self.net_args["activation_fn"],
use_sde=self.actor_kwargs["use_sde"],
log_std_init=self.actor_kwargs["log_std_init"],
use_expln=self.actor_kwargs["use_expln"],
clip_mean=self.actor_kwargs["clip_mean"],
n_critics=self.critic_kwargs["n_critics"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size:
"""
self.actor.reset_noise(batch_size=batch_size)
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
return Actor(**actor_kwargs).to(self.device)
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> DiscreteCritic:
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return DiscreteCritic(**critic_kwargs).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.actor(observation, deterministic)
def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.actor.set_training_mode(mode)
self.critic.set_training_mode(mode)
self.training = mode
MlpPolicy = SACPolicy
class CnnPolicy(SACPolicy):
"""
Policy class (with both actor and critic) for SAC.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
use_sde,
log_std_init,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
)
class MultiInputPolicy(SACPolicy):
"""
Policy class (with both actor and critic) for SAC.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
use_sde,
log_std_init,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
)