stable-baselines3-contrib-sacd/tests/test_run.py

111 lines
2.8 KiB
Python

import gym
import pytest
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib import QRDQN, TQC, TRPO
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
def test_tqc(ent_coef):
model = TQC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
create_eval_env=True,
ent_coef=ent_coef,
)
model.learn(total_timesteps=300, eval_freq=250)
@pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics):
# Test TQC with different number of critics
model = TQC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64], n_critics=n_critics),
learning_starts=100,
verbose=1,
)
model.learn(total_timesteps=300)
def test_sde():
model = TQC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64]),
use_sde=True,
learning_starts=100,
verbose=1,
)
model.learn(total_timesteps=300)
model.policy.reset_noise()
model.policy.actor.get_std()
def test_qrdqn():
model = QRDQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
learning_starts=100,
buffer_size=500,
learning_rate=3e-4,
verbose=1,
create_eval_env=True,
)
model.learn(total_timesteps=500, eval_freq=250)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
def test_trpo(env_id):
model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500)
def test_trpo_params():
# Test with gSDE and subsampling
model = TRPO(
"MlpPolicy",
"Pendulum-v0",
n_steps=64,
batch_size=32,
use_sde=True,
sub_sampling_factor=4,
seed=0,
policy_kwargs=dict(net_arch=[dict(pi=[32], vf=[32])]),
verbose=1,
)
model.learn(total_timesteps=500)
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
def test_offpolicy_multi_env(model_class):
if model_class in [TQC]:
env_id = "Pendulum-v0"
policy_kwargs = dict(net_arch=[64], n_critics=1)
else:
env_id = "CartPole-v1"
policy_kwargs = dict(net_arch=[64])
def make_env():
env = gym.make(env_id)
# to check that the code handling timeouts runs
env = gym.wrappers.TimeLimit(env, 50)
return env
env = make_vec_env(make_env, n_envs=2)
model = model_class(
"MlpPolicy",
env,
policy_kwargs=policy_kwargs,
learning_starts=100,
buffer_size=10000,
verbose=0,
train_freq=5,
)
model.learn(total_timesteps=150)