stable-baselines3-contrib-sacd/tests/test_distributions.py

332 lines
11 KiB
Python

import numpy as np
import pytest
import torch as th
from sb3_contrib.common.maskable.distributions import (
MaskableBernoulliDistribution,
MaskableCategorical,
MaskableCategoricalDistribution,
MaskableMultiCategoricalDistribution,
)
class TestMaskableCategorical:
def test_applying_mask(self):
"""
Show that probs change as a result of masking
"""
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
expected_probs = th.Tensor([[0, 0.25, 0.75], [0, 0.5, 0.5]])
mask = np.array([[False, True, True], [False, True, True]])
distribution = MaskableCategorical(probs=starting_probs)
distribution.apply_masking(mask)
assert th.allclose(distribution.probs, expected_probs)
def test_modifying_mask(self):
"""
Show that masks apply independently of each other
"""
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
expected_probs = th.Tensor([[0.5, 0.5, 0], [0, 1, 0]])
first_mask = np.array([[False, True, True], [False, True, True]])
second_mask = np.array([[True, True, False], [False, True, False]])
# pytorch converts probs to logits in a way that loses some precision and makes
# 0 probability outcomes slightly non-zero.
atol = 2e-07
distribution = MaskableCategorical(probs=starting_probs)
assert th.allclose(distribution.probs, starting_probs, atol=atol)
target_distribution = MaskableCategorical(probs=expected_probs)
distribution.apply_masking(first_mask)
distribution.apply_masking(second_mask)
assert th.allclose(distribution.probs, target_distribution.probs, atol=atol)
def test_removing_mask(self):
"""
Show that masking may be unapplied to recover original probs
"""
starting_probs = th.Tensor([[0.2, 0.2, 0.6], [1, 0, 0]])
mask = np.array([[False, True, True], [False, True, True]])
distribution = MaskableCategorical(probs=starting_probs)
target_distribution = MaskableCategorical(probs=starting_probs)
distribution.apply_masking(mask)
distribution.apply_masking(None)
assert th.allclose(distribution.probs, target_distribution.probs)
def test_masking_affects_entropy(self):
# All outcomes equally likely
NUM_DIMS = 3
logits = th.Tensor([[0] * NUM_DIMS])
dist = MaskableCategorical(logits=logits)
# For each possible number of valid actions v, show that e^entropy == v
for v in range(1, NUM_DIMS + 1):
masks = [j < v for j in range(NUM_DIMS)]
dist.apply_masking(masks)
assert int(dist.entropy().exp()) == v
class TestMaskableCategoricalDistribution:
def test_distribution_must_be_initialized(self):
"""
Cannot use distribution before it has logits
"""
DIMS = 2
dist = MaskableCategoricalDistribution(DIMS)
with pytest.raises(AssertionError):
dist.log_prob(th.randint(DIMS - 1, (1, 3)))
with pytest.raises(AssertionError):
dist.entropy()
with pytest.raises(AssertionError):
dist.sample()
with pytest.raises(AssertionError):
dist.mode()
with pytest.raises(AssertionError):
dist.apply_masking(None)
# But now we can
action_logits = th.randn(1, DIMS)
dist.proba_distribution(action_logits)
actions = th.randint(DIMS - 1, (3, 1))
dist.log_prob(actions)
dist.entropy()
dist.sample()
dist.mode()
# Test api
dist.actions_from_params(action_logits)
dist.log_prob_from_params(action_logits)
dist.apply_masking(None)
def test_logits_must_align_with_dims(self):
NUM_DIMS = 3
dist = MaskableCategoricalDistribution(NUM_DIMS)
# There should be one logit per dim, we're one short
logits = th.randn(1, NUM_DIMS - 1)
with pytest.raises(RuntimeError):
dist.proba_distribution(logits)
# That's better
logits = th.randn(1, NUM_DIMS)
dist.proba_distribution(logits)
# Other numbers of dimensions are acceptable as long as they can be realigned
logits = th.randn(NUM_DIMS)
dist.proba_distribution(logits)
logits = th.randn(3, NUM_DIMS, 3)
dist.proba_distribution(logits)
def test_dim_masking(self):
NUM_DIMS = 2
dist = MaskableCategoricalDistribution(NUM_DIMS)
logits = th.Tensor([[0] * NUM_DIMS])
dist.proba_distribution(logits)
assert (dist.distribution.probs == 0.5).all()
assert int(dist.entropy().exp()) == NUM_DIMS
for i in range(NUM_DIMS):
mask = np.array([False] * NUM_DIMS)
mask[i] = True
dist.apply_masking(mask)
probs = dist.distribution.probs
assert probs.sum() == 1
assert probs[0][i] == 1
assert int(dist.entropy().exp()) == 1
dist.apply_masking(None)
assert (dist.distribution.probs == 0.5).all()
assert int(dist.entropy().exp()) == NUM_DIMS
class TestMaskableMultiCategoricalDistribution:
def test_distribution_must_be_initialized(self):
"""
Cannot use distribution before it has logits
"""
DIMS_PER_CAT = 2
NUM_CATS = 2
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
with pytest.raises(AssertionError):
dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS)))
with pytest.raises(AssertionError):
dist.entropy()
with pytest.raises(AssertionError):
dist.sample()
with pytest.raises(AssertionError):
dist.mode()
with pytest.raises(AssertionError):
dist.apply_masking(None)
# But now we can
action_logits = th.randn(1, DIMS_PER_CAT * NUM_CATS)
dist.proba_distribution(action_logits)
actions = th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS))
dist.log_prob(actions)
dist.entropy()
dist.sample()
dist.mode()
# Test api
dist.actions_from_params(action_logits)
dist.log_prob_from_params(action_logits)
dist.apply_masking(None)
def test_logits_must_align_with_dims(self):
DIMS_PER_CAT = 3
NUM_CATS = 2
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
# There should be one logit per dim, we're one short
logits = th.randn(1, DIMS_PER_CAT * NUM_CATS - 1)
with pytest.raises(RuntimeError):
dist.proba_distribution(logits)
# That's better
logits = th.randn(1, DIMS_PER_CAT * NUM_CATS)
dist.proba_distribution(logits)
# Other numbers of dimensions are acceptable as long as they can be realigned
logits = th.randn(DIMS_PER_CAT * NUM_CATS)
dist.proba_distribution(logits)
logits = th.randn(3, DIMS_PER_CAT * NUM_CATS, 3)
dist.proba_distribution(logits)
def test_dim_masking(self):
DIMS_PER_CAT = 2
NUM_CATS = 3
dist = MaskableMultiCategoricalDistribution([DIMS_PER_CAT] * NUM_CATS)
logits = th.Tensor([[0] * DIMS_PER_CAT * NUM_CATS])
dist.proba_distribution(logits)
assert len(dist.distributions) == NUM_CATS
for i in range(NUM_CATS):
assert (dist.distributions[i].probs == 0.5).all()
assert int(dist.entropy().exp()) == DIMS_PER_CAT**NUM_CATS
for i in range(DIMS_PER_CAT):
mask = np.array([False] * DIMS_PER_CAT * NUM_CATS)
for j in range(NUM_CATS):
mask[j * DIMS_PER_CAT + i] = True
dist.apply_masking(mask)
for j in range(NUM_CATS):
probs = dist.distributions[j].probs
assert probs.sum() == 1
assert probs[0][i] == 1
assert int(dist.entropy().exp()) == 1
dist.apply_masking(None)
for i in range(NUM_CATS):
assert (dist.distributions[i].probs == 0.5).all()
assert int(dist.entropy().exp()) == DIMS_PER_CAT**NUM_CATS
class TestMaskableBernoulliDistribution:
def test_distribution_must_be_initialized(self):
"""
Cannot use distribution before it has logits
"""
DIMS = 2
dist = MaskableBernoulliDistribution(DIMS)
with pytest.raises(AssertionError):
dist.log_prob(th.randint(1, (2, DIMS)))
with pytest.raises(AssertionError):
dist.entropy()
with pytest.raises(AssertionError):
dist.sample()
with pytest.raises(AssertionError):
dist.mode()
with pytest.raises(AssertionError):
dist.apply_masking(None)
# But now we can
action_logits = th.randn(1, 2 * DIMS)
dist.proba_distribution(action_logits)
actions = th.randint(1, (2, DIMS))
dist.log_prob(actions)
dist.entropy()
dist.sample()
dist.mode()
# Test api
dist.actions_from_params(action_logits)
dist.log_prob_from_params(action_logits)
dist.apply_masking(None)
def test_logits_must_align_with_dims(self):
NUM_DIMS = 3
dist = MaskableBernoulliDistribution(NUM_DIMS)
# There should be two logits per dim, we're one short
logits = th.randn(1, 2 * NUM_DIMS - 1)
with pytest.raises(RuntimeError):
dist.proba_distribution(logits)
# That's better
logits = th.randn(1, 2 * NUM_DIMS)
dist.proba_distribution(logits)
# Other numbers of dimensions are acceptable as long as they can be realigned
logits = th.randn(2 * NUM_DIMS)
dist.proba_distribution(logits)
logits = th.randn(3, 2 * NUM_DIMS, 3)
dist.proba_distribution(logits)
def test_dim_masking(self):
NUM_DIMS = 2
BINARY_STATES = 2
dist = MaskableBernoulliDistribution(NUM_DIMS)
logits = th.Tensor([[0] * BINARY_STATES * NUM_DIMS])
dist.proba_distribution(logits)
assert len(dist.distributions) == NUM_DIMS
for i in range(NUM_DIMS):
assert (dist.distributions[i].probs == 0.5).all()
assert int(dist.entropy().exp()) == BINARY_STATES * NUM_DIMS
for i in range(BINARY_STATES):
mask = np.array([False] * BINARY_STATES * NUM_DIMS)
for j in range(NUM_DIMS):
mask[j * BINARY_STATES + i] = True
dist.apply_masking(mask)
for j in range(NUM_DIMS):
probs = dist.distributions[j].probs
assert probs.sum() == 1
assert probs[0][i] == 1
assert int(dist.entropy().exp()) == 1
dist.apply_masking(None)
for i in range(NUM_DIMS):
assert (dist.distributions[i].probs == 0.5).all()
assert int(dist.entropy().exp()) == BINARY_STATES * NUM_DIMS