Add Gymnasium support (#152)

* Add support for Gym 0.24

* Fixes for gym 0.24

* Fix for new reset signature

* Add tmp SB3 branch

* Fixes for gym 0.26

* Remove unused import

* Fix dependency

* Type annotations fixes

* Reformat

* Reformat with black 23

* Move to gymnasium

* Patch env if needed

* Fix types

* Fix CI

* Fixes for gymnasium

* Fix wrapper annotations

* Update version

* Fix type check

* Update QRDQN type hints and bug fix with multi envs

* Fix TQC type hints

* Fix TRPO type hints

* Additional fixes

* Update SB3 version

* Update issue templates and CI

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
Antonin RAFFIN 2023-04-14 13:52:07 +02:00 committed by GitHub
parent a84ad3aa7d
commit 21cc96cafd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 278 additions and 168 deletions

View File

@ -50,7 +50,8 @@ body:
* GPU models and configuration
* Python version
* PyTorch version
* Gym version
* Gymnasium version
* (if installed) OpenAI Gym version
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:

View File

@ -34,9 +34,9 @@ body:
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
value: |
```python
import gym
import gymnasium as gym
import numpy as np
from gym import spaces
from gymnasium import spaces
from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env
@ -49,15 +49,16 @@ body:
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
def reset(self):
return self.observation_space.sample()
def reset(self, seed=None, options=None):
return self.observation_space.sample(), {}
def step(self, action):
obs = self.observation_space.sample()
reward = 1.0
done = False
terminated = False
truncated = False
info = {}
return obs, reward, done, info
return obs, reward, terminated, truncated, info
env = CustomEnv()
check_env(env)
@ -84,7 +85,8 @@ body:
* GPU models and configuration
* Python version
* PyTorch version
* Gym version
* Gymnasium version
* (if installed) OpenAI Gym version
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:

View File

@ -3,6 +3,42 @@
Changelog
==========
Release 2.0.0a4 (WIP)
--------------------------
**Gymnasium support**
.. warning::
Stable-Baselines3 (SB3) v2.0 will be the last one supporting python 3.7 (end of life in June 2023).
We highly recommended you to upgrade to Python >= 3.8.
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the ``shimmy`` package (@carlosluis, @arjun-kg, @tlpss)
- Upgraded to Stable-Baselines3 >= 2.0.0
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
- Fixed QRDQN update interval for multi envs
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Fixed ``sb3_contrib/tqc/*.py`` type hints
- Fixed ``sb3_contrib/trpo/*.py`` type hints
Documentation:
^^^^^^^^^^^^^^
Release 1.8.0 (2023-04-07)
--------------------------

View File

@ -91,7 +91,7 @@ to specify the name (see `PR #25 <https://github.com/Stable-Baselines-Team/stabl
.. code-block:: python
import gym
import gymnasium as gym
import numpy as np
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy

View File

@ -53,7 +53,7 @@ Example
.. code-block:: python
import gym
import gymnasium as gym
from sb3_contrib import QRDQN

View File

@ -52,7 +52,7 @@ Example
.. code-block:: python
import gym
import gymnasium as gym
import numpy as np
from sb3_contrib import TQC

View File

@ -49,7 +49,7 @@ Example
.. code-block:: python
import gym
import gymnasium as gym
import numpy as np
from sb3_contrib import TRPO

View File

@ -37,9 +37,6 @@ exclude = """(?x)(
| sb3_contrib/common/maskable/buffers.py$
| sb3_contrib/common/envs/invalid_actions_env.py$
| sb3_contrib/common/vec_env/async_eval.py$
| sb3_contrib/tqc/tqc.py$
| sb3_contrib/tqc/policies.py$
| sb3_contrib/trpo/trpo.py$
| sb3_contrib/ppo_mask/ppo_mask.py$
| tests/test_train_eval_mode.py$
)"""

View File

@ -8,7 +8,7 @@ from typing import Any, Dict, Optional, Type, TypeVar, Union
import numpy as np
import torch as th
import torch.nn.utils
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Type
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
@ -50,7 +50,7 @@ class ARSPolicy(BasePolicy):
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
)
elif isinstance(action_space, spaces.Discrete):
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
actor_net = create_mlp(self.features_dim, int(action_space.n), net_arch, activation_fn, with_bias=with_bias)
else:
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")

