65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
import numpy as np
|
|
import pytest
|
|
import torch as th
|
|
from stable_baselines3.common.utils import set_random_seed
|
|
|
|
from sb3_contrib import TRPO
|
|
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad, quantile_huber_loss
|
|
|
|
|
|
def test_quantile_huber_loss():
|
|
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10)), 2.5)
|
|
assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10), sum_over_quantiles=False), 0.25)
|
|
|
|
with pytest.raises(ValueError):
|
|
quantile_huber_loss(th.zeros(1, 4, 4), th.zeros(1, 4))
|
|
with pytest.raises(ValueError):
|
|
quantile_huber_loss(th.zeros(1, 4), th.zeros(1, 1, 4))
|
|
with pytest.raises(ValueError):
|
|
quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4))
|
|
with pytest.raises(ValueError):
|
|
quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4))
|
|
|
|
|
|
def test_cg():
|
|
# Test that conjugate gradient can actually solve
|
|
# Ax = b when the A^-1 is known
|
|
set_random_seed(4)
|
|
A = th.ones(3, 3)
|
|
# Symmetric matrix
|
|
A[0, 1] = 2
|
|
A[1, 0] = 2
|
|
x = th.ones(3) + th.rand(3)
|
|
b = A @ x
|
|
|
|
def matrix_vector_dot_func(vector):
|
|
return A @ vector
|
|
|
|
x_approx = conjugate_gradient_solver(matrix_vector_dot_func, b, max_iter=5, residual_tol=1e-10)
|
|
assert th.allclose(x_approx, x)
|
|
|
|
|
|
def test_flat_grad():
|
|
n_parameters = 12 # 3 * (2 * 2)
|
|
x = th.nn.Parameter(th.ones(2, 2, requires_grad=True))
|
|
y = (x ** 2).sum()
|
|
flat_grad_out = flat_grad(y, [x, x, x])
|
|
assert len(flat_grad_out.shape) == 1
|
|
# dy/dx = 2
|
|
assert th.allclose(flat_grad_out, th.ones(n_parameters) * 2)
|
|
|
|
|
|
def test_trpo_warnings():
|
|
"""Test that TRPO warns and errors correctly on
|
|
problematic rollout buffer sizes"""
|
|
|
|
# Only 1 step: advantage normalization will return NaN
|
|
with pytest.raises(AssertionError):
|
|
TRPO("MlpPolicy", "Pendulum-v0", n_steps=1)
|
|
# One step not advantage normalization: ok
|
|
TRPO("MlpPolicy", "Pendulum-v0", n_steps=1, normalize_advantage=False, batch_size=1)
|
|
|
|
# Truncated mini-batch
|
|
with pytest.warns(UserWarning):
|
|
TRPO("MlpPolicy", "Pendulum-v0", n_steps=6, batch_size=8)
|