70 lines
3.6 KiB
Python
70 lines
3.6 KiB
Python
from typing import Optional
|
|
|
|
import torch as th
|
|
|
|
|
|
def quantile_huber_loss(
|
|
current_quantiles: th.Tensor,
|
|
target_quantiles: th.Tensor,
|
|
cum_prob: Optional[th.Tensor] = None,
|
|
sum_over_quantiles: bool = True,
|
|
) -> th.Tensor:
|
|
"""
|
|
The quantile-regression loss, as described in the QR-DQN and TQC papers.
|
|
Partially taken from https://github.com/bayesgroup/tqc_pytorch.
|
|
|
|
:param current_quantiles: current estimate of quantiles, must be either
|
|
(batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles)
|
|
:param target_quantiles: target of quantiles, must be either (batch_size, n_target_quantiles),
|
|
(batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles)
|
|
:param cum_prob: cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper),
|
|
must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles).
|
|
(if None, calculating unit quantiles)
|
|
:param sum_over_quantiles: if summing over the quantile dimension or not
|
|
:return: the loss
|
|
"""
|
|
if current_quantiles.ndim != target_quantiles.ndim:
|
|
raise ValueError(
|
|
f"Error: The dimension of curremt_quantile ({current_quantiles.ndim}) needs to match "
|
|
f"the dimension of target_quantiles ({target_quantiles.ndim})."
|
|
)
|
|
if current_quantiles.shape[0] != target_quantiles.shape[0]:
|
|
raise ValueError(
|
|
f"Error: The batch size of curremt_quantile ({current_quantiles.shape[0]}) needs to match "
|
|
f"the batch size of target_quantiles ({target_quantiles.shape[0]})."
|
|
)
|
|
if current_quantiles.ndim not in (2, 3):
|
|
raise ValueError(f"Error: The dimension of current_quantiles ({current_quantiles.ndim}) needs to be either 2 or 3.")
|
|
|
|
if cum_prob is None:
|
|
n_quantiles = current_quantiles.shape[-1]
|
|
# Cumulative probabilities to calculate quantiles.
|
|
cum_prob = (th.arange(n_quantiles, device=current_quantiles.device, dtype=th.float) + 0.5) / n_quantiles
|
|
if current_quantiles.ndim == 2:
|
|
# For QR-DQN, current_quantiles have a shape (batch_size, n_quantiles), and make cum_prob
|
|
# broadcastable to (batch_size, n_quantiles, n_target_quantiles)
|
|
cum_prob = cum_prob.view(1, -1, 1)
|
|
elif current_quantiles.ndim == 3:
|
|
# For TQC, current_quantiles have a shape (batch_size, n_critics, n_quantiles), and make cum_prob
|
|
# broadcastable to (batch_size, n_critics, n_quantiles, n_target_quantiles)
|
|
cum_prob = cum_prob.view(1, 1, -1, 1)
|
|
|
|
# QR-DQN
|
|
# target_quantiles: (batch_size, n_target_quantiles) -> (batch_size, 1, n_target_quantiles)
|
|
# current_quantiles: (batch_size, n_quantiles) -> (batch_size, n_quantiles, 1)
|
|
# pairwise_delta: (batch_size, n_target_quantiles, n_quantiles)
|
|
# TQC
|
|
# target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles)
|
|
# current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1)
|
|
# pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles)
|
|
# Note: in both cases, the loss has the same shape as pairwise_delta
|
|
pairwise_delta = target_quantiles.unsqueeze(-2) - current_quantiles.unsqueeze(-1)
|
|
abs_pairwise_delta = th.abs(pairwise_delta)
|
|
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
|
|
loss = th.abs(cum_prob - (pairwise_delta.detach() < 0).float()) * huber_loss
|
|
if sum_over_quantiles:
|
|
loss = loss.sum(dim=-2).mean()
|
|
else:
|
|
loss = loss.mean()
|
|
return loss
|