View File

@ -1,7 +1,7 @@
from typing import List, Optional
import numpy as np
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.envs import IdentityEnv

View File

@ -2,7 +2,7 @@ from typing import Generator, NamedTuple, Optional, Union
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.vec_env import VecNormalize

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, TypeVar
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.distributions import Distribution
from torch import nn
from torch.distributions import Categorical

View File

@ -1,7 +1,7 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import gym
import gymnasium as gym
import numpy as np
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
@ -64,7 +64,7 @@ def evaluate_policy(
is_monitor_wrapped = False
if not isinstance(env, VecEnv):
env = DummyVecEnv([lambda: env])
env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value]
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
@ -93,7 +93,7 @@ def evaluate_policy(
if use_masking:
action_masks = get_action_masks(env)
actions, state = model.predict(
observations,
observations, # type: ignore[arg-type]
state=states,
episode_start=episode_starts,
deterministic=deterministic,
@ -101,7 +101,10 @@ def evaluate_policy(
)
else:
actions, states = model.predict(
observations, state=states, episode_start=episode_starts, deterministic=deterministic
observations, # type: ignore[arg-type]
state=states,
episode_start=episode_starts,
deterministic=deterministic,
)
observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,

View File

@ -3,7 +3,7 @@ from typing import Callable, Generator, Optional, Tuple, Union
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.vec_env import VecNormalize

View File

@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.distributions import Distribution
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.torch_layers import (

View File

@ -8,6 +8,7 @@ import torch as th
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.utils import compat_gym_seed
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
@ -56,7 +57,7 @@ def _worker(
results.append((weights_idx, (episode_rewards, episode_lengths)))
remote.send(results)
elif cmd == "seed":
remote.send(env.seed(data))
remote.send(compat_gym_seed(env, seed=data))
elif cmd == "get_obs_rms":
remote.send(obs_rms)
elif cmd == "sync_obs_rms":

View File

@ -1,6 +1,6 @@
from typing import Callable, Union
import gym
import gymnasium as gym
import numpy as np

View File

@ -1,12 +1,14 @@
from typing import Dict, Union
from typing import Any, Dict, SupportsFloat, Tuple, Union
import gym
import gymnasium as gym
import numpy as np
from gym import spaces
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
from gymnasium import spaces
from gymnasium.core import ActType
TimeFeatureObs = Union[np.ndarray, Dict[str, np.ndarray]]
class TimeFeatureWrapper(gym.Wrapper):
class TimeFeatureWrapper(gym.Wrapper[TimeFeatureObs, ActType, TimeFeatureObs, ActType]):
"""
Add remaining, normalized time to observation space for fixed length episodes.
See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13.
@ -33,27 +35,30 @@ class TimeFeatureWrapper(gym.Wrapper):
if isinstance(env.observation_space, spaces.Dict):
assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
obs_space = env.observation_space.spaces["observation"]
assert isinstance(obs_space, spaces.Box), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
obs_space = env.observation_space.spaces["observation"]
else:
obs_space = env.observation_space
assert isinstance(obs_space, gym.spaces.Box), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
assert len(obs_space.shape) == 1, "Only 1D observation spaces are supported"
low, high = obs_space.low, obs_space.high
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0])) # type: ignore[arg-type]
self.dtype = obs_space.dtype
if isinstance(env.observation_space, spaces.Dict):
env.observation_space.spaces["observation"] = spaces.Box(low=low, high=high, dtype=self.dtype)
env.observation_space.spaces["observation"] = spaces.Box(
low=low,
high=high,
dtype=self.dtype, # type: ignore[arg-type]
)
else:
env.observation_space = spaces.Box(low=low, high=high, dtype=self.dtype)
env.observation_space = spaces.Box(low=low, high=high, dtype=self.dtype) # type: ignore[arg-type]
super().__init__(env)
# Try to infer the max number of steps per episode
try:
self._max_steps = env.spec.max_episode_steps
self._max_steps = env.spec.max_episode_steps # type: ignore[union-attr]
except AttributeError:
self._max_steps = None
@ -64,14 +69,15 @@ class TimeFeatureWrapper(gym.Wrapper):
self._current_step = 0
self._test_mode = test_mode
def reset(self) -> GymObs:
def reset(self, **kwargs) -> Tuple[TimeFeatureObs, Dict[str, Any]]:
self._current_step = 0
return self._get_obs(self.env.reset())
obs, info = self.env.reset(**kwargs)
return self._get_obs(obs), info
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
def step(self, action: ActType) -> Tuple[TimeFeatureObs, SupportsFloat, bool, bool, Dict[str, Any]]:
self._current_step += 1
obs, reward, done, info = self.env.step(action)
return self._get_obs(obs), reward, done, info
obs, reward, done, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, done, truncated, info
def _get_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
@ -80,11 +86,13 @@ class TimeFeatureWrapper(gym.Wrapper):
:param obs:
:return:
"""
# for mypy
assert self._max_steps is not None
# Remaining time is more general
time_feature = 1 - (self._current_step / self._max_steps)
if self._test_mode:
time_feature = 1.0
time_feature = np.array(time_feature, dtype=self.dtype)
time_feature = np.array(time_feature, dtype=self.dtype) # type: ignore[assignment]
if isinstance(obs, dict):
obs["observation"] = np.append(obs["observation"], time_feature)

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
@ -232,6 +232,7 @@ class MaskablePPO(OnPolicyAlgorithm):
# Make sure training timesteps are ahead of the internal counter
total_timesteps += self.num_timesteps
self._total_timesteps = total_timesteps
self._num_timesteps_at_start = self.num_timesteps
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Type, TypeVar, Union
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Type
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
@ -27,10 +27,12 @@ class QuantileNetwork(BasePolicy):
dividing by 255.0 (True by default)
"""
action_space: spaces.Discrete
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Discrete,
features_extractor: BaseFeaturesExtractor,
features_dim: int,
n_quantiles: int = 200,
@ -52,7 +54,7 @@ class QuantileNetwork(BasePolicy):
self.activation_fn = activation_fn
self.features_dim = features_dim
self.n_quantiles = n_quantiles
action_dim = self.action_space.n # number of actions
action_dim = int(self.action_space.n) # number of actions
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
self.quantile_net = nn.Sequential(*quantile_net)
@ -63,10 +65,8 @@ class QuantileNetwork(BasePolicy):
:param obs: Observation
:return: The estimated quantiles for each action.
"""
# For type checker:
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
return quantiles.view(-1, self.n_quantiles, int(self.action_space.n))
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
q_values = self(observation).mean(dim=1)
@ -110,10 +110,13 @@ class QRDQNPolicy(BasePolicy):
excluding the learning rate, to pass to the optimizer
"""
quantile_net: QuantileNetwork
quantile_net_target: QuantileNetwork
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
@ -152,9 +155,6 @@ class QRDQNPolicy(BasePolicy):
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
self.quantile_net: QuantileNetwork
self.quantile_net_target: QuantileNetwork
self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None:
@ -239,7 +239,7 @@ class CnnPolicy(QRDQNPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
@ -287,7 +287,7 @@ class MultiInputPolicy(QRDQNPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,

View File

@ -1,8 +1,9 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
@ -10,7 +11,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy, QuantileNetwork
SelfQRDQN = TypeVar("SelfQRDQN", bound="QRDQN")
@ -64,6 +65,11 @@ class QRDQN(OffPolicyAlgorithm):
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
# Linear schedule will be defined in `_setup_model()`
exploration_schedule: Schedule
quantile_net: QuantileNetwork
quantile_net_target: QuantileNetwork
policy: QRDQNPolicy
def __init__(
self,
@ -123,13 +129,11 @@ class QRDQN(OffPolicyAlgorithm):
self.exploration_final_eps = exploration_final_eps
self.exploration_fraction = exploration_fraction
self.target_update_interval = target_update_interval
# For updating the target network with multiple envs:
self._n_calls = 0
self.max_grad_norm = max_grad_norm
# "epsilon" for the epsilon-greedy exploration
self.exploration_rate = 0.0
# Linear schedule will be defined in `_setup_model()`
self.exploration_schedule: Schedule
self.quantile_net: th.nn.Module
self.quantile_net_target: th.nn.Module
if "optimizer_class" not in self.policy_kwargs:
self.policy_kwargs["optimizer_class"] = th.optim.Adam
@ -148,10 +152,20 @@ class QRDQN(OffPolicyAlgorithm):
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
)
# Account for multiple environments
# each call to step() corresponds to n_envs transitions
if self.n_envs > 1:
if self.n_envs > self.target_update_interval:
warnings.warn(
"The number of environments used is greater than the target network "
f"update interval ({self.n_envs} > {self.target_update_interval}), "
"therefore the target network will be updated after each call to env.step() "
f"which corresponds to {self.n_envs} steps."
)
self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)
def _create_aliases(self) -> None:
# For type checker:
assert isinstance(self.policy, QRDQNPolicy)
self.quantile_net = self.policy.quantile_net
self.quantile_net_target = self.policy.quantile_net_target
self.n_quantiles = self.policy.n_quantiles
@ -161,7 +175,8 @@ class QRDQN(OffPolicyAlgorithm):
Update the exploration rate and target network if needed.
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
if self.num_timesteps % self.target_update_interval == 0:
self._n_calls += 1
if self._n_calls % self.target_update_interval == 0:
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BaseModel, BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim
@ -44,10 +44,12 @@ class Actor(BasePolicy):
dividing by 255.0 (True by default)
"""
action_space: spaces.Box
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Box,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
@ -95,9 +97,9 @@ class Actor(BasePolicy):
if clip_mean > 0.0:
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
else:
self.action_dist = SquashedDiagGaussianDistribution(action_dim)
self.action_dist = SquashedDiagGaussianDistribution(action_dim) # type: ignore[assignment]
self.mu = nn.Linear(last_layer_dim, action_dim)
self.log_std = nn.Linear(last_layer_dim, action_dim)
self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment]
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
@ -156,7 +158,7 @@ class Actor(BasePolicy):
if self.use_sde:
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
# Unstructured exploration (Original implementation)
log_std = self.log_std(latent_pi)
log_std = self.log_std(latent_pi) # type: ignore[operator]
# Original Implementation to cap the standard deviation
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, {}
@ -192,12 +194,15 @@ class Critic(BaseModel):
between the actor and the critic (this saves computation time)
"""
action_space: spaces.Box
features_extractor: BaseFeaturesExtractor
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Box,
net_arch: List[int],
features_extractor: nn.Module,
features_extractor: BaseFeaturesExtractor,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
@ -221,12 +226,12 @@ class Critic(BaseModel):
self.quantiles_total = n_quantiles * n_critics
for i in range(n_critics):
qf_net = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
qf_net = nn.Sequential(*qf_net)
qf_net_list = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
qf_net = nn.Sequential(*qf_net_list)
self.add_module(f"qf{i}", qf_net)
self.q_networks.append(qf_net)
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
def forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor:
# Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
@ -266,10 +271,14 @@ class TQCPolicy(BasePolicy):
between the actor and the critic (this saves computation time)
"""
actor: Actor
critic: Critic
critic_target: Critic
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
@ -328,15 +337,17 @@ class TQCPolicy(BasePolicy):
"share_features_extractor": share_features_extractor,
}
self.critic_kwargs.update(tqc_kwargs)
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
self.share_features_extractor = share_features_extractor
self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None:
self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.actor.optimizer = self.optimizer_class( # type: ignore[call-arg]
self.actor.parameters(),
lr=lr_schedule(1),
**self.optimizer_kwargs,
)
if self.share_features_extractor:
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
@ -347,7 +358,7 @@ class TQCPolicy(BasePolicy):
# Create a separate features extractor for the critic
# this requires more memory and computation
self.critic = self.make_critic(features_extractor=None)
critic_parameters = self.critic.parameters()
critic_parameters = list(self.critic.parameters())
# Critic target should not share the feature extactor with critic
self.critic_target = self.make_critic(features_extractor=None)
@ -356,7 +367,11 @@ class TQCPolicy(BasePolicy):
# Target networks should always be in eval mode
self.critic_target.set_training_mode(False)
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
self.critic.optimizer = self.optimizer_class( # type: ignore[call-arg]
critic_parameters,
lr=lr_schedule(1),
**self.optimizer_kwargs,
)
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
@ -447,7 +462,7 @@ class CnnPolicy(TQCPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
@ -516,7 +531,7 @@ class MultiInputPolicy(TQCPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,

View File

@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Un
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
@ -11,7 +11,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
from sb3_contrib.tqc.policies import Actor, CnnPolicy, Critic, MlpPolicy, MultiInputPolicy, TQCPolicy
SelfTQC = TypeVar("SelfTQC", bound="TQC")
@ -73,6 +73,10 @@ class TQC(OffPolicyAlgorithm):
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
policy: TQCPolicy
actor: Actor
critic: Critic
critic_target: Critic
def __init__(
self,
@ -87,7 +91,7 @@ class TQC(OffPolicyAlgorithm):
train_freq: int = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
ent_coef: Union[str, float] = "auto",
@ -139,7 +143,7 @@ class TQC(OffPolicyAlgorithm):
# Inverse of the reward scale
self.ent_coef = ent_coef
self.target_update_interval = target_update_interval
self.ent_coef_optimizer = None
self.ent_coef_optimizer: Optional[th.optim.Adam] = None
self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
if _init_setup_model:
@ -155,7 +159,7 @@ class TQC(OffPolicyAlgorithm):
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
@ -202,7 +206,7 @@ class TQC(OffPolicyAlgorithm):
for gradient_step in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
# We need to sample because `log_std` may have changed between two gradient steps
if self.use_sde:
@ -213,7 +217,7 @@ class TQC(OffPolicyAlgorithm):
log_prob = log_prob.reshape(-1, 1)
ent_coef_loss = None
if self.ent_coef_optimizer is not None:
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
@ -224,11 +228,10 @@ class TQC(OffPolicyAlgorithm):
ent_coef = self.ent_coef_tensor
ent_coefs.append(ent_coef.item())
self.replay_buffer.ent_coef = ent_coef.item()
# Optimize entropy coefficient, also called
# entropy temperature or alpha in the paper
if ent_coef_loss is not None:
if ent_coef_loss is not None and self.ent_coef_optimizer is not None:
self.ent_coef_optimizer.zero_grad()
ent_coef_loss.backward()
self.ent_coef_optimizer.step()

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.distributions import kl_divergence
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
@ -178,16 +178,16 @@ class TRPO(OnPolicyAlgorithm):
"""
# This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence
# The policy objective is also called surrogate objective
policy_objective_gradients = []
policy_objective_gradients_list = []
# Contains the gradients of the KL divergence
grad_kl = []
grad_kl_list = []
# Contains the shape of the gradients of the KL divergence w.r.t each parameter
# This way the flattened gradient can be reshaped back into the original shapes and applied to
# the parameters
grad_shape = []
grad_shape: List[Tuple[int, ...]] = []
# Contains the parameters which have non-zeros KL divergence gradients
# The list is used during the line-search to apply the step to each parameters
actor_params = []
actor_params: List[nn.Parameter] = []
for name, param in self.policy.named_parameters():
# Skip parameters related to value function based on name
@ -213,13 +213,13 @@ class TRPO(OnPolicyAlgorithm):
policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True)
grad_shape.append(kl_param_grad.shape)
grad_kl.append(kl_param_grad.reshape(-1))
policy_objective_gradients.append(policy_objective_grad.reshape(-1))
grad_kl_list.append(kl_param_grad.reshape(-1))
policy_objective_gradients_list.append(policy_objective_grad.reshape(-1))
actor_params.append(param)
# Gradients are concatenated before the conjugate gradient step
policy_objective_gradients = th.cat(policy_objective_gradients)
grad_kl = th.cat(grad_kl)
policy_objective_gradients = th.cat(policy_objective_gradients_list)
grad_kl = th.cat(grad_kl_list)
return actor_params, policy_objective_gradients, grad_kl, grad_shape
def train(self) -> None:
@ -243,10 +243,10 @@ class TRPO(OnPolicyAlgorithm):
rollout_data = RolloutBufferSamples(
rollout_data.observations[:: self.sub_sampling_factor],
rollout_data.actions[:: self.sub_sampling_factor],
None, # old values, not used here
None, # type: ignore[arg-type] # old values, not used here
rollout_data.old_log_prob[:: self.sub_sampling_factor],
rollout_data.advantages[:: self.sub_sampling_factor],
None, # returns, not used here
None, # type: ignore[arg-type] # returns, not used here
)
actions = rollout_data.actions
@ -301,7 +301,7 @@ class TRPO(OnPolicyAlgorithm):
line_search_max_step_size /= th.matmul(
search_direction, hessian_vector_product_fn(search_direction, retain_graph=False)
)
line_search_max_step_size = th.sqrt(line_search_max_step_size)
line_search_max_step_size = th.sqrt(line_search_max_step_size) # type: ignore[assignment, arg-type]
line_search_backtrack_coeff = 1.0
original_actor_params = [param.detach().clone() for param in actor_params]
@ -351,7 +351,7 @@ class TRPO(OnPolicyAlgorithm):
param.data = original_param.data.clone()
policy_objective_values.append(policy_objective.item())
kl_divergences.append(0)
kl_divergences.append(0.0)
else:
policy_objective_values.append(new_policy_objective.item())
kl_divergences.append(kl_div.item())

View File

@ -1 +1 @@
1.8.0
2.0.0a4

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.8.0,<2.0",
"stable_baselines3>=2.0.0a4",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",

View File

@ -43,7 +43,7 @@ def test_cnn(tmp_path, model_class, share_features_extractor):
model = model_class("CnnPolicy", env, **kwargs).learn(250)
obs = env.reset()
obs, _ = env.reset()
# FakeImageEnv is channel last by default and should be wrapped
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)

View File

@ -1,7 +1,10 @@
import gym
from typing import Dict, Optional
import gymnasium as gym
import numpy as np
import pytest
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.envs import SimpleMultiObsEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
@ -27,8 +30,8 @@ class DummyDictEnv(gym.Env):
else:
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
N_CHANNELS = 1
HEIGHT = 64
WIDTH = 64
HEIGHT = 36
WIDTH = 36
if channel_last:
obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
@ -65,19 +68,31 @@ class DummyDictEnv(gym.Env):
def step(self, action):
reward = 0.0
done = False
return self.observation_space.sample(), reward, done, {}
done = truncated = False
return self.observation_space.sample(), reward, done, truncated, {}
def compute_reward(self, achieved_goal, desired_goal, info):
return np.zeros((len(achieved_goal),))
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
self.observation_space.seed(seed)
return self.observation_space.sample(), {}
def reset(self):
return self.observation_space.sample()
def render(self, mode="human"):
def render(self):
pass
@pytest.mark.parametrize("use_discrete_actions", [True, False])
@pytest.mark.parametrize("channel_last", [True, False])
@pytest.mark.parametrize("nested_dict_obs", [True, False])
@pytest.mark.parametrize("vec_only", [True, False])
def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only):
# Check the env used for testing
if nested_dict_obs:
with pytest.warns(UserWarning, match="Nested observation spaces are not supported"):
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))
else:
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
def test_consistency(model_class):
"""
@ -89,7 +104,7 @@ def test_consistency(model_class):
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
env = gym.wrappers.FlattenObservation(dict_env)
dict_env.seed(10)
obs = dict_env.reset()
obs, _ = dict_env.reset()
kwargs = {}
n_steps = 256

View File

@ -12,7 +12,7 @@ DIM = 4
@pytest.mark.parametrize("model_class", [QRDQN, TRPO])
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
def test_discrete(model_class, env):
env_ = DummyVecEnv([lambda: env])
vec_env = DummyVecEnv([lambda: env])
kwargs = {}
n_steps = 1500
if model_class == QRDQN:
@ -30,9 +30,9 @@ def test_discrete(model_class, env):
elif n_steps == TRPO:
kwargs = dict(n_steps=256, cg_max_steps=5)
model = model_class("MlpPolicy", env_, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
obs = env.reset()
evaluate_policy(model, vec_env, n_eval_episodes=20, reward_threshold=90, warn=False)
obs = vec_env.reset()
assert np.shape(model.predict(obs)[0]) == np.shape(obs)

View File

@ -1,8 +1,10 @@
import random
from typing import Dict, Tuple
import gym
import gymnasium as gym
import numpy as np
import pytest
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
@ -30,12 +32,12 @@ class ToDictWrapper(gym.Wrapper):
super().__init__(env)
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
def reset(self):
return {"obs": self.env.reset()}
def reset(self, **kwargs) -> Tuple[Dict[str, np.ndarray], Dict]:
return {"obs": self.env.reset(seed=kwargs.get("seed", 0))[0]}, {} # type: ignore[dict-item]
def step(self, action):
obs, reward, done, infos = self.env.step(action)
return {"obs": obs}, reward, done, infos
obs, reward, terminated, truncated, infos = self.env.step(action)
return {"obs": obs}, reward, terminated, truncated, infos
def test_identity():

View File

@ -1,10 +1,13 @@
import gym
from typing import Dict, Optional
import gymnasium as gym
import numpy as np
import pytest
from gym import spaces
from gym.envs.classic_control import CartPoleEnv
from gym.wrappers.time_limit import TimeLimit
from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.wrappers.time_limit import TimeLimit
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv
from stable_baselines3.common.evaluation import evaluate_policy
@ -22,12 +25,12 @@ class ToDictWrapper(gym.Wrapper):
super().__init__(env)
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
def reset(self):
return {"obs": self.env.reset()}
def reset(self, **kwargs):
return {"obs": self.env.reset(**kwargs)[0]}, {}
def step(self, action):
obs, reward, done, infos = self.env.step(action)
return {"obs": obs}, reward, done, infos
obs, reward, done, truncated, infos = self.env.step(action)
return {"obs": obs}, reward, done, truncated, infos
class CartPoleNoVelEnv(CartPoleEnv):
@ -46,15 +49,19 @@ class CartPoleNoVelEnv(CartPoleEnv):
@staticmethod
def _pos_obs(full_obs):
xpos, _xvel, thetapos, _thetavel = full_obs
return xpos, thetapos
return np.array([xpos, thetapos])
def reset(self):
full_obs = super().reset()
return CartPoleNoVelEnv._pos_obs(full_obs)
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
full_obs, info = super().reset(seed=seed, options=options)
return CartPoleNoVelEnv._pos_obs(full_obs), info
def step(self, action):
full_obs, rew, done, info = super().step(action)
return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info
full_obs, rew, terminated, truncated, info = super().step(action)
return CartPoleNoVelEnv._pos_obs(full_obs), rew, terminated, truncated, info
def test_env():
check_env(CartPoleNoVelEnv())
@pytest.mark.parametrize(

View File

@ -1,4 +1,4 @@
import gym
import gymnasium as gym
import pytest
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize

View File

@ -3,7 +3,7 @@ import pathlib
from collections import OrderedDict
from copy import deepcopy
import gym
import gymnasium as gym
import numpy as np
import pytest
import torch as th
@ -24,7 +24,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env:
if model_class == QRDQN:
return IdentityEnv(10)
else:
return IdentityEnvBox(10)
return IdentityEnvBox(-10, 10)
@pytest.mark.parametrize("model_class", MODEL_LIST)

View File

@ -1,6 +1,6 @@
from typing import Union
import gym
import gymnasium as gym
import numpy as np
import pytest
import torch as th
@ -111,7 +111,7 @@ def test_ppo_mask_train_eval_mode():
batch_norm_stats_before = clone_on_policy_batch_norm(model)
observation = env.reset()
observation, _ = env.reset()
action_masks = get_action_masks(env)
first_prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True)
for _ in range(5):

View File

@ -1,7 +1,7 @@
from typing import List
import pytest
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.envs import IdentityEnv
from sb3_contrib.common.wrappers import ActionMasker
@ -20,10 +20,12 @@ class IdentityEnvDiscrete(IdentityEnv):
super().__init__(ep_length=ep_length, space=space)
def _action_masks(self) -> List[int]:
assert isinstance(self.action_space, spaces.Discrete)
return [i == self.state for i in range(self.action_space.n)]
def action_mask_fn(env: IdentityEnvDiscrete) -> List[int]:
assert isinstance(env.action_space, spaces.Discrete)
return [i == env.state for i in range(env.action_space.n)]

View File

@ -1,14 +1,14 @@
import gym
import gymnasium as gym
import numpy as np
import pytest
from gym import spaces
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.envs import BitFlippingEnv
from sb3_contrib.common.wrappers import TimeFeatureWrapper
class CustomGoalEnv(gym.GoalEnv):
class CustomGoalEnv(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = spaces.Dict(
@ -21,10 +21,10 @@ class CustomGoalEnv(gym.GoalEnv):
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
def reset(self):
return self.observation_space.sample()
return self.observation_space.sample(), {}
def step(self, action):
return self.observation_space.sample(), 0.0, False, {}
return self.observation_space.sample(), 0.0, False, False, {}
def check_time_feature(obs, timestep, max_timesteps):
@ -37,14 +37,16 @@ def test_time_feature():
check_env(env, warn=False)
# Check for four episodes
max_timesteps = 200
obs = env.reset()
obs, _ = env.reset()
for _ in range(4):
done = False
check_time_feature(obs, timestep=0, max_timesteps=max_timesteps)
for step in range(1, max_timesteps + 1):
obs, _, done, _ = env.step(env.action_space.sample())
obs, _, terminated, truncated, _ = env.step(env.action_space.sample())
check_time_feature(obs, timestep=step, max_timesteps=max_timesteps)
done = terminated or truncated
if done:
obs = env.reset()
obs, _ = env.reset()
env = BitFlippingEnv()
with pytest.raises(AssertionError):
@ -52,16 +54,16 @@ def test_time_feature():
env = CustomGoalEnv()
env = TimeFeatureWrapper(env, max_steps=500)
obs = env.reset()
obs, _ = env.reset()
check_time_feature(obs["observation"], timestep=0, max_timesteps=500)
obs, _, _, _ = env.step(env.action_space.sample())
obs = env.step(env.action_space.sample())[0]
check_time_feature(obs["observation"], timestep=1, max_timesteps=500)
# In test mode, the time feature must be constant
env = gym.make("Pendulum-v1")
env = TimeFeatureWrapper(env, test_mode=True)
obs = env.reset()
obs, _ = env.reset()
check_time_feature(obs, timestep=0, max_timesteps=200)
obs, _, _, _ = env.step(env.action_space.sample())
obs = env.step(env.action_space.sample())[0]
# Should be the same
check_time_feature(obs, timestep=0, max_timesteps=200)