import gym import pytest from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import VecNormalize from sb3_contrib import ARS, QRDQN, TQC, TRPO, MaskablePPO from sb3_contrib.common.envs import InvalidActionEnvDiscrete from sb3_contrib.common.vec_env import AsyncEval @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) def test_tqc(ent_coef): with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated model = TQC( "MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=[64, 64]), learning_starts=100, verbose=1, create_eval_env=True, ent_coef=ent_coef, ) model.learn(total_timesteps=300, eval_freq=250) @pytest.mark.parametrize("n_critics", [1, 3]) def test_n_critics(n_critics): # Test TQC with different number of critics model = TQC( "MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1, ) model.learn(total_timesteps=300) def test_sde(): model = TQC( "MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=[64]), use_sde=True, learning_starts=100, verbose=1, ) model.learn(total_timesteps=300) model.policy.reset_noise() model.policy.actor.get_std() def test_qrdqn(): with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated model = QRDQN( "MlpPolicy", "CartPole-v1", policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]), learning_starts=100, buffer_size=500, learning_rate=3e-4, verbose=1, create_eval_env=True, ) model.learn(total_timesteps=500, eval_freq=250) @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_trpo(env_id): model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1) model.learn(total_timesteps=500) def test_trpo_params(): # Test with gSDE and subsampling model = TRPO( "MlpPolicy", "Pendulum-v1", n_steps=64, batch_size=32, use_sde=True, sub_sampling_factor=4, seed=0, policy_kwargs=dict(net_arch=[dict(pi=[32], vf=[32])]), verbose=1, ) model.learn(total_timesteps=500) @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) @pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"]) def test_ars(policy_str, env_id): model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0) with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated model.learn(total_timesteps=500, log_interval=1, eval_freq=250) def test_ars_multi_env(): env = make_vec_env("Pendulum-v1", n_envs=2) model = ARS("MlpPolicy", env, n_delta=1) model.learn(total_timesteps=250) env = VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) model = ARS("MlpPolicy", env, n_delta=2, seed=0) # with parallelism async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) for _ in range(2)], model.policy) async_eval.seed(0) model.learn(500, async_eval=async_eval) @pytest.mark.parametrize("n_top", [2, 8]) def test_ars_n_top(n_top): n_delta = 3 if n_top > n_delta: with pytest.warns(UserWarning): model = ARS("MlpPolicy", "Pendulum-v1", n_delta=n_delta, n_top=n_top) model.learn(total_timesteps=500) else: model = ARS("MlpPolicy", "Pendulum-v1", n_delta=n_delta, n_top=n_top) model.learn(total_timesteps=500) @pytest.mark.parametrize("model_class", [TQC, QRDQN]) def test_offpolicy_multi_env(model_class): if model_class in [TQC]: env_id = "Pendulum-v1" policy_kwargs = dict(net_arch=[64], n_critics=1) else: env_id = "CartPole-v1" policy_kwargs = dict(net_arch=[64]) def make_env(): env = gym.make(env_id) # to check that the code handling timeouts runs env = gym.wrappers.TimeLimit(env, 50) return env env = make_vec_env(make_env, n_envs=2) model = model_class( "MlpPolicy", env, policy_kwargs=policy_kwargs, learning_starts=100, buffer_size=10000, verbose=0, train_freq=5, ) model.learn(total_timesteps=150) @pytest.mark.parametrize("normalize_advantage", [False, True]) def test_advantage_normalization(normalize_advantage): env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60) model = MaskablePPO("MlpPolicy", env, n_steps=64, normalize_advantage=normalize_advantage) model.learn(64)