Upgrade Gym to 0.21 (#59)

* Pendulum-v0 -> Pendulum-v1

* Reformat with black

* Update changelog

* Fix dtype bug in TimeFeatureWrapper

* Update version and removed forward calls

* Update CI

* Fix min version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Adam Gleave 2022-02-22 15:25:43 +00:00 committed by GitHub
parent a78891bd00
commit 901a648507
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 64 additions and 43 deletions

View File

@ -36,8 +36,6 @@ jobs:
pip install . pip install .
# Use headless version # Use headless version
pip install opencv-python-headless pip install opencv-python-headless
# Tmp fix: ROM missing in the newest atari-py version
pip install atari-py==0.2.5
- name: Build the doc - name: Build the doc
run: | run: |
make doc make doc

View File

@ -12,7 +12,7 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
from sb3_contrib import TQC from sb3_contrib import TQC
model = TQC("MlpPolicy", "Pendulum-v0", top_quantiles_to_drop_per_net=2, verbose=1) model = TQC("MlpPolicy", "Pendulum-v1", top_quantiles_to_drop_per_net=2, verbose=1)
model.learn(total_timesteps=10_000, log_interval=4) model.learn(total_timesteps=10_000, log_interval=4)
model.save("tqc_pendulum") model.save("tqc_pendulum")
@ -54,7 +54,7 @@ Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environmen
from sb3_contrib import TRPO from sb3_contrib import TRPO
model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1) model = TRPO("MlpPolicy", "Pendulum-v1", gamma=0.9, verbose=1)
model.learn(total_timesteps=100_000, log_interval=4) model.learn(total_timesteps=100_000, log_interval=4)
model.save("trpo_pendulum") model.save("trpo_pendulum")
@ -68,6 +68,6 @@ Train an agent using Augmented Random Search (ARS) agent on the Pendulum environ
from sb3_contrib import ARS from sb3_contrib import ARS
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1) model = ARS("LinearPolicy", "Pendulum-v1", verbose=1)
model.learn(total_timesteps=10000, log_interval=4) model.learn(total_timesteps=10000, log_interval=4)
model.save("ars_pendulum") model.save("ars_pendulum")

View File

@ -3,6 +3,27 @@
Changelog Changelog
========== ==========
Release 1.4.1a1 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched minimum Gym version to 0.21.0.
- Upgraded to Stable-Baselines3 >= 1.4.1a1
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
- Removed explict calls to ``forward()`` method as per pytorch guidelines
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Release 1.4.0 (2022-01-19) Release 1.4.0 (2022-01-19)
------------------------------- -------------------------------

View File

@ -64,7 +64,7 @@ Example
from sb3_contrib import ARS from sb3_contrib import ARS
# Policy can be LinearPolicy or MlpPolicy # Policy can be LinearPolicy or MlpPolicy
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1) model = ARS("LinearPolicy", "Pendulum-v1", verbose=1)
model.learn(total_timesteps=10000, log_interval=4) model.learn(total_timesteps=10000, log_interval=4)
model.save("ars_pendulum") model.save("ars_pendulum")

View File

@ -57,7 +57,7 @@ Example
from sb3_contrib import TQC from sb3_contrib import TQC
env = gym.make("Pendulum-v0") env = gym.make("Pendulum-v1")
policy_kwargs = dict(n_critics=2, n_quantiles=25) policy_kwargs = dict(n_critics=2, n_quantiles=25)
model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs) model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs)

View File

@ -54,7 +54,7 @@ Example
from sb3_contrib import TRPO from sb3_contrib import TRPO
env = gym.make("Pendulum-v0") env = gym.make("Pendulum-v1")
model = TRPO("MlpPolicy", env, verbose=1) model = TRPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4) model.learn(total_timesteps=10000, log_interval=4)

View File

@ -76,7 +76,7 @@ class ARSPolicy(BasePolicy):
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
# Non deterministic action does not really make sense for ARS, we ignore this parameter for now.. # Non deterministic action does not really make sense for ARS, we ignore this parameter for now..
return self.forward(observation) return self(observation)
class ARSLinearPolicy(ARSPolicy): class ARSLinearPolicy(ARSPolicy):

