stable-baselines3-contrib-sacd/sb3_contrib/common/utils.py

166 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Callable, Optional, Sequence
import torch as th
from torch import nn
def quantile_huber_loss(
current_quantiles: th.Tensor,
target_quantiles: th.Tensor,
cum_prob: Optional[th.Tensor] = None,
sum_over_quantiles: bool = True,
) -> th.Tensor:
"""
The quantile-regression loss, as described in the QR-DQN and TQC papers.
Partially taken from https://github.com/bayesgroup/tqc_pytorch.
:param current_quantiles: current estimate of quantiles, must be either
(batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles)
:param target_quantiles: target of quantiles, must be either (batch_size, n_target_quantiles),
(batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles)
:param cum_prob: cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper),
must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles).
(if None, calculating unit quantiles)
:param sum_over_quantiles: if summing over the quantile dimension or not
:return: the loss
"""
if current_quantiles.ndim != target_quantiles.ndim:
raise ValueError(
f"Error: The dimension of curremt_quantile ({current_quantiles.ndim}) needs to match "
f"the dimension of target_quantiles ({target_quantiles.ndim})."
)
if current_quantiles.shape[0] != target_quantiles.shape[0]:
raise ValueError(
f"Error: The batch size of curremt_quantile ({current_quantiles.shape[0]}) needs to match "
f"the batch size of target_quantiles ({target_quantiles.shape[0]})."
)
if current_quantiles.ndim not in (2, 3):
raise ValueError(f"Error: The dimension of current_quantiles ({current_quantiles.ndim}) needs to be either 2 or 3.")
if cum_prob is None:
n_quantiles = current_quantiles.shape[-1]
# Cumulative probabilities to calculate quantiles.
cum_prob = (th.arange(n_quantiles, device=current_quantiles.device, dtype=th.float) + 0.5) / n_quantiles
if current_quantiles.ndim == 2:
# For QR-DQN, current_quantiles have a shape (batch_size, n_quantiles), and make cum_prob
# broadcastable to (batch_size, n_quantiles, n_target_quantiles)
cum_prob = cum_prob.view(1, -1, 1)
elif current_quantiles.ndim == 3:
# For TQC, current_quantiles have a shape (batch_size, n_critics, n_quantiles), and make cum_prob
# broadcastable to (batch_size, n_critics, n_quantiles, n_target_quantiles)
cum_prob = cum_prob.view(1, 1, -1, 1)
# QR-DQN
# target_quantiles: (batch_size, n_target_quantiles) -> (batch_size, 1, n_target_quantiles)
# current_quantiles: (batch_size, n_quantiles) -> (batch_size, n_quantiles, 1)
# pairwise_delta: (batch_size, n_target_quantiles, n_quantiles)
# TQC
# target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles)
# current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1)
# pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles)
# Note: in both cases, the loss has the same shape as pairwise_delta
pairwise_delta = target_quantiles.unsqueeze(-2) - current_quantiles.unsqueeze(-1)
abs_pairwise_delta = th.abs(pairwise_delta)
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5)
loss = th.abs(cum_prob - (pairwise_delta.detach() < 0).float()) * huber_loss
if sum_over_quantiles:
loss = loss.sum(dim=-2).mean()
else:
loss = loss.mean()
return loss
def conjugate_gradient_solver(
matrix_vector_dot_fn: Callable[[th.Tensor], th.Tensor],
b,
max_iter=10,
residual_tol=1e-10,
) -> th.Tensor:
"""
Finds an approximate solution to a set of linear equations Ax = b
Sources:
- https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py
- https://github.com/joschu/modular_rl/blob/master/modular_rl/trpo.py#L122
Reference:
- https://epubs.siam.org/doi/abs/10.1137/1.9781611971446.ch6
:param matrix_vector_dot_fn:
a function that right multiplies a matrix A by a vector v
:param b:
the right hand term in the set of linear equations Ax = b
:param max_iter:
the maximum number of iterations (default is 10)
:param residual_tol:
residual tolerance for early stopping of the solving (default is 1e-10)
:return x:
the approximate solution to the system of equations defined by `matrix_vector_dot_fn`
and b
"""
# The vector is not initialized at 0 because of the instability issues when the gradient becomes small.
# A small random gaussian noise is used for the initialization.
x = 1e-4 * th.randn_like(b)
residual = b - matrix_vector_dot_fn(x)
# Equivalent to th.linalg.norm(residual) ** 2 (L2 norm squared)
residual_squared_norm = th.matmul(residual, residual)
if residual_squared_norm < residual_tol:
# If the gradient becomes extremely small
# The denominator in alpha will become zero
# Leading to a division by zero
return x
p = residual.clone()
for i in range(max_iter):
# A @ p (matrix vector multiplication)
A_dot_p = matrix_vector_dot_fn(p)
alpha = residual_squared_norm / p.dot(A_dot_p)
x += alpha * p
if i == max_iter - 1:
return x
residual -= alpha * A_dot_p
new_residual_squared_norm = th.matmul(residual, residual)
if new_residual_squared_norm < residual_tol:
return x
beta = new_residual_squared_norm / residual_squared_norm
residual_squared_norm = new_residual_squared_norm
p = residual + beta * p
# Note: this return statement is only used when max_iter=0
return x
def flat_grad(
output,
parameters: Sequence[nn.parameter.Parameter],
create_graph: bool = False,
retain_graph: bool = False,
) -> th.Tensor:
"""
Returns the gradients of the passed sequence of parameters into a flat gradient.
Order of parameters is preserved.
:param output: functional output to compute the gradient for
:param parameters: sequence of ``Parameter``
:param retain_graph: If ``False``, the graph used to compute the grad will be freed.
Defaults to the value of ``create_graph``.
:param create_graph: If ``True``, graph of the derivative will be constructed,
allowing to compute higher order derivative products. Default: ``False``.
:return: Tensor containing the flattened gradients
"""
grads = th.autograd.grad(
output,
parameters,
create_graph=create_graph,
retain_graph=retain_graph,
allow_unused=True,
)
return th.cat([th.ravel(grad) for grad in grads if grad is not None])