Update doc: switch from Gym to Gymnasium (#182)
* Update documentation to switch to Gymnasium * Fix custom env type hints * Fix type check
This commit is contained in:
parent
21cc96cafd
commit
86fb056fda
|
|
@ -3,7 +3,7 @@
|
|||
Gym Wrappers
|
||||
============
|
||||
|
||||
Additional `Gym Wrappers <https://github.com/openai/gym/tree/master/gym/wrappers>`_ to enhance Gym environments.
|
||||
Additional `Gymnasium Wrappers <https://gymnasium.farama.org/api/wrappers/>`_ to enhance Gymnasium environments.
|
||||
|
||||
.. automodule:: sb3_contrib.common.wrappers
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ TRPO ✔️ ✔️ ✔️ ✔️
|
|||
|
||||
.. note::
|
||||
``Tuple`` observation spaces are not supported by any environment,
|
||||
however, single-level ``Dict`` spaces are
|
||||
however, single-level ``Dict`` spaces are supported.
|
||||
|
||||
Actions ``gym.spaces``:
|
||||
|
||||
|
|
|
|||
|
|
@ -94,15 +94,16 @@ Train a PPO agent with a recurrent policy on the CartPole environment.
|
|||
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
|
||||
model.learn(5000)
|
||||
|
||||
env = model.get_env()
|
||||
obs = env.reset()
|
||||
# cell and hidden state of the LSTM
|
||||
vec_env = model.get_env()
|
||||
obs = vec_env.reset()
|
||||
# Cell and hidden state of the LSTM
|
||||
lstm_states = None
|
||||
num_envs = 1
|
||||
# Episode start signals are used to reset the lstm states
|
||||
episode_starts = np.ones((num_envs,), dtype=bool)
|
||||
while True:
|
||||
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
# Note: vectorized environment resets automatically
|
||||
obs, rewards, dones, info = vec_env.step(action)
|
||||
episode_starts = dones
|
||||
env.render()
|
||||
vec_env.render("human")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.0.0a4 (WIP)
|
||||
Release 2.0.0a8 (WIP)
|
||||
--------------------------
|
||||
|
||||
**Gymnasium support**
|
||||
|
|
@ -33,11 +33,12 @@ Others:
|
|||
^^^^^^^
|
||||
- Fixed ``sb3_contrib/tqc/*.py`` type hints
|
||||
- Fixed ``sb3_contrib/trpo/*.py`` type hints
|
||||
- Fixed ``sb3_contrib/common/envs/invalid_actions_env.py`` type hints
|
||||
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
- Update documentation, switch from Gym to Gymnasium
|
||||
|
||||
Release 1.8.0 (2023-04-07)
|
||||
--------------------------
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
|
|||
|
||||
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
||||
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, verbose=1)
|
||||
model.learn(5000)
|
||||
model.learn(5_000)
|
||||
|
||||
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||
|
||||
|
|
@ -71,13 +71,12 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
|
|||
|
||||
model = MaskablePPO.load("ppo_mask")
|
||||
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
while True:
|
||||
# Retrieve current action mask
|
||||
action_masks = get_action_masks(env)
|
||||
action, _states = model.predict(obs, action_masks=action_masks)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
|
||||
If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker``
|
||||
|
|
|
|||
|
|
@ -64,8 +64,8 @@ Example
|
|||
model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
|
||||
model.learn(5000)
|
||||
|
||||
env = model.get_env()
|
||||
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False)
|
||||
vec_env = model.get_env()
|
||||
mean_reward, std_reward = evaluate_policy(model, vec_env, n_eval_episodes=20, warn=False)
|
||||
print(mean_reward)
|
||||
|
||||
model.save("ppo_recurrent")
|
||||
|
|
@ -73,7 +73,7 @@ Example
|
|||
|
||||
model = RecurrentPPO.load("ppo_recurrent")
|
||||
|
||||
obs = env.reset()
|
||||
obs = vec_env.reset()
|
||||
# cell and hidden state of the LSTM
|
||||
lstm_states = None
|
||||
num_envs = 1
|
||||
|
|
@ -81,9 +81,9 @@ Example
|
|||
episode_starts = np.ones((num_envs,), dtype=bool)
|
||||
while True:
|
||||
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
obs, rewards, dones, info = vec_env.step(action)
|
||||
episode_starts = dones
|
||||
env.render()
|
||||
vec_env.render("human")
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -57,24 +57,24 @@ Example
|
|||
|
||||
from sb3_contrib import QRDQN
|
||||
|
||||
env = gym.make("CartPole-v1")
|
||||
env = gym.make("CartPole-v1", render_mode="human")
|
||||
|
||||
policy_kwargs = dict(n_quantiles=50)
|
||||
model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.learn(total_timesteps=10_000, log_interval=4)
|
||||
model.save("qrdqn_cartpole")
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = QRDQN.load("qrdqn_cartpole")
|
||||
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
obs = env.reset()
|
||||
if terminated or truncated:
|
||||
obs, _ = env.reset()
|
||||
|
||||
|
||||
Results
|
||||
|
|
|
|||
|
|
@ -57,24 +57,24 @@ Example
|
|||
|
||||
from sb3_contrib import TQC
|
||||
|
||||
env = gym.make("Pendulum-v1")
|
||||
env = gym.make("Pendulum-v1", render_mode="human")
|
||||
|
||||
policy_kwargs = dict(n_critics=2, n_quantiles=25)
|
||||
model = TQC("MlpPolicy", env, top_quantiles_to_drop_per_net=2, verbose=1, policy_kwargs=policy_kwargs)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.learn(total_timesteps=10_000, log_interval=4)
|
||||
model.save("tqc_pendulum")
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = TQC.load("tqc_pendulum")
|
||||
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
obs = env.reset()
|
||||
if terminated or truncated:
|
||||
obs, _ = env.reset()
|
||||
|
||||
|
||||
Results
|
||||
|
|
|
|||
|
|
@ -54,23 +54,23 @@ Example
|
|||
|
||||
from sb3_contrib import TRPO
|
||||
|
||||
env = gym.make("Pendulum-v1")
|
||||
env = gym.make("Pendulum-v1", render_mode="human")
|
||||
|
||||
model = TRPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.learn(total_timesteps=10_000, log_interval=4)
|
||||
model.save("trpo_pendulum")
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = TRPO.load("trpo_pendulum")
|
||||
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
env.render()
|
||||
if done:
|
||||
obs = env.reset()
|
||||
if terminated or truncated:
|
||||
obs, _ = env.reset()
|
||||
|
||||
|
||||
Results
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ exclude = """(?x)(
|
|||
| sb3_contrib/common/maskable/callbacks.py$
|
||||
| sb3_contrib/common/maskable/policies.py$
|
||||
| sb3_contrib/common/maskable/buffers.py$
|
||||
| sb3_contrib/common/envs/invalid_actions_env.py$
|
||||
| sb3_contrib/common/vec_env/async_eval.py$
|
||||
| sb3_contrib/ppo_mask/ppo_mask.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
|
|
@ -50,10 +49,6 @@ env = [
|
|||
filterwarnings = [
|
||||
# Tensorboard warnings
|
||||
"ignore::DeprecationWarning:tensorboard",
|
||||
# Gym warnings
|
||||
"ignore:Parameters to load are deprecated.:DeprecationWarning",
|
||||
"ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning",
|
||||
"ignore::UserWarning:gym",
|
||||
]
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from gymnasium import spaces
|
|||
from stable_baselines3.common.envs import IdentityEnv
|
||||
|
||||
|
||||
class InvalidActionEnvDiscrete(IdentityEnv):
|
||||
class InvalidActionEnvDiscrete(IdentityEnv[int]):
|
||||
"""
|
||||
Identity env with a discrete action space. Supports action masking.
|
||||
"""
|
||||
|
|
@ -30,17 +30,19 @@ class InvalidActionEnvDiscrete(IdentityEnv):
|
|||
self.state = self.action_space.sample()
|
||||
# Randomly choose invalid actions that are not the current state
|
||||
potential_invalid_actions = [i for i in self.possible_actions if i != self.state]
|
||||
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
|
||||
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist()
|
||||
|
||||
def action_masks(self) -> List[bool]:
|
||||
return [action not in self.invalid_actions for action in self.possible_actions]
|
||||
|
||||
|
||||
class InvalidActionEnvMultiDiscrete(IdentityEnv):
|
||||
class InvalidActionEnvMultiDiscrete(IdentityEnv[np.ndarray]):
|
||||
"""
|
||||
Identity env with a multidiscrete action space. Supports action masking.
|
||||
"""
|
||||
|
||||
action_space: spaces.MultiDiscrete
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Optional[List[int]] = None,
|
||||
|
|
@ -70,13 +72,13 @@ class InvalidActionEnvMultiDiscrete(IdentityEnv):
|
|||
|
||||
# Randomly choose invalid actions that are not the current state
|
||||
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
|
||||
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
|
||||
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist()
|
||||
|
||||
def action_masks(self) -> List[bool]:
|
||||
return [action not in self.invalid_actions for action in self.possible_actions]
|
||||
|
||||
|
||||
class InvalidActionEnvMultiBinary(IdentityEnv):
|
||||
class InvalidActionEnvMultiBinary(IdentityEnv[np.ndarray]):
|
||||
"""
|
||||
Identity env with a multibinary action space. Supports action masking.
|
||||
"""
|
||||
|
|
@ -94,6 +96,7 @@ class InvalidActionEnvMultiBinary(IdentityEnv):
|
|||
raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}")
|
||||
|
||||
space = spaces.MultiBinary(dims)
|
||||
self.n_dims = dims
|
||||
self.n_invalid_actions = n_invalid_actions
|
||||
self.possible_actions = np.arange(2 * dims)
|
||||
self.invalid_actions: List[int] = []
|
||||
|
|
@ -104,13 +107,13 @@ class InvalidActionEnvMultiBinary(IdentityEnv):
|
|||
|
||||
converted_state: List[int] = []
|
||||
running_total = 0
|
||||
for i in range(self.action_space.n):
|
||||
for i in range(self.n_dims):
|
||||
converted_state.append(running_total + self.state[i])
|
||||
running_total += 2
|
||||
|
||||
# Randomly choose invalid actions that are not the current state
|
||||
potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
|
||||
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
|
||||
self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False).tolist()
|
||||
|
||||
def action_masks(self) -> List[bool]:
|
||||
return [action not in self.invalid_actions for action in self.possible_actions]
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.0.0a4
|
||||
2.0.0a8
|
||||
|
|
|
|||
Loading…
Reference in New Issue