View File

@ -43,11 +43,12 @@ class TimeFeatureWrapper(gym.Wrapper):
low, high = obs_space.low, obs_space.high low, high = obs_space.low, obs_space.high
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0])) low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
self.dtype = obs_space.dtype
if isinstance(env.observation_space, gym.spaces.Dict): if isinstance(env.observation_space, gym.spaces.Dict):
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=np.float32) env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
else: else:
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32) env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
super(TimeFeatureWrapper, self).__init__(env) super(TimeFeatureWrapper, self).__init__(env)
@ -84,6 +85,7 @@ class TimeFeatureWrapper(gym.Wrapper):
time_feature = 1 - (self._current_step / self._max_steps) time_feature = 1 - (self._current_step / self._max_steps)
if self._test_mode: if self._test_mode:
time_feature = 1.0 time_feature = 1.0
time_feature = np.array(time_feature, dtype=self.dtype)
if isinstance(obs, dict): if isinstance(obs, dict):
obs["observation"] = np.append(obs["observation"], time_feature) obs["observation"] = np.append(obs["observation"], time_feature)

View File

@ -317,7 +317,7 @@ class MaskablePPO(OnPolicyAlgorithm):
if use_masking: if use_masking:
action_masks = get_action_masks(env) action_masks = get_action_masks(env)
actions, values, log_probs = self.policy.forward(obs_tensor, action_masks=action_masks) actions, values, log_probs = self.policy(obs_tensor, action_masks=action_masks)
actions = actions.cpu().numpy() actions = actions.cpu().numpy()
new_obs, rewards, dones, infos = env.step(actions) new_obs, rewards, dones, infos = env.step(actions)

View File

@ -69,7 +69,7 @@ class QuantileNetwork(BasePolicy):
return quantiles.view(-1, self.n_quantiles, self.action_space.n) return quantiles.view(-1, self.n_quantiles, self.action_space.n)
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
q_values = self.forward(observation).mean(dim=1) q_values = self(observation).mean(dim=1)
# Greedy action # Greedy action
action = q_values.argmax(dim=1).reshape(-1) action = q_values.argmax(dim=1).reshape(-1)
return action return action

View File

@ -181,7 +181,7 @@ class Actor(BasePolicy):
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.forward(observation, deterministic) return self(observation, deterministic)
class Critic(BaseModel): class Critic(BaseModel):

View File

@ -1 +1 @@
1.4.0 1.4.1a1

View File

@ -63,7 +63,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=1.4.0", "stable_baselines3>=1.4.1a1",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

View File

@ -17,7 +17,7 @@ def test_deterministic_training_common(algo):
rewards = [[], []] rewards = [[], []]
# Smaller network # Smaller network
kwargs = {"policy_kwargs": dict(net_arch=[64])} kwargs = {"policy_kwargs": dict(net_arch=[64])}
env_id = "Pendulum-v0" env_id = "Pendulum-v1"
if algo == ARS_MULTI: if algo == ARS_MULTI:
algo = ARS algo = ARS
ars_multi = True ars_multi = True

View File

