Upgrade Gym to 0.21 (#59)

* Pendulum-v0 -> Pendulum-v1

* Reformat with black

* Update changelog

* Fix dtype bug in TimeFeatureWrapper

* Update version and removed forward calls

* Update CI

* Fix min version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Adam Gleave 2022-02-22 15:25:43 +00:00 committed by GitHub
parent a78891bd00
commit 901a648507
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 64 additions and 43 deletions

View File

@ -36,8 +36,6 @@ jobs:
pip install .
# Use headless version
pip install opencv-python-headless
# Tmp fix: ROM missing in the newest atari-py version
pip install atari-py==0.2.5
- name: Build the doc
run: |
make doc

View File

@ -12,7 +12,7 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
from sb3_contrib import TQC
model = TQC("MlpPolicy", "Pendulum-v0", top_quantiles_to_drop_per_net=2, verbose=1)
model = TQC("MlpPolicy", "Pendulum-v1", top_quantiles_to_drop_per_net=2, verbose=1)
model.learn(total_timesteps=10_000, log_interval=4)
model.save("tqc_pendulum")
@ -54,7 +54,7 @@ Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environmen
from sb3_contrib import TRPO
model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1)
model = TRPO("MlpPolicy", "Pendulum-v1", gamma=0.9, verbose=1)
model.learn(total_timesteps=100_000, log_interval=4)
model.save("trpo_pendulum")
@ -68,6 +68,6 @@ Train an agent using Augmented Random Search (ARS) agent on the Pendulum environ
from sb3_contrib import ARS
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1)
model = ARS("LinearPolicy", "Pendulum-v1", verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("ars_pendulum")

View File

@ -3,6 +3,27 @@
Changelog
==========
Release 1.4.1a1 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched minimum Gym version to 0.21.0.
- Upgraded to Stable-Baselines3 >= 1.4.1a1
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
- Removed explict calls to ``forward()`` method as per pytorch guidelines
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Release 1.4.0 (2022-01-19)
-------------------------------

View File

@ -64,7 +64,7 @@ Example
from sb3_contrib import ARS
# Policy can be LinearPolicy or MlpPolicy
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1)
model = ARS("LinearPolicy", "Pendulum-v1", verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("ars_pendulum")

View File

@ -57,7 +57,7 @@ Example
from sb3_contrib import TQC
env = gym.make("Pendulum-v0")
env = gym.make("Pendulum-v1")
policy_kwargs = dict(n_critics=2, n_quantiles=25)
model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs)

View File

@ -54,7 +54,7 @@ Example
from sb3_contrib import TRPO
env = gym.make("Pendulum-v0")
env = gym.make("Pendulum-v1")
model = TRPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)

View File

@ -76,7 +76,7 @@ class ARSPolicy(BasePolicy):
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
# Non deterministic action does not really make sense for ARS, we ignore this parameter for now..
return self.forward(observation)
return self(observation)
class ARSLinearPolicy(ARSPolicy):

View File

@ -61,7 +61,7 @@ def quantile_huber_loss(
# Note: in both cases, the loss has the same shape as pairwise_delta
pairwise_delta = target_quantiles.unsqueeze(-2) - current_quantiles.unsqueeze(-1)
abs_pairwise_delta = th.abs(pairwise_delta)
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5)
loss = th.abs(cum_prob - (pairwise_delta.detach() < 0).float()) * huber_loss
if sum_over_quantiles:
loss = loss.sum(dim=-2).mean()

View File

@ -43,11 +43,12 @@ class TimeFeatureWrapper(gym.Wrapper):
low, high = obs_space.low, obs_space.high
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
self.dtype = obs_space.dtype
if isinstance(env.observation_space, gym.spaces.Dict):
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=np.float32)
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
else:
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
super(TimeFeatureWrapper, self).__init__(env)
@ -84,6 +85,7 @@ class TimeFeatureWrapper(gym.Wrapper):
time_feature = 1 - (self._current_step / self._max_steps)
if self._test_mode:
time_feature = 1.0
time_feature = np.array(time_feature, dtype=self.dtype)
if isinstance(obs, dict):
obs["observation"] = np.append(obs["observation"], time_feature)

View File

