Update doc: switch from Gym to Gymnasium (#182)

* Update documentation to switch to Gymnasium

* Fix custom env type hints

* Fix type check
This commit is contained in:
Antonin RAFFIN 2023-05-10 11:40:40 +02:00 committed by GitHub
parent 21cc96cafd
commit 86fb056fda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 48 additions and 49 deletions

View File

@ -3,7 +3,7 @@
Gym Wrappers Gym Wrappers
============ ============
Additional `Gym Wrappers <https://github.com/openai/gym/tree/master/gym/wrappers>`_ to enhance Gym environments. Additional `Gymnasium Wrappers <https://gymnasium.farama.org/api/wrappers/>`_ to enhance Gymnasium environments.
.. automodule:: sb3_contrib.common.wrappers .. automodule:: sb3_contrib.common.wrappers

View File

@ -19,7 +19,7 @@ TRPO ✔️ ✔️ ✔️ ✔️
.. note:: .. note::
``Tuple`` observation spaces are not supported by any environment, ``Tuple`` observation spaces are not supported by any environment,
however, single-level ``Dict`` spaces are however, single-level ``Dict`` spaces are supported.
Actions ``gym.spaces``: Actions ``gym.spaces``:

View File

@ -94,15 +94,16 @@ Train a PPO agent with a recurrent policy on the CartPole environment.
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1) model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
model.learn(5000) model.learn(5000)
env = model.get_env() vec_env = model.get_env()
obs = env.reset() obs = vec_env.reset()
# cell and hidden state of the LSTM # Cell and hidden state of the LSTM
lstm_states = None lstm_states = None
num_envs = 1 num_envs = 1
# Episode start signals are used to reset the lstm states # Episode start signals are used to reset the lstm states
episode_starts = np.ones((num_envs,), dtype=bool) episode_starts = np.ones((num_envs,), dtype=bool)
while True: while True:
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True) action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
obs, rewards, dones, info = env.step(action) # Note: vectorized environment resets automatically
obs, rewards, dones, info = vec_env.step(action)
episode_starts = dones episode_starts = dones
env.render() vec_env.render("human")

View File