@ -11,7 +11,7 @@ from sb3_contrib.common.vec_env import AsyncEval
def test_tqc(ent_coef): def test_tqc(ent_coef):
model = TQC( model = TQC(
"MlpPolicy", "MlpPolicy",
"Pendulum-v0", "Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64]), policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100, learning_starts=100,
verbose=1, verbose=1,
@ -26,7 +26,7 @@ def test_n_critics(n_critics):
# Test TQC with different number of critics # Test TQC with different number of critics
model = TQC( model = TQC(
"MlpPolicy", "MlpPolicy",
"Pendulum-v0", "Pendulum-v1",
policy_kwargs=dict(net_arch=[64], n_critics=n_critics), policy_kwargs=dict(net_arch=[64], n_critics=n_critics),
learning_starts=100, learning_starts=100,
verbose=1, verbose=1,
@ -37,7 +37,7 @@ def test_n_critics(n_critics):
def test_sde(): def test_sde():
model = TQC( model = TQC(
"MlpPolicy", "MlpPolicy",
"Pendulum-v0", "Pendulum-v1",
policy_kwargs=dict(net_arch=[64]), policy_kwargs=dict(net_arch=[64]),
use_sde=True, use_sde=True,
learning_starts=100, learning_starts=100,
@ -62,7 +62,7 @@ def test_qrdqn():
model.learn(total_timesteps=500, eval_freq=250) model.learn(total_timesteps=500, eval_freq=250)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
def test_trpo(env_id): def test_trpo(env_id):
model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1) model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500) model.learn(total_timesteps=500)
@ -72,7 +72,7 @@ def test_trpo_params():
# Test with gSDE and subsampling # Test with gSDE and subsampling
model = TRPO( model = TRPO(
"MlpPolicy", "MlpPolicy",
"Pendulum-v0", "Pendulum-v1",
n_steps=64, n_steps=64,
batch_size=32, batch_size=32,
use_sde=True, use_sde=True,
@ -84,7 +84,7 @@ def test_trpo_params():
model.learn(total_timesteps=500) model.learn(total_timesteps=500)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"]) @pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
def test_ars(policy_str, env_id): def test_ars(policy_str, env_id):
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0) model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
@ -92,14 +92,14 @@ def test_ars(policy_str, env_id):
def test_ars_multi_env(): def test_ars_multi_env():
env = make_vec_env("Pendulum-v0", n_envs=2) env = make_vec_env("Pendulum-v1", n_envs=2)
model = ARS("MlpPolicy", env, n_delta=1) model = ARS("MlpPolicy", env, n_delta=1)
model.learn(total_timesteps=250) model.learn(total_timesteps=250)
env = VecNormalize(make_vec_env("Pendulum-v0", n_envs=1)) env = VecNormalize(make_vec_env("Pendulum-v1", n_envs=1))
model = ARS("MlpPolicy", env, n_delta=2, seed=0) model = ARS("MlpPolicy", env, n_delta=2, seed=0)
# with parallelism # with parallelism
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v0", n_envs=1)) for _ in range(2)], model.policy) async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) for _ in range(2)], model.policy)
async_eval.seed(0) async_eval.seed(0)
model.learn(500, async_eval=async_eval) model.learn(500, async_eval=async_eval)
@ -109,17 +109,17 @@ def test_ars_n_top(n_top):
n_delta = 3 n_delta = 3
if n_top > n_delta: if n_top > n_delta:
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top) model = ARS("MlpPolicy", "Pendulum-v1", n_delta=n_delta, n_top=n_top)
model.learn(total_timesteps=500) model.learn(total_timesteps=500)
else: else:
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top) model = ARS("MlpPolicy", "Pendulum-v1", n_delta=n_delta, n_top=n_top)
model.learn(total_timesteps=500) model.learn(total_timesteps=500)
@pytest.mark.parametrize("model_class", [TQC, QRDQN]) @pytest.mark.parametrize("model_class", [TQC, QRDQN])
def test_offpolicy_multi_env(model_class): def test_offpolicy_multi_env(model_class):
if model_class in [TQC]: if model_class in [TQC]:
env_id = "Pendulum-v0" env_id = "Pendulum-v1"
policy_kwargs = dict(net_arch=[64], n_critics=1) policy_kwargs = dict(net_arch=[64], n_critics=1)
else: else:
env_id = "CartPole-v1" env_id = "CartPole-v1"

View File

@ -447,7 +447,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
def test_save_load_pytorch_var(tmp_path): def test_save_load_pytorch_var(tmp_path):
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1)) model = TQC("MlpPolicy", "Pendulum-v1", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200) model.learn(200)
save_path = str(tmp_path / "tqc_pendulum") save_path = str(tmp_path / "tqc_pendulum")
model.save(save_path) model.save(save_path)
@ -464,7 +464,7 @@ def test_save_load_pytorch_var(tmp_path):
assert not th.allclose(log_ent_coef_before, log_ent_coef_after) assert not th.allclose(log_ent_coef_before, log_ent_coef_after)
# With a fixed entropy coef # With a fixed entropy coef
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1)) model = TQC("MlpPolicy", "Pendulum-v1", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200) model.learn(200)
save_path = str(tmp_path / "tqc_pendulum") save_path = str(tmp_path / "tqc_pendulum")
model.save(save_path) model.save(save_path)