@ -317,7 +317,7 @@ class MaskablePPO(OnPolicyAlgorithm):
if use_masking:
action_masks = get_action_masks(env)
actions, values, log_probs = self.policy.forward(obs_tensor, action_masks=action_masks)
actions, values, log_probs = self.policy(obs_tensor, action_masks=action_masks)
actions = actions.cpu().numpy()
new_obs, rewards, dones, infos = env.step(actions)

View File

@ -69,7 +69,7 @@ class QuantileNetwork(BasePolicy):
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
q_values = self.forward(observation).mean(dim=1)
q_values = self(observation).mean(dim=1)
# Greedy action
action = q_values.argmax(dim=1).reshape(-1)
return action

View File

@ -181,7 +181,7 @@ class Actor(BasePolicy):
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.forward(observation, deterministic)
return self(observation, deterministic)
class Critic(BaseModel):

View File

@ -1 +1 @@
1.4.0
1.4.1a1

View File

@ -63,7 +63,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.4.0",
"stable_baselines3>=1.4.1a1",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",

View File

@ -17,7 +17,7 @@ def test_deterministic_training_common(algo):
rewards = [[], []]
# Smaller network
kwargs = {"policy_kwargs": dict(net_arch=[64])}
env_id = "Pendulum-v0"
env_id = "Pendulum-v1"
if algo == ARS_MULTI:
algo = ARS
ars_multi = True

View File

@ -222,7 +222,7 @@ class TestMaskableMultiCategoricalDistribution:
assert len(dist.distributions) == NUM_CATS
for i in range(NUM_CATS):
assert (dist.distributions[i].probs == 0.5).all()
assert int(dist.entropy().exp()) == DIMS_PER_CAT ** NUM_CATS
assert int(dist.entropy().exp()) == DIMS_PER_CAT**NUM_CATS
for i in range(DIMS_PER_CAT):
mask = np.array([False] * DIMS_PER_CAT * NUM_CATS)
@ -240,7 +240,7 @@ class TestMaskableMultiCategoricalDistribution:
dist.apply_masking(None)
for i in range(NUM_CATS):
assert (dist.distributions[i].probs == 0.5).all()
assert int(dist.entropy().exp()) == DIMS_PER_CAT ** NUM_CATS
assert int(dist.entropy().exp()) == DIMS_PER_CAT**NUM_CATS
class TestMaskableBernoulliDistribution:

View File

@ -11,7 +11,7 @@ from sb3_contrib.common.vec_env import AsyncEval
def test_tqc(ent_coef):
model = TQC(
"MlpPolicy",
"Pendulum-v0",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
@ -26,7 +26,7 @@ def test_n_critics(n_critics):
# Test TQC with different number of critics
model = TQC(
"MlpPolicy",
"Pendulum-v0",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64], n_critics=n_critics),
learning_starts=100,
verbose=1,
@ -37,7 +37,7 @@ def test_n_critics(n_critics):
def test_sde():
model = TQC(
"MlpPolicy",
"Pendulum-v0",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64]),
use_sde=True,
learning_starts=100,
@ -62,7 +62,7 @@ def test_qrdqn():
model.learn(total_timesteps=500, eval_freq=250)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
def test_trpo(env_id):
model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500)
@ -72,7 +72,7 @@ def test_trpo_params():
# Test with gSDE and subsampling
model = TRPO(
"MlpPolicy",
"Pendulum-v0",
"Pendulum-v1",
n_steps=64,
batch_size=32,
use_sde=True,
@ -84,7 +84,7 @@ def test_trpo_params():
model.learn(total_timesteps=500)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
def test_ars(policy_str, env_id):
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
@ -92,14 +92,14 @@ def test_ars(policy_str, env_id):
def test_ars_multi_env():
env = make_vec_env("Pendulum-v0", n_envs=2)
env = make_vec_env("Pendulum-v1", n_envs=2)
model = ARS("MlpPolicy", env, n_delta=1)
model.learn(total_timesteps=250)
env = VecNormalize(make_vec_env("Pendulum-v0", n_envs=1))
env = VecNormalize(make_vec_env("Pendulum-v1", n_envs=1))
model = ARS("MlpPolicy", env, n_delta=2, seed=0)
# with parallelism
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v0", n_envs=1)) for _ in range(2)], model.policy)
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) for _ in range(2)], model.policy)
async_eval.seed(0)
model.learn(500, async_eval=async_eval)
@ -109,17 +109,17 @@ def test_ars_n_top(n_top):
n_delta = 3
if n_top > n_delta:
with pytest.warns(UserWarning):
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top)
model = ARS("MlpPolicy", "Pendulum-v1", n_delta=n_delta, n_top=n_top)
model.learn(total_timesteps=500)
else:
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top)
model = ARS("MlpPolicy", "Pendulum-v1", n_delta=n_delta, n_top=n_top)
model.learn(total_timesteps=500)
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
def test_offpolicy_multi_env(model_class):
if model_class in [TQC]:
env_id = "Pendulum-v0"
env_id = "Pendulum-v1"
policy_kwargs = dict(net_arch=[64], n_critics=1)
else:
env_id = "CartPole-v1"