@ -3,7 +3,7 @@
Changelog Changelog
========== ==========
Release 2.0.0a4 (WIP) Release 2.0.0a8 (WIP)
-------------------------- --------------------------
**Gymnasium support** **Gymnasium support**
@ -33,11 +33,12 @@ Others:
^^^^^^^ ^^^^^^^
- Fixed ``sb3_contrib/tqc/*.py`` type hints - Fixed ``sb3_contrib/tqc/*.py`` type hints
- Fixed ``sb3_contrib/trpo/*.py`` type hints - Fixed ``sb3_contrib/trpo/*.py`` type hints
- Fixed ``sb3_contrib/common/envs/invalid_actions_env.py`` type hints
Documentation: Documentation:
^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^
- Update documentation, switch from Gym to Gymnasium
Release 1.8.0 (2023-04-07) Release 1.8.0 (2023-04-07)
-------------------------- --------------------------

View File

@ -62,7 +62,7 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60) env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, verbose=1) model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, verbose=1)
model.learn(5000) model.learn(5_000)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False) evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
@ -71,13 +71,12 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
model = MaskablePPO.load("ppo_mask") model = MaskablePPO.load("ppo_mask")
obs = env.reset() obs, _ = env.reset()
while True: while True:
# Retrieve current action mask # Retrieve current action mask
action_masks = get_action_masks(env) action_masks = get_action_masks(env)
action, _states = model.predict(obs, action_masks=action_masks) action, _states = model.predict(obs, action_masks=action_masks)
obs, rewards, dones, info = env.step(action) obs, reward, terminated, truncated, info = env.step(action)
env.render()
If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker`` If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker``

View File

@ -64,8 +64,8 @@ Example
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1) model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
model.learn(5000) model.learn(5000)
env = model.get_env() vec_env = model.get_env()
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False) mean_reward, std_reward = evaluate_policy(model, vec_env, n_eval_episodes=20, warn=False)
print(mean_reward) print(mean_reward)
model.save("ppo_recurrent") model.save("ppo_recurrent")
@ -73,7 +73,7 @@ Example
model = RecurrentPPO.load("ppo_recurrent") model = RecurrentPPO.load("ppo_recurrent")
obs = env.reset() obs = vec_env.reset()
# cell and hidden state of the LSTM # cell and hidden state of the LSTM
lstm_states = None lstm_states = None
num_envs = 1 num_envs = 1
@ -81,9 +81,9 @@ Example
episode_starts = np.ones((num_envs,), dtype=bool) episode_starts = np.ones((num_envs,), dtype=bool)
while True: while True:
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True) action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
obs, rewards, dones, info = env.step(action) obs, rewards, dones, info = vec_env.step(action)
episode_starts = dones episode_starts = dones
env.render() vec_env.render("human")

View File

@ -57,24 +57,24 @@ Example
from sb3_contrib import QRDQN from sb3_contrib import QRDQN
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1", render_mode="human")
policy_kwargs = dict(n_quantiles=50) policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1) model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000, log_interval=4) model.learn(total_timesteps=10_000, log_interval=4)
model.save("qrdqn_cartpole") model.save("qrdqn_cartpole")
del model # remove to demonstrate saving and loading del model # remove to demonstrate saving and loading
model = QRDQN.load("qrdqn_cartpole") model = QRDQN.load("qrdqn_cartpole")
obs = env.reset() obs, _ = env.reset()
while True: while True:
action, _states = model.predict(obs, deterministic=True) action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action) obs, reward, terminated, truncated, info = env.step(action)
env.render() env.render()
if done: if terminated or truncated:
obs = env.reset() obs, _ = env.reset()
Results Results

View File

@ -57,24 +57,24 @@ Example
from sb3_contrib import TQC from sb3_contrib import TQC
env = gym.make("Pendulum-v1") env = gym.make("Pendulum-v1", render_mode="human")
policy_kwargs = dict(n_critics=2, n_quantiles=25) policy_kwargs = dict(n_critics=2, n_quantiles=25)
model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs) model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs)
model.learn(total_timesteps=10000, log_interval=4) model.learn(total_timesteps=10_000, log_interval=4)
model.save("tqc_pendulum") model.save("tqc_pendulum")
del model # remove to demonstrate saving and loading del model # remove to demonstrate saving and loading
model = TQC.load("tqc_pendulum") model = TQC.load("tqc_pendulum")
obs = env.reset() obs, _ = env.reset()
while True: while True:
action, _states = model.predict(obs, deterministic=True) action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action) obs, reward, terminated, truncated, info = env.step(action)
env.render() env.render()
if done: if terminated or truncated:
obs = env.reset() obs, _ = env.reset()
Results Results

View File

@ -54,23 +54,23 @@ Example
from sb3_contrib import TRPO from sb3_contrib import TRPO
env = gym.make("Pendulum-v1") env = gym.make("Pendulum-v1", render_mode="human")
model = TRPO("MlpPolicy", env, verbose=1) model = TRPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4) model.learn(total_timesteps=10_000, log_interval=4)
model.save("trpo_pendulum") model.save("trpo_pendulum")
del model # remove to demonstrate saving and loading del model # remove to demonstrate saving and loading
model = TRPO.load("trpo_pendulum") model = TRPO.load("trpo_pendulum")
obs = env.reset() obs, _ = env.reset()
while True: while True:
action, _states = model.predict(obs, deterministic=True) action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action) obs, reward, terminated, truncated, info = env.step(action)
env.render() env.render()
if done: if terminated or truncated:
obs = env.reset() obs, _ = env.reset()
Results Results

View File

@ -35,7 +35,6 @@ exclude = """(?x)(
| sb3_contrib/common/maskable/callbacks.py$ | sb3_contrib/common/maskable/callbacks.py$
| sb3_contrib/common/maskable/policies.py$ | sb3_contrib/common/maskable/policies.py$
| sb3_contrib/common/maskable/buffers.py$ | sb3_contrib/common/maskable/buffers.py$
| sb3_contrib/common/envs/invalid_actions_env.py$
| sb3_contrib/common/vec_env/async_eval.py$ | sb3_contrib/common/vec_env/async_eval.py$
| sb3_contrib/ppo_mask/ppo_mask.py$ | sb3_contrib/ppo_mask/ppo_mask.py$
| tests/test_train_eval_mode.py$ | tests/test_train_eval_mode.py$
@ -50,10 +49,6 @@ env = [
filterwarnings = [ filterwarnings = [
# Tensorboard warnings # Tensorboard warnings
"ignore::DeprecationWarning:tensorboard", "ignore::DeprecationWarning:tensorboard",
# Gym warnings
"ignore:Parameters to load are deprecated.:DeprecationWarning",
"ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning",
"ignore::UserWarning:gym",
] ]
markers = [ markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')" "slow: marks tests as slow (deselect with '-m \"not slow\"')"

View File

@ -5,7 +5,7 @@ from gymnasium import spaces
from stable_baselines3.common.envs import IdentityEnv from stable_baselines3.common.envs import IdentityEnv
class InvalidActionEnvDiscrete(IdentityEnv): class InvalidActionEnvDiscrete(IdentityEnv[int]):
""" """
Identity env with a discrete action space. Supports action masking. Identity env with a discrete action space. Supports action masking.
""" """
@ -30,17 +30,19 @@ class InvalidActionEnvDiscrete(IdentityEnv):
self.state = self.action_space.sample() self.state = self.action_space.sample()
# Randomly choose invalid actions that are not the current state # Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i != self.state] potential_invalid_actions = [i for i in self.possible_actions if i != self.state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist()
def action_masks(self) -> List[bool]: def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions] return [action not in self.invalid_actions for action in self.possible_actions]
class InvalidActionEnvMultiDiscrete(IdentityEnv): class InvalidActionEnvMultiDiscrete(IdentityEnv[np.ndarray]):
""" """
Identity env with a multidiscrete action space. Supports action masking. Identity env with a multidiscrete action space. Supports action masking.
""" """
action_space: spaces.MultiDiscrete
def __init__( def __init__(
self, self,
dims: Optional[List[int]] = None, dims: Optional[List[int]] = None,
@ -70,13 +72,13 @@ class InvalidActionEnvMultiDiscrete(IdentityEnv):
# Randomly choose invalid actions that are not the current state # Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state] potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist()
def action_masks(self) -> List[bool]: def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions] return [action not in self.invalid_actions for action in self.possible_actions]
class InvalidActionEnvMultiBinary(IdentityEnv): class InvalidActionEnvMultiBinary(IdentityEnv[np.ndarray]):
""" """
Identity env with a multibinary action space. Supports action masking. Identity env with a multibinary action space. Supports action masking.
""" """
@ -94,6 +96,7 @@ class InvalidActionEnvMultiBinary(IdentityEnv):
raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}") raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}")
space = spaces.MultiBinary(dims) space = spaces.MultiBinary(dims)
self.n_dims = dims
self.n_invalid_actions = n_invalid_actions self.n_invalid_actions = n_invalid_actions
self.possible_actions = np.arange(2 * dims) self.possible_actions = np.arange(2 * dims)
self.invalid_actions: List[int] = [] self.invalid_actions: List[int] = []
@ -104,13 +107,13 @@ class InvalidActionEnvMultiBinary(IdentityEnv):
converted_state: List[int] = [] converted_state: List[int] = []
running_total = 0 running_total = 0
for i in range(self.action_space.n): for i in range(self.n_dims):
converted_state.append(running_total + self.state[i]) converted_state.append(running_total + self.state[i])
running_total += 2 running_total += 2
# Randomly choose invalid actions that are not the current state # Randomly choose invalid actions that are not the current state
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state] potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist()
def action_masks(self) -> List[bool]: def action_masks(self) -> List[bool]:
return [action not in self.invalid_actions for action in self.possible_actions] return [action not in self.invalid_actions for action in self.possible_actions]

View File

@ -1 +1 @@
2.0.0a4 2.0.0a8