stable-baselines3-contrib-sacd/tests/test_utils.py

20 lines
736 B
Python

import numpy as np
import pytest
import torch as th
from sb3_contrib.common.utils import quantile_huber_loss
def test_quantile_huber_loss():
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10)), 2.5)
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10), sum_over_quantiles=False), 0.25)
with pytest.raises(ValueError):
quantile_huber_loss(th.zeros(1, 4, 4), th.zeros(1, 4))
with pytest.raises(ValueError):
quantile_huber_loss(th.zeros(1, 4), th.zeros(1, 1, 4))
with pytest.raises(ValueError):
quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4))
with pytest.raises(ValueError):
quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4))