stable-baselines3-contrib-sacd/tests/test_identity.py

39 lines
1.3 KiB
Python

import numpy as np
import pytest
from stable_baselines3.common.envs import IdentityEnv, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib import QRDQN, TRPO
DIM = 4
@pytest.mark.parametrize("model_class", [QRDQN, TRPO])
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
def test_discrete(model_class, env):
vec_env = DummyVecEnv([lambda: env])
kwargs = {}
n_steps = 1500
if model_class == QRDQN:
kwargs = dict(
learning_starts=0,
policy_kwargs=dict(n_quantiles=25, net_arch=[32]),
target_update_interval=10,
train_freq=2,
batch_size=256,
)
n_steps = 1500
# DQN only support discrete actions
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
return
elif n_steps == TRPO:
kwargs = dict(n_steps=256, cg_max_steps=5)
model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
evaluate_policy(model, vec_env, n_eval_episodes=20, reward_threshold=90, warn=False)
obs = vec_env.reset()
assert np.shape(model.predict(obs)[0]) == np.shape(obs)