View File

@ -161,7 +161,7 @@ def test_qrdqn_train_with_batch_norm():
def test_tqc_train_with_batch_norm(): def test_tqc_train_with_batch_norm():
model = TQC( model = TQC(
"MlpPolicy", "MlpPolicy",
"Pendulum-v0", "Pendulum-v1",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0, learning_starts=0,
tau=0, # do not copy the target tau=0, # do not copy the target
@ -203,7 +203,7 @@ def test_offpolicy_collect_rollout_batch_norm(model_class):
if model_class in [QRDQN]: if model_class in [QRDQN]:
env_id = "CartPole-v1" env_id = "CartPole-v1"
else: else:
env_id = "Pendulum-v0" env_id = "Pendulum-v1"
clone_helper = CLONE_HELPERS[model_class] clone_helper = CLONE_HELPERS[model_class]
@ -230,7 +230,7 @@ def test_offpolicy_collect_rollout_batch_norm(model_class):
@pytest.mark.parametrize("model_class", [QRDQN, TQC]) @pytest.mark.parametrize("model_class", [QRDQN, TQC])
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"]) @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_predict_with_dropout_batch_norm(model_class, env_id): def test_predict_with_dropout_batch_norm(model_class, env_id):
if env_id == "CartPole-v1": if env_id == "CartPole-v1":
if model_class in [TQC]: if model_class in [TQC]:

View File

@ -55,10 +55,10 @@ def test_trpo_warnings():
# Only 1 step: advantage normalization will return NaN # Only 1 step: advantage normalization will return NaN
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
TRPO("MlpPolicy", "Pendulum-v0", n_steps=1) TRPO("MlpPolicy", "Pendulum-v1", n_steps=1)
# One step not advantage normalization: ok # One step not advantage normalization: ok
TRPO("MlpPolicy", "Pendulum-v0", n_steps=1, normalize_advantage=False, batch_size=1) TRPO("MlpPolicy", "Pendulum-v1", n_steps=1, normalize_advantage=False, batch_size=1)
# Truncated mini-batch # Truncated mini-batch
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
TRPO("MlpPolicy", "Pendulum-v0", n_steps=6, batch_size=8) TRPO("MlpPolicy", "Pendulum-v1", n_steps=6, batch_size=8)

View File

@ -32,7 +32,7 @@ def check_time_feature(obs, timestep, max_timesteps):
def test_time_feature(): def test_time_feature():
env = gym.make("Pendulum-v0") env = gym.make("Pendulum-v1")
env = TimeFeatureWrapper(env) env = TimeFeatureWrapper(env)
check_env(env, warn=False) check_env(env, warn=False)
# Check for four episodes # Check for four episodes
@ -58,7 +58,7 @@ def test_time_feature():
check_time_feature(obs["observation"], timestep=1, max_timesteps=500) check_time_feature(obs["observation"], timestep=1, max_timesteps=500)
# In test mode, the time feature must be constant # In test mode, the time feature must be constant
env = gym.make("Pendulum-v0") env = gym.make("Pendulum-v1")
env = TimeFeatureWrapper(env, test_mode=True) env = TimeFeatureWrapper(env, test_mode=True)
obs = env.reset() obs = env.reset()
check_time_feature(obs, timestep=0, max_timesteps=200) check_time_feature(obs, timestep=0, max_timesteps=200)