stable-baselines3-contrib-sacd/tests/test_run.py

59 lines
1.4 KiB
Python

import pytest
from sb3_contrib import QRDQN, TQC
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
def test_tqc(ent_coef):
model = TQC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
create_eval_env=True,
ent_coef=ent_coef,
)
model.learn(total_timesteps=300, eval_freq=250)
@pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics):
# Test TQC with different number of critics
model = TQC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64], n_critics=n_critics),
learning_starts=100,
verbose=1,
)
model.learn(total_timesteps=300)
def test_sde():
model = TQC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64]),
use_sde=True,
learning_starts=100,
verbose=1,
)
model.learn(total_timesteps=300)
model.policy.reset_noise()
model.policy.actor.get_std()
def test_qrdqn():
model = QRDQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
learning_starts=100,
buffer_size=500,
learning_rate=3e-4,
verbose=1,
create_eval_env=True,
)
model.learn(total_timesteps=500, eval_freq=250)