20 lines
736 B
Python
20 lines
736 B
Python
import numpy as np
|
|
import pytest
|
|
import torch as th
|
|
|
|
from sb3_contrib.common.utils import quantile_huber_loss
|
|
|
|
|
|
def test_quantile_huber_loss():
|
|
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10)), 2.5)
|
|
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10), sum_over_quantiles=False), 0.25)
|
|
|
|
with pytest.raises(ValueError):
|
|
quantile_huber_loss(th.zeros(1, 4, 4), th.zeros(1, 4))
|
|
with pytest.raises(ValueError):
|
|
quantile_huber_loss(th.zeros(1, 4), th.zeros(1, 1, 4))
|
|
with pytest.raises(ValueError):
|
|
quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4))
|
|
with pytest.raises(ValueError):
|
|
quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4))
|