317 lines
10 KiB
Python
317 lines
10 KiB
Python
import numpy as np
|
|
import pytest
|
|
import torch as th
|
|
|
|
from sb3_contrib.common.maskable.distributions import (
|
|
MaskableBernoulliDistribution,
|
|
MaskableCategorical,
|
|
MaskableCategoricalDistribution,
|
|
MaskableMultiCategoricalDistribution,
|
|
)
|
|
|
|
|
|
class TestMaskableCategorical:
|
|
def test_applying_mask(self):
|
|
"""
|
|
Show that probs change as a result of masking
|
|
"""
|
|
|
|
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
|
|
expected_probs = th.Tensor([[0, 0.25, 0.75], [0, 0.5, 0.5]])
|
|
mask = np.array([[False, True, True], [False, True, True]])
|
|
|
|
distribution = MaskableCategorical(probs=starting_probs)
|
|
distribution.apply_masking(mask)
|
|
assert th.allclose(distribution.probs, expected_probs)
|
|
|
|
def test_modifying_mask(self):
|
|
"""
|
|
Show that masks apply independently of each other
|
|
"""
|
|
|
|
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
|
|
expected_probs = th.Tensor([[0.5, 0.5, 0], [0, 1, 0]])
|
|
first_mask = np.array([[False, True, True], [False, True, True]])
|
|
second_mask = np.array([[True, True, False], [False, True, False]])
|
|
|
|
# pytorch converts probs to logits in a way that loses some precision and makes
|
|
# 0 probability outcomes slightly non-zero.
|
|
atol = 2e-07
|
|
distribution = MaskableCategorical(probs=starting_probs)
|
|
assert th.allclose(distribution.probs, starting_probs, atol=atol)
|
|
|
|
target_distribution = MaskableCategorical(probs=expected_probs)
|
|
|
|
distribution.apply_masking(first_mask)
|
|
distribution.apply_masking(second_mask)
|
|
|
|
assert th.allclose(distribution.probs, target_distribution.probs, atol=atol)
|
|
|
|
def test_removing_mask(self):
|
|
"""
|
|
Show that masking may be unapplied to recover original probs
|
|
"""
|
|
|
|
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
|
|
mask = np.array([[False, True, True], [False, True, True]])
|
|
|
|
distribution = MaskableCategorical(probs=starting_probs)
|
|
target_distribution = MaskableCategorical(probs=starting_probs)
|
|
distribution.apply_masking(mask)
|
|
distribution.apply_masking(None)
|
|
assert th.allclose(distribution.probs, target_distribution.probs)
|
|
|
|
def test_masking_affects_entropy(self):
|
|
# All outcomes equally likely
|
|
NUM_DIMS = 3
|
|
logits = th.Tensor([[0] * NUM_DIMS])
|
|
dist = MaskableCategorical(logits=logits)
|
|
|
|
# For each possible number of valid actions v, show that e^entropy == v
|
|
for v in range(1, NUM_DIMS + 1):
|
|
masks = [j < v for j in range(NUM_DIMS)]
|
|
dist.apply_masking(masks)
|
|
assert int(dist.entropy().exp()) == v
|
|
|
|
|
|
class TestMaskableCategoricalDistribution:
|
|
def test_distribution_must_be_initialized(self):
|
|
"""
|
|
Cannot use distribution before it has logits
|
|
"""
|
|
|
|
DIMS = 2
|
|
dist = MaskableCategoricalDistribution(DIMS)
|
|
with pytest.raises(AssertionError):
|
|
dist.log_prob(th.randint(DIMS - 1, (1, 3)))
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.entropy()
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.sample()
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.mode()
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.apply_masking(None)
|
|
|
|
# But now we can
|
|
dist.proba_distribution(th.randn(1, DIMS))
|
|
dist.log_prob(th.randint(DIMS - 1, (3, 1)))
|
|
dist.entropy()
|
|
dist.sample()
|
|
dist.mode()
|
|
dist.apply_masking(None)
|
|
|
|
def test_logits_must_align_with_dims(self):
|
|
NUM_DIMS = 3
|
|
dist = MaskableCategoricalDistribution(NUM_DIMS)
|
|
|
|
# There should be one logit per dim, we're one short
|
|
logits = th.randn(1, NUM_DIMS - 1)
|
|
with pytest.raises(RuntimeError):
|
|
dist.proba_distribution(logits)
|
|
|
|
# That's better
|
|
logits = th.randn(1, NUM_DIMS)
|
|
dist.proba_distribution(logits)
|
|
|
|
# Other numbers of dimensions are acceptable as long as they can be realigned
|
|
logits = th.randn(NUM_DIMS)
|
|
dist.proba_distribution(logits)
|
|
logits = th.randn(3, NUM_DIMS, 3)
|
|
dist.proba_distribution(logits)
|
|
|
|
def test_dim_masking(self):
|
|
NUM_DIMS = 2
|
|
dist = MaskableCategoricalDistribution(NUM_DIMS)
|
|
|
|
logits = th.Tensor([[0] * NUM_DIMS])
|
|
dist.proba_distribution(logits)
|
|
|
|
assert (dist.distribution.probs == 0.5).all()
|
|
assert int(dist.entropy().exp()) == NUM_DIMS
|
|
|
|
for i in range(NUM_DIMS):
|
|
mask = np.array([False] * NUM_DIMS)
|
|
mask[i] = True
|
|
dist.apply_masking(mask)
|
|
probs = dist.distribution.probs
|
|
assert probs.sum() == 1
|
|
assert probs[0][i] == 1
|
|
assert int(dist.entropy().exp()) == 1
|
|
|
|
dist.apply_masking(None)
|
|
assert (dist.distribution.probs == 0.5).all()
|
|
assert int(dist.entropy().exp()) == NUM_DIMS
|
|
|
|
|
|
class TestMaskableMultiCategoricalDistribution:
|
|
def test_distribution_must_be_initialized(self):
|
|
"""
|
|
Cannot use distribution before it has logits
|
|
"""
|
|
|
|
DIMS_PER_CAT = 2
|
|
NUM_CATS = 2
|
|
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS)))
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.entropy()
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.sample()
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.mode()
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.apply_masking(None)
|
|
|
|
# But now we can
|
|
dist.proba_distribution(th.randn(1, DIMS_PER_CAT * NUM_CATS))
|
|
dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS)))
|
|
dist.entropy()
|
|
dist.sample()
|
|
dist.mode()
|
|
dist.apply_masking(None)
|
|
|
|
def test_logits_must_align_with_dims(self):
|
|
DIMS_PER_CAT = 3
|
|
NUM_CATS = 2
|
|
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
|
|
|
|
# There should be one logit per dim, we're one short
|
|
logits = th.randn(1, DIMS_PER_CAT * NUM_CATS - 1)
|
|
with pytest.raises(RuntimeError):
|
|
dist.proba_distribution(logits)
|
|
|
|
# That's better
|
|
logits = th.randn(1, DIMS_PER_CAT * NUM_CATS)
|
|
dist.proba_distribution(logits)
|
|
|
|
# Other numbers of dimensions are acceptable as long as they can be realigned
|
|
logits = th.randn(DIMS_PER_CAT * NUM_CATS)
|
|
dist.proba_distribution(logits)
|
|
logits = th.randn(3, DIMS_PER_CAT * NUM_CATS, 3)
|
|
dist.proba_distribution(logits)
|
|
|
|
def test_dim_masking(self):
|
|
DIMS_PER_CAT = 2
|
|
NUM_CATS = 3
|
|
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
|
|
|
|
logits = th.Tensor([[0] * DIMS_PER_CAT * NUM_CATS])
|
|
dist.proba_distribution(logits)
|
|
|
|
assert len(dist.distributions) == NUM_CATS
|
|
for i in range(NUM_CATS):
|
|
assert (dist.distributions[i].probs == 0.5).all()
|
|
assert int(dist.entropy().exp()) == DIMS_PER_CAT ** NUM_CATS
|
|
|
|
for i in range(DIMS_PER_CAT):
|
|
mask = np.array([False] * DIMS_PER_CAT * NUM_CATS)
|
|
for j in range(NUM_CATS):
|
|
mask[j * DIMS_PER_CAT + i] = True
|
|
|
|
dist.apply_masking(mask)
|
|
for j in range(NUM_CATS):
|
|
probs = dist.distributions[j].probs
|
|
assert probs.sum() == 1
|
|
assert probs[0][i] == 1
|
|
|
|
assert int(dist.entropy().exp()) == 1
|
|
|
|
dist.apply_masking(None)
|
|
for i in range(NUM_CATS):
|
|
assert (dist.distributions[i].probs == 0.5).all()
|
|
assert int(dist.entropy().exp()) == DIMS_PER_CAT ** NUM_CATS
|
|
|
|
|
|
class TestMaskableBernoulliDistribution:
|
|
def test_distribution_must_be_initialized(self):
|
|
"""
|
|
Cannot use distribution before it has logits
|
|
"""
|
|
|
|
DIMS = 2
|
|
dist = MaskableBernoulliDistribution(DIMS)
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.log_prob(th.randint(1, (2, DIMS)))
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.entropy()
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.sample()
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.mode()
|
|
|
|
with pytest.raises(AssertionError):
|
|
dist.apply_masking(None)
|
|
|
|
# But now we can
|
|
dist.proba_distribution(th.randn(1, 2 * DIMS))
|
|
dist.log_prob(th.randint(1, (2, DIMS)))
|
|
dist.entropy()
|
|
dist.sample()
|
|
dist.mode()
|
|
dist.apply_masking(None)
|
|
|
|
def test_logits_must_align_with_dims(self):
|
|
NUM_DIMS = 3
|
|
dist = MaskableBernoulliDistribution(NUM_DIMS)
|
|
|
|
# There should be two logits per dim, we're one short
|
|
logits = th.randn(1, 2 * NUM_DIMS - 1)
|
|
with pytest.raises(RuntimeError):
|
|
dist.proba_distribution(logits)
|
|
|
|
# That's better
|
|
logits = th.randn(1, 2 * NUM_DIMS)
|
|
dist.proba_distribution(logits)
|
|
|
|
# Other numbers of dimensions are acceptable as long as they can be realigned
|
|
logits = th.randn(2 * NUM_DIMS)
|
|
dist.proba_distribution(logits)
|
|
logits = th.randn(3, 2 * NUM_DIMS, 3)
|
|
dist.proba_distribution(logits)
|
|
|
|
def test_dim_masking(self):
|
|
NUM_DIMS = 2
|
|
BINARY_STATES = 2
|
|
dist = MaskableBernoulliDistribution(NUM_DIMS)
|
|
|
|
logits = th.Tensor([[0] * BINARY_STATES * NUM_DIMS])
|
|
dist.proba_distribution(logits)
|
|
|
|
assert len(dist.distributions) == NUM_DIMS
|
|
for i in range(NUM_DIMS):
|
|
assert (dist.distributions[i].probs == 0.5).all()
|
|
assert int(dist.entropy().exp()) == BINARY_STATES * NUM_DIMS
|
|
|
|
for i in range(BINARY_STATES):
|
|
mask = np.array([False] * BINARY_STATES * NUM_DIMS)
|
|
for j in range(NUM_DIMS):
|
|
mask[j * BINARY_STATES + i] = True
|
|
|
|
dist.apply_masking(mask)
|
|
for j in range(NUM_DIMS):
|
|
probs = dist.distributions[j].probs
|
|
assert probs.sum() == 1
|
|
assert probs[0][i] == 1
|
|
|
|
assert int(dist.entropy().exp()) == 1
|
|
|
|
dist.apply_masking(None)
|
|
for i in range(NUM_DIMS):
|
|
assert (dist.distributions[i].probs == 0.5).all()
|
|
assert int(dist.entropy().exp()) == BINARY_STATES * NUM_DIMS
|