stable-baselines3-contrib-sacd/tests/test_run.py

151 lines
4.3 KiB
Python

import gym
import pytest
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import ARS, QRDQN, TQC, TRPO, MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sb3_contrib.common.vec_env import AsyncEval
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
def test_tqc(ent_coef):
model = TQC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
ent_coef=ent_coef,
)
model.learn(total_timesteps=300, progress_bar=True)
@pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics):
# Test TQC with different number of critics
model = TQC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64], n_critics=n_critics),
learning_starts=100,
verbose=1,
)
model.learn(total_timesteps=300)
def test_sde():
model = TQC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64]),
use_sde=True,
learning_starts=100,
verbose=1,
)
model.learn(total_timesteps=300)
model.policy.reset_noise()
model.policy.actor.get_std()
def test_qrdqn():
model = QRDQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
learning_starts=100,
buffer_size=500,
learning_rate=3e-4,
verbose=1,
)
model.learn(total_timesteps=500)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
def test_trpo(env_id):
model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500)
def test_trpo_params():
# Test with gSDE and subsampling
model = TRPO(
"MlpPolicy",
"Pendulum-v1",
n_steps=64,
batch_size=32,
use_sde=True,
sub_sampling_factor=4,
seed=0,
policy_kwargs=dict(net_arch=[dict(pi=[32], vf=[32])]),
verbose=1,
)
model.learn(total_timesteps=500)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
def test_ars(policy_str, env_id):
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
model.learn(total_timesteps=500, log_interval=1)
def test_ars_multi_env():
env = make_vec_env("Pendulum-v1", n_envs=2)
model = ARS("MlpPolicy", env, n_delta=1)
model.learn(total_timesteps=250)
env = VecNormalize(make_vec_env("Pendulum-v1", n_envs=1))
model = ARS("MlpPolicy", env, n_delta=2, seed=0)
# with parallelism
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) for _ in range(2)], model.policy)
async_eval.seed(0)
model.learn(500, async_eval=async_eval)
@pytest.mark.parametrize("n_top", [2, 8])
def test_ars_n_top(n_top):
n_delta = 3
if n_top > n_delta:
with pytest.warns(UserWarning):
model = ARS("MlpPolicy", "Pendulum-v1", n_delta=n_delta, n_top=n_top)
model.learn(total_timesteps=500)
else:
model = ARS("MlpPolicy", "Pendulum-v1", n_delta=n_delta, n_top=n_top)
model.learn(total_timesteps=500)
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
def test_offpolicy_multi_env(model_class):
if model_class in [TQC]:
env_id = "Pendulum-v1"
policy_kwargs = dict(net_arch=[64], n_critics=1)
else:
env_id = "CartPole-v1"
policy_kwargs = dict(net_arch=[64])
def make_env():
env = gym.make(env_id)
# to check that the code handling timeouts runs
env = gym.wrappers.TimeLimit(env, 50)
return env
env = make_vec_env(make_env, n_envs=2)
model = model_class(
"MlpPolicy",
env,
policy_kwargs=policy_kwargs,
learning_starts=100,
buffer_size=10000,
verbose=0,
train_freq=5,
)
model.learn(total_timesteps=150)
@pytest.mark.parametrize("normalize_advantage", [False, True])
def test_advantage_normalization(normalize_advantage):
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
model = MaskablePPO("MlpPolicy", env, n_steps=64, normalize_advantage=normalize_advantage)
model.learn(64)