612 lines
26 KiB
Python
612 lines
26 KiB
Python
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from gymnasium import spaces
|
|
from stable_baselines3.common.distributions import Distribution
|
|
from stable_baselines3.common.policies import ActorCriticPolicy
|
|
from stable_baselines3.common.torch_layers import (
|
|
BaseFeaturesExtractor,
|
|
CombinedExtractor,
|
|
FlattenExtractor,
|
|
MlpExtractor,
|
|
NatureCNN,
|
|
)
|
|
from stable_baselines3.common.type_aliases import Schedule
|
|
from stable_baselines3.common.utils import zip_strict
|
|
from torch import nn
|
|
|
|
from sb3_contrib.common.recurrent.type_aliases import RNNStates
|
|
|
|
|
|
class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|
"""
|
|
Recurrent policy class for actor-critic algorithms (has both policy and value prediction).
|
|
To be used with A2C, PPO and the likes.
|
|
It assumes that both the actor and the critic LSTM
|
|
have the same architecture.
|
|
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param lr_schedule: Learning rate schedule (could be constant)
|
|
:param net_arch: The specification of the policy and value networks.
|
|
:param activation_fn: Activation function
|
|
:param ortho_init: Whether to use or not orthogonal initialization
|
|
:param use_sde: Whether to use State Dependent Exploration or not
|
|
:param log_std_init: Initial value for the log standard deviation
|
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
|
for the std instead of only (n_features,) when using gSDE
|
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
|
a positive standard deviation (cf paper). It allows to keep variance
|
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
|
:param squash_output: Whether to squash the output using a tanh function,
|
|
this allows to ensure boundaries when using gSDE.
|
|
:param features_extractor_class: Features extractor to use.
|
|
:param features_extractor_kwargs: Keyword arguments
|
|
to pass to the features extractor.
|
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param optimizer_class: The optimizer to use,
|
|
``th.optim.Adam`` by default
|
|
:param optimizer_kwargs: Additional keyword arguments,
|
|
excluding the learning rate, to pass to the optimizer
|
|
:param lstm_hidden_size: Number of hidden units for each LSTM layer.
|
|
:param n_lstm_layers: Number of LSTM layers.
|
|
:param shared_lstm: Whether the LSTM is shared between the actor and the critic
|
|
(in that case, only the actor gradient is used)
|
|
By default, the actor and the critic have two separate LSTM.
|
|
:param enable_critic_lstm: Use a seperate LSTM for the critic.
|
|
:param lstm_kwargs: Additional keyword arguments to pass the the LSTM
|
|
constructor.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
lr_schedule: Schedule,
|
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
|
ortho_init: bool = True,
|
|
use_sde: bool = False,
|
|
log_std_init: float = 0.0,
|
|
full_std: bool = True,
|
|
use_expln: bool = False,
|
|
squash_output: bool = False,
|
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
|
share_features_extractor: bool = True,
|
|
normalize_images: bool = True,
|
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
lstm_hidden_size: int = 256,
|
|
n_lstm_layers: int = 1,
|
|
shared_lstm: bool = False,
|
|
enable_critic_lstm: bool = True,
|
|
lstm_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
self.lstm_output_dim = lstm_hidden_size
|
|
super().__init__(
|
|
observation_space,
|
|
action_space,
|
|
lr_schedule,
|
|
net_arch,
|
|
activation_fn,
|
|
ortho_init,
|
|
use_sde,
|
|
log_std_init,
|
|
full_std,
|
|
use_expln,
|
|
squash_output,
|
|
features_extractor_class,
|
|
features_extractor_kwargs,
|
|
share_features_extractor,
|
|
normalize_images,
|
|
optimizer_class,
|
|
optimizer_kwargs,
|
|
)
|
|
|
|
self.lstm_kwargs = lstm_kwargs or {}
|
|
self.shared_lstm = shared_lstm
|
|
self.enable_critic_lstm = enable_critic_lstm
|
|
self.lstm_actor = nn.LSTM(
|
|
self.features_dim,
|
|
lstm_hidden_size,
|
|
num_layers=n_lstm_layers,
|
|
**self.lstm_kwargs,
|
|
)
|
|
# For the predict() method, to initialize hidden states
|
|
# (n_lstm_layers, batch_size, lstm_hidden_size)
|
|
self.lstm_hidden_state_shape = (n_lstm_layers, 1, lstm_hidden_size)
|
|
self.critic = None
|
|
self.lstm_critic = None
|
|
assert not (
|
|
self.shared_lstm and self.enable_critic_lstm
|
|
), "You must choose between shared LSTM, seperate or no LSTM for the critic."
|
|
|
|
assert not (
|
|
self.shared_lstm and not self.share_features_extractor
|
|
), "If the features extractor is not shared, the LSTM cannot be shared."
|
|
|
|
# No LSTM for the critic, we still need to convert
|
|
# output of features extractor to the correct size
|
|
# (size of the output of the actor lstm)
|
|
if not (self.shared_lstm or self.enable_critic_lstm):
|
|
self.critic = nn.Linear(self.features_dim, lstm_hidden_size)
|
|
|
|
# Use a separate LSTM for the critic
|
|
if self.enable_critic_lstm:
|
|
self.lstm_critic = nn.LSTM(
|
|
self.features_dim,
|
|
lstm_hidden_size,
|
|
num_layers=n_lstm_layers,
|
|
**self.lstm_kwargs,
|
|
)
|
|
|
|
# Setup optimizer with initial learning rate
|
|
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
|
|
|
def _build_mlp_extractor(self) -> None:
|
|
"""
|
|
Create the policy and value networks.
|
|
Part of the layers can be shared.
|
|
"""
|
|
self.mlp_extractor = MlpExtractor(
|
|
self.lstm_output_dim,
|
|
net_arch=self.net_arch,
|
|
activation_fn=self.activation_fn,
|
|
device=self.device,
|
|
)
|
|
|
|
@staticmethod
|
|
def _process_sequence(
|
|
features: th.Tensor,
|
|
lstm_states: Tuple[th.Tensor, th.Tensor],
|
|
episode_starts: th.Tensor,
|
|
lstm: nn.LSTM,
|
|
) -> Tuple[th.Tensor, th.Tensor]:
|
|
"""
|
|
Do a forward pass in the LSTM network.
|
|
|
|
:param features: Input tensor
|
|
:param lstm_states: previous cell and hidden states of the LSTM
|
|
:param episode_starts: Indicates when a new episode starts,
|
|
in that case, we need to reset LSTM states.
|
|
:param lstm: LSTM object.
|
|
:return: LSTM output and updated LSTM states.
|
|
"""
|
|
# LSTM logic
|
|
# (sequence length, batch size, features dim)
|
|
# (batch size = n_envs for data collection or n_seq when doing gradient update)
|
|
n_seq = lstm_states[0].shape[1]
|
|
# Batch to sequence
|
|
# (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim)
|
|
# note: max length (max sequence length) is always 1 during data collection
|
|
features_sequence = features.reshape((n_seq, -1, lstm.input_size)).swapaxes(0, 1)
|
|
episode_starts = episode_starts.reshape((n_seq, -1)).swapaxes(0, 1)
|
|
|
|
# If we don't have to reset the state in the middle of a sequence
|
|
# we can avoid the for loop, which speeds up things
|
|
if th.all(episode_starts == 0.0):
|
|
lstm_output, lstm_states = lstm(features_sequence, lstm_states)
|
|
lstm_output = th.flatten(lstm_output.transpose(0, 1), start_dim=0, end_dim=1)
|
|
return lstm_output, lstm_states
|
|
|
|
lstm_output = []
|
|
# Iterate over the sequence
|
|
for features, episode_start in zip_strict(features_sequence, episode_starts):
|
|
hidden, lstm_states = lstm(
|
|
features.unsqueeze(dim=0),
|
|
(
|
|
# Reset the states at the beginning of a new episode
|
|
(1.0 - episode_start).view(1, n_seq, 1) * lstm_states[0],
|
|
(1.0 - episode_start).view(1, n_seq, 1) * lstm_states[1],
|
|
),
|
|
)
|
|
lstm_output += [hidden]
|
|
# Sequence to batch
|
|
# (sequence length, n_seq, lstm_out_dim) -> (batch_size, lstm_out_dim)
|
|
lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1)
|
|
return lstm_output, lstm_states
|
|
|
|
def forward(
|
|
self,
|
|
obs: th.Tensor,
|
|
lstm_states: RNNStates,
|
|
episode_starts: th.Tensor,
|
|
deterministic: bool = False,
|
|
) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]:
|
|
"""
|
|
Forward pass in all the networks (actor and critic)
|
|
|
|
:param obs: Observation. Observation
|
|
:param lstm_states: The last hidden and memory states for the LSTM.
|
|
:param episode_starts: Whether the observations correspond to new episodes
|
|
or not (we reset the lstm states in that case).
|
|
:param deterministic: Whether to sample or use deterministic actions
|
|
:return: action, value and log probability of the action
|
|
"""
|
|
# Preprocess the observation if needed
|
|
features = self.extract_features(obs)
|
|
if self.share_features_extractor:
|
|
pi_features = vf_features = features # alis
|
|
else:
|
|
pi_features, vf_features = features
|
|
# latent_pi, latent_vf = self.mlp_extractor(features)
|
|
latent_pi, lstm_states_pi = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor)
|
|
if self.lstm_critic is not None:
|
|
latent_vf, lstm_states_vf = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic)
|
|
elif self.shared_lstm:
|
|
# Re-use LSTM features but do not backpropagate
|
|
latent_vf = latent_pi.detach()
|
|
lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach())
|
|
else:
|
|
# Critic only has a feedforward network
|
|
latent_vf = self.critic(vf_features)
|
|
lstm_states_vf = lstm_states_pi
|
|
|
|
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
|
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
|
|
|
|
# Evaluate the values for the given observations
|
|
values = self.value_net(latent_vf)
|
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
|
actions = distribution.get_actions(deterministic=deterministic)
|
|
log_prob = distribution.log_prob(actions)
|
|
return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf)
|
|
|
|
def get_distribution(
|
|
self,
|
|
obs: th.Tensor,
|
|
lstm_states: Tuple[th.Tensor, th.Tensor],
|
|
episode_starts: th.Tensor,
|
|
) -> Tuple[Distribution, Tuple[th.Tensor, ...]]:
|
|
"""
|
|
Get the current policy distribution given the observations.
|
|
|
|
:param obs: Observation.
|
|
:param lstm_states: The last hidden and memory states for the LSTM.
|
|
:param episode_starts: Whether the observations correspond to new episodes
|
|
or not (we reset the lstm states in that case).
|
|
:return: the action distribution and new hidden states.
|
|
"""
|
|
# Call the method from the parent of the parent class
|
|
features = super(ActorCriticPolicy, self).extract_features(obs, self.pi_features_extractor)
|
|
latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor)
|
|
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
|
return self._get_action_dist_from_latent(latent_pi), lstm_states
|
|
|
|
def predict_values(
|
|
self,
|
|
obs: th.Tensor,
|
|
lstm_states: Tuple[th.Tensor, th.Tensor],
|
|
episode_starts: th.Tensor,
|
|
) -> th.Tensor:
|
|
"""
|
|
Get the estimated values according to the current policy given the observations.
|
|
|
|
:param obs: Observation.
|
|
:param lstm_states: The last hidden and memory states for the LSTM.
|
|
:param episode_starts: Whether the observations correspond to new episodes
|
|
or not (we reset the lstm states in that case).
|
|
:return: the estimated values.
|
|
"""
|
|
# Call the method from the parent of the parent class
|
|
features = super(ActorCriticPolicy, self).extract_features(obs, self.vf_features_extractor)
|
|
|
|
if self.lstm_critic is not None:
|
|
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic)
|
|
elif self.shared_lstm:
|
|
# Use LSTM from the actor
|
|
latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor)
|
|
latent_vf = latent_pi.detach()
|
|
else:
|
|
latent_vf = self.critic(features)
|
|
|
|
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
|
|
return self.value_net(latent_vf)
|
|
|
|
def evaluate_actions(
|
|
self, obs: th.Tensor, actions: th.Tensor, lstm_states: RNNStates, episode_starts: th.Tensor
|
|
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
|
"""
|
|
Evaluate actions according to the current policy,
|
|
given the observations.
|
|
|
|
:param obs: Observation.
|
|
:param actions:
|
|
:param lstm_states: The last hidden and memory states for the LSTM.
|
|
:param episode_starts: Whether the observations correspond to new episodes
|
|
or not (we reset the lstm states in that case).
|
|
:return: estimated value, log likelihood of taking those actions
|
|
and entropy of the action distribution.
|
|
"""
|
|
# Preprocess the observation if needed
|
|
features = self.extract_features(obs)
|
|
if self.share_features_extractor:
|
|
pi_features = vf_features = features # alias
|
|
else:
|
|
pi_features, vf_features = features
|
|
latent_pi, _ = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor)
|
|
if self.lstm_critic is not None:
|
|
latent_vf, _ = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic)
|
|
elif self.shared_lstm:
|
|
latent_vf = latent_pi.detach()
|
|
else:
|
|
latent_vf = self.critic(vf_features)
|
|
|
|
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
|
|
latent_vf = self.mlp_extractor.forward_critic(latent_vf)
|
|
|
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
|
log_prob = distribution.log_prob(actions)
|
|
values = self.value_net(latent_vf)
|
|
return values, log_prob, distribution.entropy()
|
|
|
|
def _predict(
|
|
self,
|
|
observation: th.Tensor,
|
|
lstm_states: Tuple[th.Tensor, th.Tensor],
|
|
episode_starts: th.Tensor,
|
|
deterministic: bool = False,
|
|
) -> Tuple[th.Tensor, Tuple[th.Tensor, ...]]:
|
|
"""
|
|
Get the action according to the policy for a given observation.
|
|
|
|
:param observation:
|
|
:param lstm_states: The last hidden and memory states for the LSTM.
|
|
:param episode_starts: Whether the observations correspond to new episodes
|
|
or not (we reset the lstm states in that case).
|
|
:param deterministic: Whether to use stochastic or deterministic actions
|
|
:return: Taken action according to the policy and hidden states of the RNN
|
|
"""
|
|
distribution, lstm_states = self.get_distribution(observation, lstm_states, episode_starts)
|
|
return distribution.get_actions(deterministic=deterministic), lstm_states
|
|
|
|
def predict(
|
|
self,
|
|
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
|
state: Optional[Tuple[np.ndarray, ...]] = None,
|
|
episode_start: Optional[np.ndarray] = None,
|
|
deterministic: bool = False,
|
|
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
|
"""
|
|
Get the policy action from an observation (and optional hidden state).
|
|
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
|
|
|
:param observation: the input observation
|
|
:param lstm_states: The last hidden and memory states for the LSTM.
|
|
:param episode_starts: Whether the observations correspond to new episodes
|
|
or not (we reset the lstm states in that case).
|
|
:param deterministic: Whether or not to return deterministic actions.
|
|
:return: the model's action and the next hidden state
|
|
(used in recurrent policies)
|
|
"""
|
|
# Switch to eval mode (this affects batch norm / dropout)
|
|
self.set_training_mode(False)
|
|
|
|
observation, vectorized_env = self.obs_to_tensor(observation)
|
|
|
|
if isinstance(observation, dict):
|
|
n_envs = observation[next(iter(observation.keys()))].shape[0]
|
|
else:
|
|
n_envs = observation.shape[0]
|
|
# state : (n_layers, n_envs, dim)
|
|
if state is None:
|
|
# Initialize hidden states to zeros
|
|
state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1)
|
|
state = (state, state)
|
|
|
|
if episode_start is None:
|
|
episode_start = np.array([False for _ in range(n_envs)])
|
|
|
|
with th.no_grad():
|
|
# Convert to PyTorch tensors
|
|
states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor(
|
|
state[1], dtype=th.float32, device=self.device
|
|
)
|
|
episode_starts = th.tensor(episode_start, dtype=th.float32, device=self.device)
|
|
actions, states = self._predict(
|
|
observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic
|
|
)
|
|
states = (states[0].cpu().numpy(), states[1].cpu().numpy())
|
|
|
|
# Convert to numpy
|
|
actions = actions.cpu().numpy()
|
|
|
|
if isinstance(self.action_space, spaces.Box):
|
|
if self.squash_output:
|
|
# Rescale to proper domain when using squashing
|
|
actions = self.unscale_action(actions)
|
|
else:
|
|
# Actions could be on arbitrary scale, so clip the actions to avoid
|
|
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
|
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
|
|
|
# Remove batch dimension if needed
|
|
if not vectorized_env:
|
|
actions = actions.squeeze(axis=0)
|
|
|
|
return actions, states
|
|
|
|
|
|
class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
|
"""
|
|
CNN recurrent policy class for actor-critic algorithms (has both policy and value prediction).
|
|
Used by A2C, PPO and the likes.
|
|
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param lr_schedule: Learning rate schedule (could be constant)
|
|
:param net_arch: The specification of the policy and value networks.
|
|
:param activation_fn: Activation function
|
|
:param ortho_init: Whether to use or not orthogonal initialization
|
|
:param use_sde: Whether to use State Dependent Exploration or not
|
|
:param log_std_init: Initial value for the log standard deviation
|
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
|
for the std instead of only (n_features,) when using gSDE
|
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
|
a positive standard deviation (cf paper). It allows to keep variance
|
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
|
:param squash_output: Whether to squash the output using a tanh function,
|
|
this allows to ensure boundaries when using gSDE.
|
|
:param features_extractor_class: Features extractor to use.
|
|
:param features_extractor_kwargs: Keyword arguments
|
|
to pass to the features extractor.
|
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param optimizer_class: The optimizer to use,
|
|
``th.optim.Adam`` by default
|
|
:param optimizer_kwargs: Additional keyword arguments,
|
|
excluding the learning rate, to pass to the optimizer
|
|
:param lstm_hidden_size: Number of hidden units for each LSTM layer.
|
|
:param n_lstm_layers: Number of LSTM layers.
|
|
:param shared_lstm: Whether the LSTM is shared between the actor and the critic.
|
|
By default, only the actor has a recurrent network.
|
|
:param enable_critic_lstm: Use a seperate LSTM for the critic.
|
|
:param lstm_kwargs: Additional keyword arguments to pass the the LSTM
|
|
constructor.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
lr_schedule: Schedule,
|
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
|
ortho_init: bool = True,
|
|
use_sde: bool = False,
|
|
log_std_init: float = 0.0,
|
|
full_std: bool = True,
|
|
use_expln: bool = False,
|
|
squash_output: bool = False,
|
|
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
|
share_features_extractor: bool = True,
|
|
normalize_images: bool = True,
|
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
lstm_hidden_size: int = 256,
|
|
n_lstm_layers: int = 1,
|
|
shared_lstm: bool = False,
|
|
enable_critic_lstm: bool = True,
|
|
lstm_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
super().__init__(
|
|
observation_space,
|
|
action_space,
|
|
lr_schedule,
|
|
net_arch,
|
|
activation_fn,
|
|
ortho_init,
|
|
use_sde,
|
|
log_std_init,
|
|
full_std,
|
|
use_expln,
|
|
squash_output,
|
|
features_extractor_class,
|
|
features_extractor_kwargs,
|
|
share_features_extractor,
|
|
normalize_images,
|
|
optimizer_class,
|
|
optimizer_kwargs,
|
|
lstm_hidden_size,
|
|
n_lstm_layers,
|
|
shared_lstm,
|
|
enable_critic_lstm,
|
|
lstm_kwargs,
|
|
)
|
|
|
|
|
|
class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
|
"""
|
|
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
|
|
Used by A2C, PPO and the likes.
|
|
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param lr_schedule: Learning rate schedule (could be constant)
|
|
:param net_arch: The specification of the policy and value networks.
|
|
:param activation_fn: Activation function
|
|
:param ortho_init: Whether to use or not orthogonal initialization
|
|
:param use_sde: Whether to use State Dependent Exploration or not
|
|
:param log_std_init: Initial value for the log standard deviation
|
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
|
for the std instead of only (n_features,) when using gSDE
|
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
|
a positive standard deviation (cf paper). It allows to keep variance
|
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
|
:param squash_output: Whether to squash the output using a tanh function,
|
|
this allows to ensure boundaries when using gSDE.
|
|
:param features_extractor_class: Features extractor to use.
|
|
:param features_extractor_kwargs: Keyword arguments
|
|
to pass to the features extractor.
|
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param optimizer_class: The optimizer to use,
|
|
``th.optim.Adam`` by default
|
|
:param optimizer_kwargs: Additional keyword arguments,
|
|
excluding the learning rate, to pass to the optimizer
|
|
:param lstm_hidden_size: Number of hidden units for each LSTM layer.
|
|
:param n_lstm_layers: Number of LSTM layers.
|
|
:param shared_lstm: Whether the LSTM is shared between the actor and the critic.
|
|
By default, only the actor has a recurrent network.
|
|
:param enable_critic_lstm: Use a seperate LSTM for the critic.
|
|
:param lstm_kwargs: Additional keyword arguments to pass the the LSTM
|
|
constructor.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
lr_schedule: Schedule,
|
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
|
ortho_init: bool = True,
|
|
use_sde: bool = False,
|
|
log_std_init: float = 0.0,
|
|
full_std: bool = True,
|
|
use_expln: bool = False,
|
|
squash_output: bool = False,
|
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
|
share_features_extractor: bool = True,
|
|
normalize_images: bool = True,
|
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
lstm_hidden_size: int = 256,
|
|
n_lstm_layers: int = 1,
|
|
shared_lstm: bool = False,
|
|
enable_critic_lstm: bool = True,
|
|
lstm_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
super().__init__(
|
|
observation_space,
|
|
action_space,
|
|
lr_schedule,
|
|
net_arch,
|
|
activation_fn,
|
|
ortho_init,
|
|
use_sde,
|
|
log_std_init,
|
|
full_std,
|
|
use_expln,
|
|
squash_output,
|
|
features_extractor_class,
|
|
features_extractor_kwargs,
|
|
share_features_extractor,
|
|
normalize_images,
|
|
optimizer_class,
|
|
optimizer_kwargs,
|
|
lstm_hidden_size,
|
|
n_lstm_layers,
|
|
shared_lstm,
|
|
enable_critic_lstm,
|
|
lstm_kwargs,
|
|
)
|