View File

@ -447,7 +447,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
def test_save_load_pytorch_var(tmp_path):
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
model = TQC("MlpPolicy", "Pendulum-v1", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200)
save_path = str(tmp_path / "tqc_pendulum")
model.save(save_path)
@ -464,7 +464,7 @@ def test_save_load_pytorch_var(tmp_path):
assert not th.allclose(log_ent_coef_before, log_ent_coef_after)
# With a fixed entropy coef
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
model = TQC("MlpPolicy", "Pendulum-v1", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200)
save_path = str(tmp_path / "tqc_pendulum")
model.save(save_path)

View File

@ -161,7 +161,7 @@ def test_qrdqn_train_with_batch_norm():
def test_tqc_train_with_batch_norm():
model = TQC(
"MlpPolicy",
"Pendulum-v0",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
tau=0, # do not copy the target
@ -203,7 +203,7 @@ def test_offpolicy_collect_rollout_batch_norm(model_class):
if model_class in [QRDQN]:
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v0"
env_id = "Pendulum-v1"
clone_helper = CLONE_HELPERS[model_class]
@ -230,7 +230,7 @@ def test_offpolicy_collect_rollout_batch_norm(model_class):
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_predict_with_dropout_batch_norm(model_class, env_id):
if env_id == "CartPole-v1":
if model_class in [TQC]:

View File

@ -42,7 +42,7 @@ def test_cg():
def test_flat_grad():
n_parameters = 12 # 3 * (2 * 2)
x = th.nn.Parameter(th.ones(2, 2, requires_grad=True))
y = (x ** 2).sum()
y = (x**2).sum()
flat_grad_out = flat_grad(y, [x, x, x])
assert len(flat_grad_out.shape) == 1
# dy/dx = 2
@ -55,10 +55,10 @@ def test_trpo_warnings():
# Only 1 step: advantage normalization will return NaN
with pytest.raises(AssertionError):
TRPO("MlpPolicy", "Pendulum-v0", n_steps=1)
TRPO("MlpPolicy", "Pendulum-v1", n_steps=1)
# One step not advantage normalization: ok
TRPO("MlpPolicy", "Pendulum-v0", n_steps=1, normalize_advantage=False, batch_size=1)
TRPO("MlpPolicy", "Pendulum-v1", n_steps=1, normalize_advantage=False, batch_size=1)
# Truncated mini-batch
with pytest.warns(UserWarning):
TRPO("MlpPolicy", "Pendulum-v0", n_steps=6, batch_size=8)
TRPO("MlpPolicy", "Pendulum-v1", n_steps=6, batch_size=8)

View File

@ -32,7 +32,7 @@ def check_time_feature(obs, timestep, max_timesteps):
def test_time_feature():
env = gym.make("Pendulum-v0")
env = gym.make("Pendulum-v1")
env = TimeFeatureWrapper(env)
check_env(env, warn=False)
# Check for four episodes
@ -58,7 +58,7 @@ def test_time_feature():
check_time_feature(obs["observation"], timestep=1, max_timesteps=500)
# In test mode, the time feature must be constant
env = gym.make("Pendulum-v0")
env = gym.make("Pendulum-v1")
env = TimeFeatureWrapper(env, test_mode=True)
obs = env.reset()
check_time_feature(obs, timestep=0, max_timesteps=200)