From 86fb056fdaaca0fa5490529759a37e885007d7e6 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 10 May 2023 11:40:40 +0200 Subject: [PATCH] Update doc: switch from Gym to Gymnasium (#182) * Update documentation to switch to Gymnasium * Fix custom env type hints * Fix type check --- docs/common/wrappers.rst | 2 +- docs/guide/algos.rst | 2 +- docs/guide/examples.rst | 11 ++++++----- docs/misc/changelog.rst | 5 +++-- docs/modules/ppo_mask.rst | 7 +++---- docs/modules/ppo_recurrent.rst | 10 +++++----- docs/modules/qrdqn.rst | 12 ++++++------ docs/modules/tqc.rst | 12 ++++++------ docs/modules/trpo.rst | 12 ++++++------ pyproject.toml | 5 ----- sb3_contrib/common/envs/invalid_actions_env.py | 17 ++++++++++------- sb3_contrib/version.txt | 2 +- 12 files changed, 48 insertions(+), 49 deletions(-) diff --git a/docs/common/wrappers.rst b/docs/common/wrappers.rst index cc2f11e..bf905d2 100644 --- a/docs/common/wrappers.rst +++ b/docs/common/wrappers.rst @@ -3,7 +3,7 @@ Gym Wrappers ============ -Additional `Gym Wrappers `_ to enhance Gym environments. +Additional `Gymnasium Wrappers `_ to enhance Gymnasium environments. .. automodule:: sb3_contrib.common.wrappers diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 234e6f8..8181382 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -19,7 +19,7 @@ TRPO ✔️ ✔️ ✔️ ✔️ .. note:: ``Tuple`` observation spaces are not supported by any environment, - however, single-level ``Dict`` spaces are + however, single-level ``Dict`` spaces are supported. Actions ``gym.spaces``: diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index cd4851d..de1f946 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -94,15 +94,16 @@ Train a PPO agent with a recurrent policy on the CartPole environment. model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1) model.learn(5000) - env = model.get_env() - obs = env.reset() - # cell and hidden state of the LSTM + vec_env = model.get_env() + obs = vec_env.reset() + # Cell and hidden state of the LSTM lstm_states = None num_envs = 1 # Episode start signals are used to reset the lstm states episode_starts = np.ones((num_envs,), dtype=bool) while True: action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True) - obs, rewards, dones, info = env.step(action) + # Note: vectorized environment resets automatically + obs, rewards, dones, info = vec_env.step(action) episode_starts = dones - env.render() + vec_env.render("human") diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index bd6973e..b54eb65 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.0.0a4 (WIP) +Release 2.0.0a8 (WIP) -------------------------- **Gymnasium support** @@ -33,11 +33,12 @@ Others: ^^^^^^^ - Fixed ``sb3_contrib/tqc/*.py`` type hints - Fixed ``sb3_contrib/trpo/*.py`` type hints +- Fixed ``sb3_contrib/common/envs/invalid_actions_env.py`` type hints Documentation: ^^^^^^^^^^^^^^ - +- Update documentation, switch from Gym to Gymnasium Release 1.8.0 (2023-04-07) -------------------------- diff --git a/docs/modules/ppo_mask.rst b/docs/modules/ppo_mask.rst index 027c89a..89d14d1 100644 --- a/docs/modules/ppo_mask.rst +++ b/docs/modules/ppo_mask.rst @@ -62,7 +62,7 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60) model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, verbose=1) - model.learn(5000) + model.learn(5_000) evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False) @@ -71,13 +71,12 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe model = MaskablePPO.load("ppo_mask") - obs = env.reset() + obs, _ = env.reset() while True: # Retrieve current action mask action_masks = get_action_masks(env) action, _states = model.predict(obs, action_masks=action_masks) - obs, rewards, dones, info = env.step(action) - env.render() + obs, reward, terminated, truncated, info = env.step(action) If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker`` diff --git a/docs/modules/ppo_recurrent.rst b/docs/modules/ppo_recurrent.rst index 54d9bdd..e62c2e2 100644 --- a/docs/modules/ppo_recurrent.rst +++ b/docs/modules/ppo_recurrent.rst @@ -64,8 +64,8 @@ Example model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1) model.learn(5000) - env = model.get_env() - mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False) + vec_env = model.get_env() + mean_reward, std_reward = evaluate_policy(model, vec_env, n_eval_episodes=20, warn=False) print(mean_reward) model.save("ppo_recurrent") @@ -73,7 +73,7 @@ Example model = RecurrentPPO.load("ppo_recurrent") - obs = env.reset() + obs = vec_env.reset() # cell and hidden state of the LSTM lstm_states = None num_envs = 1 @@ -81,9 +81,9 @@ Example episode_starts = np.ones((num_envs,), dtype=bool) while True: action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True) - obs, rewards, dones, info = env.step(action) + obs, rewards, dones, info = vec_env.step(action) episode_starts = dones - env.render() + vec_env.render("human") diff --git a/docs/modules/qrdqn.rst b/docs/modules/qrdqn.rst index 2bd59cd..2a913e1 100644 --- a/docs/modules/qrdqn.rst +++ b/docs/modules/qrdqn.rst @@ -57,24 +57,24 @@ Example from sb3_contrib import QRDQN - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v1", render_mode="human") policy_kwargs = dict(n_quantiles=50) model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1) - model.learn(total_timesteps=10000, log_interval=4) + model.learn(total_timesteps=10_000, log_interval=4) model.save("qrdqn_cartpole") del model # remove to demonstrate saving and loading model = QRDQN.load("qrdqn_cartpole") - obs = env.reset() + obs, _ = env.reset() while True: action, _states = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) env.render() - if done: - obs = env.reset() + if terminated or truncated: + obs, _ = env.reset() Results diff --git a/docs/modules/tqc.rst b/docs/modules/tqc.rst index 71aa311..72a4003 100644 --- a/docs/modules/tqc.rst +++ b/docs/modules/tqc.rst @@ -57,24 +57,24 @@ Example from sb3_contrib import TQC - env = gym.make("Pendulum-v1") + env = gym.make("Pendulum-v1", render_mode="human") policy_kwargs = dict(n_critics=2, n_quantiles=25) model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs) - model.learn(total_timesteps=10000, log_interval=4) + model.learn(total_timesteps=10_000, log_interval=4) model.save("tqc_pendulum") del model # remove to demonstrate saving and loading model = TQC.load("tqc_pendulum") - obs = env.reset() + obs, _ = env.reset() while True: action, _states = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) env.render() - if done: - obs = env.reset() + if terminated or truncated: + obs, _ = env.reset() Results diff --git a/docs/modules/trpo.rst b/docs/modules/trpo.rst index 2a12dad..0a7e985 100644 --- a/docs/modules/trpo.rst +++ b/docs/modules/trpo.rst @@ -54,23 +54,23 @@ Example from sb3_contrib import TRPO - env = gym.make("Pendulum-v1") + env = gym.make("Pendulum-v1", render_mode="human") model = TRPO("MlpPolicy", env, verbose=1) - model.learn(total_timesteps=10000, log_interval=4) + model.learn(total_timesteps=10_000, log_interval=4) model.save("trpo_pendulum") del model # remove to demonstrate saving and loading model = TRPO.load("trpo_pendulum") - obs = env.reset() + obs, _ = env.reset() while True: action, _states = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) env.render() - if done: - obs = env.reset() + if terminated or truncated: + obs, _ = env.reset() Results diff --git a/pyproject.toml b/pyproject.toml index cff637f..2a2913e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ exclude = """(?x)( | sb3_contrib/common/maskable/callbacks.py$ | sb3_contrib/common/maskable/policies.py$ | sb3_contrib/common/maskable/buffers.py$ - | sb3_contrib/common/envs/invalid_actions_env.py$ | sb3_contrib/common/vec_env/async_eval.py$ | sb3_contrib/ppo_mask/ppo_mask.py$ | tests/test_train_eval_mode.py$ @@ -50,10 +49,6 @@ env = [ filterwarnings = [ # Tensorboard warnings "ignore::DeprecationWarning:tensorboard", - # Gym warnings - "ignore:Parameters to load are deprecated.:DeprecationWarning", - "ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning", - "ignore::UserWarning:gym", ] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')" diff --git a/sb3_contrib/common/envs/invalid_actions_env.py b/sb3_contrib/common/envs/invalid_actions_env.py index c1fe753..8306897 100644 --- a/sb3_contrib/common/envs/invalid_actions_env.py +++ b/sb3_contrib/common/envs/invalid_actions_env.py @@ -5,7 +5,7 @@ from gymnasium import spaces from stable_baselines3.common.envs import IdentityEnv -class InvalidActionEnvDiscrete(IdentityEnv): +class InvalidActionEnvDiscrete(IdentityEnv[int]): """ Identity env with a discrete action space. Supports action masking. """ @@ -30,17 +30,19 @@ class InvalidActionEnvDiscrete(IdentityEnv): self.state = self.action_space.sample() # Randomly choose invalid actions that are not the current state potential_invalid_actions = [i for i in self.possible_actions if i != self.state] - self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) + self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist() def action_masks(self) -> List[bool]: return [action not in self.invalid_actions for action in self.possible_actions] -class InvalidActionEnvMultiDiscrete(IdentityEnv): +class InvalidActionEnvMultiDiscrete(IdentityEnv[np.ndarray]): """ Identity env with a multidiscrete action space. Supports action masking. """ + action_space: spaces.MultiDiscrete + def __init__( self, dims: Optional[List[int]] = None, @@ -70,13 +72,13 @@ class InvalidActionEnvMultiDiscrete(IdentityEnv): # Randomly choose invalid actions that are not the current state potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state] - self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) + self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist() def action_masks(self) -> List[bool]: return [action not in self.invalid_actions for action in self.possible_actions] -class InvalidActionEnvMultiBinary(IdentityEnv): +class InvalidActionEnvMultiBinary(IdentityEnv[np.ndarray]): """ Identity env with a multibinary action space. Supports action masking. """ @@ -94,6 +96,7 @@ class InvalidActionEnvMultiBinary(IdentityEnv): raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}") space = spaces.MultiBinary(dims) + self.n_dims = dims self.n_invalid_actions = n_invalid_actions self.possible_actions = np.arange(2 * dims) self.invalid_actions: List[int] = [] @@ -104,13 +107,13 @@ class InvalidActionEnvMultiBinary(IdentityEnv): converted_state: List[int] = [] running_total = 0 - for i in range(self.action_space.n): + for i in range(self.n_dims): converted_state.append(running_total + self.state[i]) running_total += 2 # Randomly choose invalid actions that are not the current state potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state] - self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False) + self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist() def action_masks(self) -> List[bool]: return [action not in self.invalid_actions for action in self.possible_actions] diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 997bba2..55a349a 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.0.0a4 +2.0.0a8