Update doc: switch from Gym to Gymnasium (#182)

* Update documentation to switch to Gymnasium

* Fix custom env type hints

* Fix type check
This commit is contained in:
Antonin RAFFIN 2023-05-10 11:40:40 +02:00 committed by GitHub
parent 21cc96cafd
commit 86fb056fda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 48 additions and 49 deletions

View File

@ -3,7 +3,7 @@
Gym Wrappers
============
Additional `Gym Wrappers <https://github.com/openai/gym/tree/master/gym/wrappers>`_ to enhance Gym environments.
Additional `Gymnasium Wrappers <https://gymnasium.farama.org/api/wrappers/>`_ to enhance Gymnasium environments.
.. automodule:: sb3_contrib.common.wrappers

View File

@ -19,7 +19,7 @@ TRPO ✔️ ✔️ ✔️ ✔️
.. note::
``Tuple`` observation spaces are not supported by any environment,
however, single-level ``Dict`` spaces are
however, single-level ``Dict`` spaces are supported.
Actions ``gym.spaces``:

View File

@ -94,15 +94,16 @@ Train a PPO agent with a recurrent policy on the CartPole environment.
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
model.learn(5000)
env = model.get_env()
obs = env.reset()
# cell and hidden state of the LSTM
vec_env = model.get_env()
obs = vec_env.reset()
# Cell and hidden state of the LSTM
lstm_states = None
num_envs = 1
# Episode start signals are used to reset the lstm states
episode_starts = np.ones((num_envs,), dtype=bool)
while True:
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
obs, rewards, dones, info = env.step(action)
# Note: vectorized environment resets automatically
obs, rewards, dones, info = vec_env.step(action)
episode_starts = dones
env.render()
vec_env.render("human")

View File

@ -3,7 +3,7 @@
Changelog
==========
Release 2.0.0a4 (WIP)
Release 2.0.0a8 (WIP)
--------------------------
**Gymnasium support**
@ -33,11 +33,12 @@ Others:
^^^^^^^
- Fixed ``sb3_contrib/tqc/*.py`` type hints
- Fixed ``sb3_contrib/trpo/*.py`` type hints
- Fixed ``sb3_contrib/common/envs/invalid_actions_env.py`` type hints
Documentation:
^^^^^^^^^^^^^^
- Update documentation, switch from Gym to Gymnasium
Release 1.8.0 (2023-04-07)
--------------------------

View File

@ -62,7 +62,7 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, verbose=1)
model.learn(5000)
model.learn(5_000)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
@ -71,13 +71,12 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
model = MaskablePPO.load("ppo_mask")
obs = env.reset()
obs, _ = env.reset()
while True:
# Retrieve current action mask
action_masks = get_action_masks(env)
action, _states = model.predict(obs, action_masks=action_masks)
obs, rewards, dones, info = env.step(action)
env.render()
obs, reward, terminated, truncated, info = env.step(action)
If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker``

View File

@ -64,8 +64,8 @@ Example
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
model.learn(5000)
env = model.get_env()
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False)
vec_env = model.get_env()
mean_reward, std_reward = evaluate_policy(model, vec_env, n_eval_episodes=20, warn=False)
print(mean_reward)
model.save("ppo_recurrent")
@ -73,7 +73,7 @@ Example
model = RecurrentPPO.load("ppo_recurrent")
obs = env.reset()
obs = vec_env.reset()
# cell and hidden state of the LSTM
lstm_states = None
num_envs = 1
@ -81,9 +81,9 @@ Example
episode_starts = np.ones((num_envs,), dtype=bool)
while True:
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
obs, rewards, dones, info = env.step(action)
obs, rewards, dones, info = vec_env.step(action)
episode_starts = dones
env.render()
vec_env.render("human")

View File

@ -57,24 +57,24 @@ Example
from sb3_contrib import QRDQN
env = gym.make("CartPole-v1")
env = gym.make("CartPole-v1", render_mode="human")
policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.learn(total_timesteps=10_000, log_interval=4)
model.save("qrdqn_cartpole")
del model # remove to demonstrate saving and loading
model = QRDQN.load("qrdqn_cartpole")
obs = env.reset()
obs, _ = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
obs, reward, terminated, truncated, info = env.step(action)
env.render()
if done:
obs = env.reset()
if terminated or truncated:
obs, _ = env.reset()
Results

View File

@ -57,24 +57,24 @@ Example
from sb3_contrib import TQC
env = gym.make("Pendulum-v1")
env = gym.make("Pendulum-v1", render_mode="human")
policy_kwargs = dict(n_critics=2, n_quantiles=25)
model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs)
model.learn(total_timesteps=10000, log_interval=4)
model.learn(total_timesteps=10_000, log_interval=4)
model.save("tqc_pendulum")
del model # remove to demonstrate saving and loading
model = TQC.load("tqc_pendulum")
obs = env.reset()
obs, _ = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
obs, reward, terminated, truncated, info = env.step(action)
env.render()
if done:
obs = env.reset()
if terminated or truncated:
obs, _ = env.reset()
Results

View File

@ -54,23 +54,23 @@ Example
from sb3_contrib import TRPO
env = gym.make("Pendulum-v1")
env = gym.make("Pendulum-v1", render_mode="human")
model = TRPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.learn(total_timesteps=10_000, log_interval=4)
model.save("trpo_pendulum")
del model # remove to demonstrate saving and loading
model = TRPO.load("trpo_pendulum")
obs = env.reset()
obs, _ = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
obs, reward, terminated, truncated, info = env.step(action)
env.render()
if done:
obs = env.reset()
if terminated or truncated:
obs, _ = env.reset()
Results

View File

@ -35,7 +35,6 @@ exclude = """(?x)(
| sb3_contrib/common/maskable/callbacks.py$
| sb3_contrib/common/maskable/policies.py$
| sb3_contrib/common/maskable/buffers.py$
| sb3_contrib/common/envs/invalid_actions_env.py$
| sb3_contrib/common/vec_env/async_eval.py$
| sb3_contrib/ppo_mask/ppo_mask.py$
| tests/test_train_eval_mode.py$
@ -50,10 +49,6 @@ env = [
filterwarnings = [
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# Gym warnings
"ignore:Parameters to load are deprecated.:DeprecationWarning",
"ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning",
"ignore::UserWarning:gym",
]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')"

View File

@ -5,7 +5,7 @@ from gymnasium import spaces
from stable_baselines3.common.envs import IdentityEnv
class InvalidActionEnvDiscrete(IdentityEnv):
class InvalidActionEnvDiscrete(IdentityEnv[int]):
"""
Identity env with a discrete action space. Supports action masking.
"""
@ -30,17 +30,19 @@ class InvalidActionEnvDiscrete(IdentityEnv):
self.state = self.action_space.sample()
# Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i != self.state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist()
def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions]
class InvalidActionEnvMultiDiscrete(IdentityEnv):
class InvalidActionEnvMultiDiscrete(IdentityEnv[np.ndarray]):
"""
Identity env with a multidiscrete action space. Supports action masking.
"""
action_space: spaces.MultiDiscrete
def __init__(
self,
dims: Optional[List[int]] = None,
@ -70,13 +72,13 @@ class InvalidActionEnvMultiDiscrete(IdentityEnv):
# Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist()
def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions]
class InvalidActionEnvMultiBinary(IdentityEnv):
class InvalidActionEnvMultiBinary(IdentityEnv[np.ndarray]):
"""
Identity env with a multibinary action space. Supports action masking.
"""
@ -94,6 +96,7 @@ class InvalidActionEnvMultiBinary(IdentityEnv):
raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}")
space = spaces.MultiBinary(dims)
self.n_dims = dims
self.n_invalid_actions = n_invalid_actions
self.possible_actions = np.arange(2 * dims)
self.invalid_actions: List[int] = []
@ -104,13 +107,13 @@ class InvalidActionEnvMultiBinary(IdentityEnv):
converted_state: List[int] = []
running_total = 0
for i in range(self.action_space.n):
for i in range(self.n_dims):
converted_state.append(running_total + self.state[i])
running_total += 2
# Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist()
def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions]

View File

@ -1 +1 @@
2.0.0a4
2.0.0a8