Upgrade Gym to 0.21 (#59)
* Pendulum-v0 -> Pendulum-v1 * Reformat with black * Update changelog * Fix dtype bug in TimeFeatureWrapper * Update version and removed forward calls * Update CI * Fix min version Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
a78891bd00
commit
901a648507
|
|
@ -36,8 +36,6 @@ jobs:
|
|||
pip install .
|
||||
# Use headless version
|
||||
pip install opencv-python-headless
|
||||
# Tmp fix: ROM missing in the newest atari-py version
|
||||
pip install atari-py==0.2.5
|
||||
- name: Build the doc
|
||||
run: |
|
||||
make doc
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
|
|||
|
||||
from sb3_contrib import TQC
|
||||
|
||||
model = TQC("MlpPolicy", "Pendulum-v0", top_quantiles_to_drop_per_net=2, verbose=1)
|
||||
model = TQC("MlpPolicy", "Pendulum-v1", top_quantiles_to_drop_per_net=2, verbose=1)
|
||||
model.learn(total_timesteps=10_000, log_interval=4)
|
||||
model.save("tqc_pendulum")
|
||||
|
||||
|
|
@ -54,7 +54,7 @@ Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environmen
|
|||
|
||||
from sb3_contrib import TRPO
|
||||
|
||||
model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1)
|
||||
model = TRPO("MlpPolicy", "Pendulum-v1", gamma=0.9, verbose=1)
|
||||
model.learn(total_timesteps=100_000, log_interval=4)
|
||||
model.save("trpo_pendulum")
|
||||
|
||||
|
|
@ -68,6 +68,6 @@ Train an agent using Augmented Random Search (ARS) agent on the Pendulum environ
|
|||
|
||||
from sb3_contrib import ARS
|
||||
|
||||
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1)
|
||||
model = ARS("LinearPolicy", "Pendulum-v1", verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.save("ars_pendulum")
|
||||
|
|
|
|||
|
|
@ -3,6 +3,27 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.4.1a1 (WIP)
|
||||
-------------------------------
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Switched minimum Gym version to 0.21.0.
|
||||
- Upgraded to Stable-Baselines3 >= 1.4.1a1
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Removed explict calls to ``forward()`` method as per pytorch guidelines
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
Release 1.4.0 (2022-01-19)
|
||||
-------------------------------
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ Example
|
|||
from sb3_contrib import ARS
|
||||
|
||||
# Policy can be LinearPolicy or MlpPolicy
|
||||
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1)
|
||||
model = ARS("LinearPolicy", "Pendulum-v1", verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.save("ars_pendulum")
|
||||
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ Example
|
|||
|
||||
from sb3_contrib import TQC
|
||||
|
||||
env = gym.make("Pendulum-v0")
|
||||
env = gym.make("Pendulum-v1")
|
||||
|
||||
policy_kwargs = dict(n_critics=2, n_quantiles=25)
|
||||
model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs)
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ Example
|
|||
|
||||
from sb3_contrib import TRPO
|
||||
|
||||
env = gym.make("Pendulum-v0")
|
||||
env = gym.make("Pendulum-v1")
|
||||
|
||||
model = TRPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class ARSPolicy(BasePolicy):
|
|||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
# Non deterministic action does not really make sense for ARS, we ignore this parameter for now..
|
||||
return self.forward(observation)
|
||||
return self(observation)
|
||||
|
||||
|
||||
class ARSLinearPolicy(ARSPolicy):
|
||||
|
|
|
|||
|
|
@ -43,11 +43,12 @@ class TimeFeatureWrapper(gym.Wrapper):
|
|||
|
||||
low, high = obs_space.low, obs_space.high
|
||||
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
|
||||
self.dtype = obs_space.dtype
|
||||
|
||||
if isinstance(env.observation_space, gym.spaces.Dict):
|
||||
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=np.float32)
|
||||
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
|
||||
else:
|
||||
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)
|
||||
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
|
||||
|
||||
super(TimeFeatureWrapper, self).__init__(env)
|
||||
|
||||
|
|
@ -84,6 +85,7 @@ class TimeFeatureWrapper(gym.Wrapper):
|
|||
time_feature = 1 - (self._current_step / self._max_steps)
|
||||
if self._test_mode:
|
||||
time_feature = 1.0
|
||||
time_feature = np.array(time_feature, dtype=self.dtype)
|
||||
|
||||
if isinstance(obs, dict):
|
||||
obs["observation"] = np.append(obs["observation"], time_feature)
|
||||
|
|
|
|||
|
|
@ -317,7 +317,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
if use_masking:
|
||||
action_masks = get_action_masks(env)
|
||||
|
||||
actions, values, log_probs = self.policy.forward(obs_tensor, action_masks=action_masks)
|
||||
actions, values, log_probs = self.policy(obs_tensor, action_masks=action_masks)
|
||||
|
||||
actions = actions.cpu().numpy()
|
||||
new_obs, rewards, dones, infos = env.step(actions)
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class QuantileNetwork(BasePolicy):
|
|||
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
q_values = self.forward(observation).mean(dim=1)
|
||||
q_values = self(observation).mean(dim=1)
|
||||
# Greedy action
|
||||
action = q_values.argmax(dim=1).reshape(-1)
|
||||
return action
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ class Actor(BasePolicy):
|
|||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self.forward(observation, deterministic)
|
||||
return self(observation, deterministic)
|
||||
|
||||
|
||||
class Critic(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.4.0
|
||||
1.4.1a1
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -63,7 +63,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.4.0",
|
||||
"stable_baselines3>=1.4.1a1",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ def test_deterministic_training_common(algo):
|
|||
rewards = [[], []]
|
||||
# Smaller network
|
||||
kwargs = {"policy_kwargs": dict(net_arch=[64])}
|
||||
env_id = "Pendulum-v0"
|
||||
env_id = "Pendulum-v1"
|
||||
if algo == ARS_MULTI:
|
||||
algo = ARS
|
||||
ars_multi = True
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from sb3_contrib.common.vec_env import AsyncEval
|
|||
def test_tqc(ent_coef):
|
||||
model = TQC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
"Pendulum-v1",
|
||||
policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100,
|
||||
verbose=1,
|
||||
|
|
@ -26,7 +26,7 @@ def test_n_critics(n_critics):
|
|||
# Test TQC with different number of critics
|
||||
model = TQC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
"Pendulum-v1",
|
||||
policy_kwargs=dict(net_arch=[64], n_critics=n_critics),
|
||||
learning_starts=100,
|
||||
verbose=1,
|
||||
|
|
@ -37,7 +37,7 @@ def test_n_critics(n_critics):
|
|||
def test_sde():
|
||||
model = TQC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
"Pendulum-v1",
|
||||
policy_kwargs=dict(net_arch=[64]),
|
||||
use_sde=True,
|
||||
learning_starts=100,
|
||||
|
|
@ -62,7 +62,7 @@ def test_qrdqn():
|
|||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
||||
def test_trpo(env_id):
|
||||
model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||
model.learn(total_timesteps=500)
|
||||
|
|
@ -72,7 +72,7 @@ def test_trpo_params():
|
|||
# Test with gSDE and subsampling
|
||||
model = TRPO(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
"Pendulum-v1",
|
||||
n_steps=64,
|
||||
batch_size=32,
|
||||
use_sde=True,
|
||||
|
|
@ -84,7 +84,7 @@ def test_trpo_params():
|
|||
model.learn(total_timesteps=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
||||
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
|
||||
def test_ars(policy_str, env_id):
|
||||
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
|
||||
|
|
@ -92,14 +92,14 @@ def test_ars(policy_str, env_id):
|
|||
|
||||
|
||||
def test_ars_multi_env():
|
||||
env = make_vec_env("Pendulum-v0", n_envs=2)
|
||||
env = make_vec_env("Pendulum-v1", n_envs=2)
|
||||
model = ARS("MlpPolicy", env, n_delta=1)
|
||||
model.learn(total_timesteps=250)
|
||||
|
||||
env = VecNormalize(make_vec_env("Pendulum-v0", n_envs=1))
|
||||
env = VecNormalize(make_vec_env("Pendulum-v1", n_envs=1))
|
||||
model = ARS("MlpPolicy", env, n_delta=2, seed=0)
|
||||
# with parallelism
|
||||
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v0", n_envs=1)) for _ in range(2)], model.policy)
|
||||
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) for _ in range(2)], model.policy)
|
||||
async_eval.seed(0)
|
||||
model.learn(500, async_eval=async_eval)
|
||||
|
||||
|
|
@ -109,17 +109,17 @@ def test_ars_n_top(n_top):
|
|||
n_delta = 3
|
||||
if n_top > n_delta:
|
||||
with pytest.warns(UserWarning):
|
||||
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top)
|
||||
model = ARS("MlpPolicy", "Pendulum-v1", n_delta=n_delta, n_top=n_top)
|
||||
model.learn(total_timesteps=500)
|
||||
else:
|
||||
model = ARS("MlpPolicy", "Pendulum-v0", n_delta=n_delta, n_top=n_top)
|
||||
model = ARS("MlpPolicy", "Pendulum-v1", n_delta=n_delta, n_top=n_top)
|
||||
model.learn(total_timesteps=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TQC, QRDQN])
|
||||
def test_offpolicy_multi_env(model_class):
|
||||
if model_class in [TQC]:
|
||||
env_id = "Pendulum-v0"
|
||||
env_id = "Pendulum-v1"
|
||||
policy_kwargs = dict(net_arch=[64], n_critics=1)
|
||||
else:
|
||||
env_id = "CartPole-v1"
|
||||
|
|
|
|||
|
|
@ -447,7 +447,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
|
|||
|
||||
|
||||
def test_save_load_pytorch_var(tmp_path):
|
||||
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
|
||||
model = TQC("MlpPolicy", "Pendulum-v1", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
|
||||
model.learn(200)
|
||||
save_path = str(tmp_path / "tqc_pendulum")
|
||||
model.save(save_path)
|
||||
|
|
@ -464,7 +464,7 @@ def test_save_load_pytorch_var(tmp_path):
|
|||
assert not th.allclose(log_ent_coef_before, log_ent_coef_after)
|
||||
|
||||
# With a fixed entropy coef
|
||||
model = TQC("MlpPolicy", "Pendulum-v0", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
|
||||
model = TQC("MlpPolicy", "Pendulum-v1", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
|
||||
model.learn(200)
|
||||
save_path = str(tmp_path / "tqc_pendulum")
|
||||
model.save(save_path)
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ def test_qrdqn_train_with_batch_norm():
|
|||
def test_tqc_train_with_batch_norm():
|
||||
model = TQC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
"Pendulum-v1",
|
||||
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
|
||||
learning_starts=0,
|
||||
tau=0, # do not copy the target
|
||||
|
|
@ -203,7 +203,7 @@ def test_offpolicy_collect_rollout_batch_norm(model_class):
|
|||
if model_class in [QRDQN]:
|
||||
env_id = "CartPole-v1"
|
||||
else:
|
||||
env_id = "Pendulum-v0"
|
||||
env_id = "Pendulum-v1"
|
||||
|
||||
clone_helper = CLONE_HELPERS[model_class]
|
||||
|
||||
|
|
@ -230,7 +230,7 @@ def test_offpolicy_collect_rollout_batch_norm(model_class):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
||||
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
|
||||
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
|
||||
def test_predict_with_dropout_batch_norm(model_class, env_id):
|
||||
if env_id == "CartPole-v1":
|
||||
if model_class in [TQC]:
|
||||
|
|
|
|||
|
|
@ -55,10 +55,10 @@ def test_trpo_warnings():
|
|||
|
||||
# Only 1 step: advantage normalization will return NaN
|
||||
with pytest.raises(AssertionError):
|
||||
TRPO("MlpPolicy", "Pendulum-v0", n_steps=1)
|
||||
TRPO("MlpPolicy", "Pendulum-v1", n_steps=1)
|
||||
# One step not advantage normalization: ok
|
||||
TRPO("MlpPolicy", "Pendulum-v0", n_steps=1, normalize_advantage=False, batch_size=1)
|
||||
TRPO("MlpPolicy", "Pendulum-v1", n_steps=1, normalize_advantage=False, batch_size=1)
|
||||
|
||||
# Truncated mini-batch
|
||||
with pytest.warns(UserWarning):
|
||||
TRPO("MlpPolicy", "Pendulum-v0", n_steps=6, batch_size=8)
|
||||
TRPO("MlpPolicy", "Pendulum-v1", n_steps=6, batch_size=8)
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def check_time_feature(obs, timestep, max_timesteps):
|
|||
|
||||
|
||||
def test_time_feature():
|
||||
env = gym.make("Pendulum-v0")
|
||||
env = gym.make("Pendulum-v1")
|
||||
env = TimeFeatureWrapper(env)
|
||||
check_env(env, warn=False)
|
||||
# Check for four episodes
|
||||
|
|
@ -58,7 +58,7 @@ def test_time_feature():
|
|||
check_time_feature(obs["observation"], timestep=1, max_timesteps=500)
|
||||
|
||||
# In test mode, the time feature must be constant
|
||||
env = gym.make("Pendulum-v0")
|
||||
env = gym.make("Pendulum-v1")
|
||||
env = TimeFeatureWrapper(env, test_mode=True)
|
||||
obs = env.reset()
|
||||
check_time_feature(obs, timestep=0, max_timesteps=200)
|
||||
|
|
|
|||
Loading…
Reference in New Issue