242 lines
6.3 KiB
Python
242 lines
6.3 KiB
Python
import gym
|
|
import numpy as np
|
|
import pytest
|
|
from gym import spaces
|
|
from gym.envs.classic_control import CartPoleEnv
|
|
from gym.wrappers.time_limit import TimeLimit
|
|
from stable_baselines3.common.callbacks import EvalCallback
|
|
from stable_baselines3.common.env_util import make_vec_env
|
|
from stable_baselines3.common.envs import FakeImageEnv
|
|
from stable_baselines3.common.evaluation import evaluate_policy
|
|
from stable_baselines3.common.vec_env import VecNormalize
|
|
|
|
from sb3_contrib import RecurrentPPO
|
|
|
|
|
|
class ToDictWrapper(gym.Wrapper):
|
|
"""
|
|
Simple wrapper to test MultInputPolicy on Dict obs.
|
|
"""
|
|
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space})
|
|
|
|
def reset(self):
|
|
return {"obs": self.env.reset()}
|
|
|
|
def step(self, action):
|
|
obs, reward, done, infos = self.env.step(action)
|
|
return {"obs": obs}, reward, done, infos
|
|
|
|
|
|
class CartPoleNoVelEnv(CartPoleEnv):
|
|
"""Variant of CartPoleEnv with velocity information removed. This task requires memory to solve."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
high = np.array(
|
|
[
|
|
self.x_threshold * 2,
|
|
self.theta_threshold_radians * 2,
|
|
]
|
|
)
|
|
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
|
|
|
|
@staticmethod
|
|
def _pos_obs(full_obs):
|
|
xpos, _xvel, thetapos, _thetavel = full_obs
|
|
return xpos, thetapos
|
|
|
|
def reset(self):
|
|
full_obs = super().reset()
|
|
return CartPoleNoVelEnv._pos_obs(full_obs)
|
|
|
|
def step(self, action):
|
|
full_obs, rew, done, info = super().step(action)
|
|
return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"policy_kwargs",
|
|
[
|
|
{},
|
|
{"share_features_extractor": False},
|
|
dict(shared_lstm=True, enable_critic_lstm=False),
|
|
dict(
|
|
enable_critic_lstm=True,
|
|
lstm_hidden_size=4,
|
|
lstm_kwargs=dict(dropout=0.5),
|
|
n_lstm_layers=2,
|
|
),
|
|
dict(
|
|
enable_critic_lstm=False,
|
|
lstm_hidden_size=4,
|
|
lstm_kwargs=dict(dropout=0.5),
|
|
n_lstm_layers=2,
|
|
),
|
|
dict(
|
|
enable_critic_lstm=False,
|
|
lstm_hidden_size=4,
|
|
share_features_extractor=False,
|
|
),
|
|
],
|
|
)
|
|
def test_cnn(policy_kwargs):
|
|
model = RecurrentPPO(
|
|
"CnnLstmPolicy",
|
|
FakeImageEnv(screen_height=40, screen_width=40, n_channels=3),
|
|
n_steps=16,
|
|
seed=0,
|
|
policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)),
|
|
n_epochs=2,
|
|
)
|
|
|
|
model.learn(total_timesteps=32)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"policy_kwargs",
|
|
[
|
|
{},
|
|
dict(shared_lstm=True, enable_critic_lstm=False),
|
|
dict(
|
|
enable_critic_lstm=True,
|
|
lstm_hidden_size=4,
|
|
lstm_kwargs=dict(dropout=0.5),
|
|
n_lstm_layers=2,
|
|
),
|
|
dict(
|
|
enable_critic_lstm=False,
|
|
lstm_hidden_size=4,
|
|
lstm_kwargs=dict(dropout=0.5),
|
|
n_lstm_layers=2,
|
|
),
|
|
],
|
|
)
|
|
def test_policy_kwargs(policy_kwargs):
|
|
model = RecurrentPPO(
|
|
"MlpLstmPolicy",
|
|
"CartPole-v1",
|
|
n_steps=16,
|
|
seed=0,
|
|
policy_kwargs=policy_kwargs,
|
|
)
|
|
|
|
model.learn(total_timesteps=32)
|
|
|
|
|
|
def test_check():
|
|
policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=True)
|
|
with pytest.raises(AssertionError):
|
|
RecurrentPPO(
|
|
"MlpLstmPolicy",
|
|
"CartPole-v1",
|
|
n_steps=16,
|
|
seed=0,
|
|
policy_kwargs=policy_kwargs,
|
|
)
|
|
|
|
policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=False, share_features_extractor=False)
|
|
with pytest.raises(AssertionError):
|
|
RecurrentPPO(
|
|
"MlpLstmPolicy",
|
|
"CartPole-v1",
|
|
n_steps=16,
|
|
seed=0,
|
|
policy_kwargs=policy_kwargs,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
|
def test_run(env):
|
|
model = RecurrentPPO(
|
|
"MlpLstmPolicy",
|
|
env,
|
|
n_steps=16,
|
|
seed=0,
|
|
)
|
|
|
|
model.learn(total_timesteps=32)
|
|
|
|
|
|
def test_run_sde():
|
|
model = RecurrentPPO(
|
|
"MlpLstmPolicy",
|
|
"Pendulum-v1",
|
|
n_steps=16,
|
|
seed=0,
|
|
sde_sample_freq=4,
|
|
use_sde=True,
|
|
clip_range_vf=0.1,
|
|
)
|
|
|
|
model.learn(total_timesteps=200)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"policy_kwargs",
|
|
[
|
|
{},
|
|
dict(shared_lstm=True, enable_critic_lstm=False),
|
|
dict(
|
|
enable_critic_lstm=True,
|
|
lstm_hidden_size=4,
|
|
lstm_kwargs=dict(dropout=0.5),
|
|
n_lstm_layers=2,
|
|
),
|
|
dict(
|
|
enable_critic_lstm=False,
|
|
lstm_hidden_size=4,
|
|
lstm_kwargs=dict(dropout=0.5),
|
|
n_lstm_layers=2,
|
|
),
|
|
],
|
|
)
|
|
def test_dict_obs(policy_kwargs):
|
|
env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper)
|
|
model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32, policy_kwargs=policy_kwargs).learn(64)
|
|
evaluate_policy(model, env, warn=False)
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_ppo_lstm_performance():
|
|
# env = make_vec_env("CartPole-v1", n_envs=16)
|
|
def make_env():
|
|
env = CartPoleNoVelEnv()
|
|
env = TimeLimit(env, max_episode_steps=500)
|
|
return env
|
|
|
|
env = VecNormalize(make_vec_env(make_env, n_envs=8))
|
|
|
|
eval_callback = EvalCallback(
|
|
VecNormalize(make_vec_env(make_env, n_envs=4), training=False, norm_reward=False),
|
|
n_eval_episodes=20,
|
|
eval_freq=5000 // env.num_envs,
|
|
)
|
|
|
|
model = RecurrentPPO(
|
|
"MlpLstmPolicy",
|
|
env,
|
|
n_steps=128,
|
|
learning_rate=0.0007,
|
|
verbose=1,
|
|
batch_size=256,
|
|
seed=1,
|
|
n_epochs=10,
|
|
max_grad_norm=1,
|
|
gae_lambda=0.98,
|
|
policy_kwargs=dict(
|
|
net_arch=[dict(vf=[64])],
|
|
lstm_hidden_size=64,
|
|
ortho_init=False,
|
|
enable_critic_lstm=True,
|
|
),
|
|
)
|
|
|
|
model.learn(total_timesteps=50_000, callback=eval_callback)
|
|
# Maximum episode reward is 500.
|
|
# In CartPole-v1, a non-recurrent policy can easily get >= 450.
|
|
# In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50.
|
|
evaluate_policy(model, env, reward_threshold=450)
|