diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 4ad1f39..6908e2d 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -50,7 +50,8 @@ body: * GPU models and configuration * Python version * PyTorch version - * Gym version + * Gymnasium version + * (if installed) OpenAI Gym version * Versions of any other relevant libraries You can use `sb3.get_system_info()` to print relevant packages info: diff --git a/.github/ISSUE_TEMPLATE/custom_env.yml b/.github/ISSUE_TEMPLATE/custom_env.yml index c35211c..c0b4093 100644 --- a/.github/ISSUE_TEMPLATE/custom_env.yml +++ b/.github/ISSUE_TEMPLATE/custom_env.yml @@ -34,9 +34,9 @@ body: Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces. value: | ```python - import gym + import gymnasium as gym import numpy as np - from gym import spaces + from gymnasium import spaces from stable_baselines3 import A2C from stable_baselines3.common.env_checker import check_env @@ -49,15 +49,16 @@ body: self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,)) self.action_space = spaces.Box(low=-1, high=1, shape=(6,)) - def reset(self): - return self.observation_space.sample() + def reset(self, seed=None, options=None): + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() reward = 1.0 - done = False + terminated = False + truncated = False info = {} - return obs, reward, done, info + return obs, reward, terminated, truncated, info env = CustomEnv() check_env(env) @@ -84,7 +85,8 @@ body: * GPU models and configuration * Python version * PyTorch version - * Gym version + * Gymnasium version + * (if installed) OpenAI Gym version * Versions of any other relevant libraries You can use `sb3.get_system_info()` to print relevant packages info: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a53fd99..bd6973e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,42 @@ Changelog ========== +Release 2.0.0a4 (WIP) +-------------------------- + +**Gymnasium support** + +.. warning:: + + Stable-Baselines3 (SB3) v2.0 will be the last one supporting python 3.7 (end of life in June 2023). + We highly recommended you to upgrade to Python >= 3.8. + + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the ``shimmy`` package (@carlosluis, @arjun-kg, @tlpss) +- Upgraded to Stable-Baselines3 >= 2.0.0 + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ +- Fixed QRDQN update interval for multi envs + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Fixed ``sb3_contrib/tqc/*.py`` type hints +- Fixed ``sb3_contrib/trpo/*.py`` type hints + + +Documentation: +^^^^^^^^^^^^^^ + + Release 1.8.0 (2023-04-07) -------------------------- diff --git a/docs/modules/ppo_mask.rst b/docs/modules/ppo_mask.rst index 87309e1..027c89a 100644 --- a/docs/modules/ppo_mask.rst +++ b/docs/modules/ppo_mask.rst @@ -91,7 +91,7 @@ to specify the name (see `PR #25 GymObs: + def reset(self, **kwargs) -> Tuple[TimeFeatureObs, Dict[str, Any]]: self._current_step = 0 - return self._get_obs(self.env.reset()) + obs, info = self.env.reset(**kwargs) + return self._get_obs(obs), info - def step(self, action: Union[int, np.ndarray]) -> GymStepReturn: + def step(self, action: ActType) -> Tuple[TimeFeatureObs, SupportsFloat, bool, bool, Dict[str, Any]]: self._current_step += 1 - obs, reward, done, info = self.env.step(action) - return self._get_obs(obs), reward, done, info + obs, reward, done, truncated, info = self.env.step(action) + return self._get_obs(obs), reward, done, truncated, info def _get_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: """ @@ -80,11 +86,13 @@ class TimeFeatureWrapper(gym.Wrapper): :param obs: :return: """ + # for mypy + assert self._max_steps is not None # Remaining time is more general time_feature = 1 - (self._current_step / self._max_steps) if self._test_mode: time_feature = 1.0 - time_feature = np.array(time_feature, dtype=self.dtype) + time_feature = np.array(time_feature, dtype=self.dtype) # type: ignore[assignment] if isinstance(obs, dict): obs["observation"] = np.append(obs["observation"], time_feature) diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index d27c724..b97213b 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common import utils from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback @@ -232,6 +232,7 @@ class MaskablePPO(OnPolicyAlgorithm): # Make sure training timesteps are ahead of the internal counter total_timesteps += self.num_timesteps self._total_timesteps = total_timesteps + self._num_timesteps_at_start = self.num_timesteps # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 344fbda..c8b0a62 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Type, TypeVar, Union import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index 0491f50..a58e47b 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Type import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -27,10 +27,12 @@ class QuantileNetwork(BasePolicy): dividing by 255.0 (True by default) """ + action_space: spaces.Discrete + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Discrete, features_extractor: BaseFeaturesExtractor, features_dim: int, n_quantiles: int = 200, @@ -52,7 +54,7 @@ class QuantileNetwork(BasePolicy): self.activation_fn = activation_fn self.features_dim = features_dim self.n_quantiles = n_quantiles - action_dim = self.action_space.n # number of actions + action_dim = int(self.action_space.n) # number of actions quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn) self.quantile_net = nn.Sequential(*quantile_net) @@ -63,10 +65,8 @@ class QuantileNetwork(BasePolicy): :param obs: Observation :return: The estimated quantiles for each action. """ - # For type checker: - assert isinstance(self.features_extractor, BaseFeaturesExtractor) quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor)) - return quantiles.view(-1, self.n_quantiles, self.action_space.n) + return quantiles.view(-1, self.n_quantiles, int(self.action_space.n)) def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: q_values = self(observation).mean(dim=1) @@ -110,10 +110,13 @@ class QRDQNPolicy(BasePolicy): excluding the learning rate, to pass to the optimizer """ + quantile_net: QuantileNetwork + quantile_net_target: QuantileNetwork + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Discrete, lr_schedule: Schedule, n_quantiles: int = 200, net_arch: Optional[List[int]] = None, @@ -152,9 +155,6 @@ class QRDQNPolicy(BasePolicy): "activation_fn": self.activation_fn, "normalize_images": normalize_images, } - - self.quantile_net: QuantileNetwork - self.quantile_net_target: QuantileNetwork self._build(lr_schedule) def _build(self, lr_schedule: Schedule) -> None: @@ -239,7 +239,7 @@ class CnnPolicy(QRDQNPolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Discrete, lr_schedule: Schedule, n_quantiles: int = 200, net_arch: Optional[List[int]] = None, @@ -287,7 +287,7 @@ class MultiInputPolicy(QRDQNPolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Discrete, lr_schedule: Schedule, n_quantiles: int = 200, net_arch: Optional[List[int]] = None, diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 62c6a30..a2a4341 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -1,8 +1,9 @@ +import warnings from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy @@ -10,7 +11,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update from sb3_contrib.common.utils import quantile_huber_loss -from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy +from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy, QuantileNetwork SelfQRDQN = TypeVar("SelfQRDQN", bound="QRDQN") @@ -64,6 +65,11 @@ class QRDQN(OffPolicyAlgorithm): "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } + # Linear schedule will be defined in `_setup_model()` + exploration_schedule: Schedule + quantile_net: QuantileNetwork + quantile_net_target: QuantileNetwork + policy: QRDQNPolicy def __init__( self, @@ -123,13 +129,11 @@ class QRDQN(OffPolicyAlgorithm): self.exploration_final_eps = exploration_final_eps self.exploration_fraction = exploration_fraction self.target_update_interval = target_update_interval + # For updating the target network with multiple envs: + self._n_calls = 0 self.max_grad_norm = max_grad_norm # "epsilon" for the epsilon-greedy exploration self.exploration_rate = 0.0 - # Linear schedule will be defined in `_setup_model()` - self.exploration_schedule: Schedule - self.quantile_net: th.nn.Module - self.quantile_net_target: th.nn.Module if "optimizer_class" not in self.policy_kwargs: self.policy_kwargs["optimizer_class"] = th.optim.Adam @@ -148,10 +152,20 @@ class QRDQN(OffPolicyAlgorithm): self.exploration_schedule = get_linear_fn( self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction ) + # Account for multiple environments + # each call to step() corresponds to n_envs transitions + if self.n_envs > 1: + if self.n_envs > self.target_update_interval: + warnings.warn( + "The number of environments used is greater than the target network " + f"update interval ({self.n_envs} > {self.target_update_interval}), " + "therefore the target network will be updated after each call to env.step() " + f"which corresponds to {self.n_envs} steps." + ) + + self.target_update_interval = max(self.target_update_interval // self.n_envs, 1) def _create_aliases(self) -> None: - # For type checker: - assert isinstance(self.policy, QRDQNPolicy) self.quantile_net = self.policy.quantile_net self.quantile_net_target = self.policy.quantile_net_target self.n_quantiles = self.policy.n_quantiles @@ -161,7 +175,8 @@ class QRDQN(OffPolicyAlgorithm): Update the exploration rate and target network if needed. This method is called in ``collect_rollouts()`` after each step in the environment. """ - if self.num_timesteps % self.target_update_interval == 0: + self._n_calls += 1 + if self._n_calls % self.target_update_interval == 0: polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau) # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index ce7fb26..9913d80 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution from stable_baselines3.common.policies import BaseModel, BasePolicy from stable_baselines3.common.preprocessing import get_action_dim @@ -44,10 +44,12 @@ class Actor(BasePolicy): dividing by 255.0 (True by default) """ + action_space: spaces.Box + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, net_arch: List[int], features_extractor: nn.Module, features_dim: int, @@ -95,9 +97,9 @@ class Actor(BasePolicy): if clip_mean > 0.0: self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean)) else: - self.action_dist = SquashedDiagGaussianDistribution(action_dim) + self.action_dist = SquashedDiagGaussianDistribution(action_dim) # type: ignore[assignment] self.mu = nn.Linear(last_layer_dim, action_dim) - self.log_std = nn.Linear(last_layer_dim, action_dim) + self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment] def _get_constructor_parameters(self) -> Dict[str, Any]: data = super()._get_constructor_parameters() @@ -156,7 +158,7 @@ class Actor(BasePolicy): if self.use_sde: return mean_actions, self.log_std, dict(latent_sde=latent_pi) # Unstructured exploration (Original implementation) - log_std = self.log_std(latent_pi) + log_std = self.log_std(latent_pi) # type: ignore[operator] # Original Implementation to cap the standard deviation log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) return mean_actions, log_std, {} @@ -192,12 +194,15 @@ class Critic(BaseModel): between the actor and the critic (this saves computation time) """ + action_space: spaces.Box + features_extractor: BaseFeaturesExtractor + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, net_arch: List[int], - features_extractor: nn.Module, + features_extractor: BaseFeaturesExtractor, features_dim: int, activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True, @@ -221,12 +226,12 @@ class Critic(BaseModel): self.quantiles_total = n_quantiles * n_critics for i in range(n_critics): - qf_net = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn) - qf_net = nn.Sequential(*qf_net) + qf_net_list = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn) + qf_net = nn.Sequential(*qf_net_list) self.add_module(f"qf{i}", qf_net) self.q_networks.append(qf_net) - def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]: + def forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor: # Learn the features extractor using the policy loss only # when the features_extractor is shared with the actor with th.set_grad_enabled(not self.share_features_extractor): @@ -266,10 +271,14 @@ class TQCPolicy(BasePolicy): between the actor and the critic (this saves computation time) """ + actor: Actor + critic: Critic + critic_target: Critic + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -328,15 +337,17 @@ class TQCPolicy(BasePolicy): "share_features_extractor": share_features_extractor, } self.critic_kwargs.update(tqc_kwargs) - self.actor, self.actor_target = None, None - self.critic, self.critic_target = None, None self.share_features_extractor = share_features_extractor self._build(lr_schedule) def _build(self, lr_schedule: Schedule) -> None: self.actor = self.make_actor() - self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.actor.optimizer = self.optimizer_class( # type: ignore[call-arg] + self.actor.parameters(), + lr=lr_schedule(1), + **self.optimizer_kwargs, + ) if self.share_features_extractor: self.critic = self.make_critic(features_extractor=self.actor.features_extractor) @@ -347,7 +358,7 @@ class TQCPolicy(BasePolicy): # Create a separate features extractor for the critic # this requires more memory and computation self.critic = self.make_critic(features_extractor=None) - critic_parameters = self.critic.parameters() + critic_parameters = list(self.critic.parameters()) # Critic target should not share the feature extactor with critic self.critic_target = self.make_critic(features_extractor=None) @@ -356,7 +367,11 @@ class TQCPolicy(BasePolicy): # Target networks should always be in eval mode self.critic_target.set_training_mode(False) - self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) + self.critic.optimizer = self.optimizer_class( # type: ignore[call-arg] + critic_parameters, + lr=lr_schedule(1), + **self.optimizer_kwargs, + ) def _get_constructor_parameters(self) -> Dict[str, Any]: data = super()._get_constructor_parameters() @@ -447,7 +462,7 @@ class CnnPolicy(TQCPolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -516,7 +531,7 @@ class MultiInputPolicy(TQCPolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 5d8889f..21c7b74 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Un import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm @@ -11,7 +11,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from sb3_contrib.common.utils import quantile_huber_loss -from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy +from sb3_contrib.tqc.policies import Actor, CnnPolicy, Critic, MlpPolicy, MultiInputPolicy, TQCPolicy SelfTQC = TypeVar("SelfTQC", bound="TQC") @@ -73,6 +73,10 @@ class TQC(OffPolicyAlgorithm): "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } + policy: TQCPolicy + actor: Actor + critic: Critic + critic_target: Critic def __init__( self, @@ -87,7 +91,7 @@ class TQC(OffPolicyAlgorithm): train_freq: int = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, ent_coef: Union[str, float] = "auto", @@ -139,7 +143,7 @@ class TQC(OffPolicyAlgorithm): # Inverse of the reward scale self.ent_coef = ent_coef self.target_update_interval = target_update_interval - self.ent_coef_optimizer = None + self.ent_coef_optimizer: Optional[th.optim.Adam] = None self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net if _init_setup_model: @@ -155,7 +159,7 @@ class TQC(OffPolicyAlgorithm): # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": # automatically set target entropy if needed - self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) + self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore else: # Force conversion # this will also throw an error for unexpected string @@ -202,7 +206,7 @@ class TQC(OffPolicyAlgorithm): for gradient_step in range(gradient_steps): # Sample replay buffer - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] # We need to sample because `log_std` may have changed between two gradient steps if self.use_sde: @@ -213,7 +217,7 @@ class TQC(OffPolicyAlgorithm): log_prob = log_prob.reshape(-1, 1) ent_coef_loss = None - if self.ent_coef_optimizer is not None: + if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: # Important: detach the variable from the graph # so we don't change it with other losses # see https://github.com/rail-berkeley/softlearning/issues/60 @@ -224,11 +228,10 @@ class TQC(OffPolicyAlgorithm): ent_coef = self.ent_coef_tensor ent_coefs.append(ent_coef.item()) - self.replay_buffer.ent_coef = ent_coef.item() # Optimize entropy coefficient, also called # entropy temperature or alpha in the paper - if ent_coef_loss is not None: + if ent_coef_loss is not None and self.ent_coef_optimizer is not None: self.ent_coef_optimizer.zero_grad() ent_coef_loss.backward() self.ent_coef_optimizer.step() diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index d97cefd..e885da5 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.distributions import kl_divergence from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy @@ -178,16 +178,16 @@ class TRPO(OnPolicyAlgorithm): """ # This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence # The policy objective is also called surrogate objective - policy_objective_gradients = [] + policy_objective_gradients_list = [] # Contains the gradients of the KL divergence - grad_kl = [] + grad_kl_list = [] # Contains the shape of the gradients of the KL divergence w.r.t each parameter # This way the flattened gradient can be reshaped back into the original shapes and applied to # the parameters - grad_shape = [] + grad_shape: List[Tuple[int, ...]] = [] # Contains the parameters which have non-zeros KL divergence gradients # The list is used during the line-search to apply the step to each parameters - actor_params = [] + actor_params: List[nn.Parameter] = [] for name, param in self.policy.named_parameters(): # Skip parameters related to value function based on name @@ -213,13 +213,13 @@ class TRPO(OnPolicyAlgorithm): policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True) grad_shape.append(kl_param_grad.shape) - grad_kl.append(kl_param_grad.reshape(-1)) - policy_objective_gradients.append(policy_objective_grad.reshape(-1)) + grad_kl_list.append(kl_param_grad.reshape(-1)) + policy_objective_gradients_list.append(policy_objective_grad.reshape(-1)) actor_params.append(param) # Gradients are concatenated before the conjugate gradient step - policy_objective_gradients = th.cat(policy_objective_gradients) - grad_kl = th.cat(grad_kl) + policy_objective_gradients = th.cat(policy_objective_gradients_list) + grad_kl = th.cat(grad_kl_list) return actor_params, policy_objective_gradients, grad_kl, grad_shape def train(self) -> None: @@ -243,10 +243,10 @@ class TRPO(OnPolicyAlgorithm): rollout_data = RolloutBufferSamples( rollout_data.observations[:: self.sub_sampling_factor], rollout_data.actions[:: self.sub_sampling_factor], - None, # old values, not used here + None, # type: ignore[arg-type] # old values, not used here rollout_data.old_log_prob[:: self.sub_sampling_factor], rollout_data.advantages[:: self.sub_sampling_factor], - None, # returns, not used here + None, # type: ignore[arg-type] # returns, not used here ) actions = rollout_data.actions @@ -301,7 +301,7 @@ class TRPO(OnPolicyAlgorithm): line_search_max_step_size /= th.matmul( search_direction, hessian_vector_product_fn(search_direction, retain_graph=False) ) - line_search_max_step_size = th.sqrt(line_search_max_step_size) + line_search_max_step_size = th.sqrt(line_search_max_step_size) # type: ignore[assignment, arg-type] line_search_backtrack_coeff = 1.0 original_actor_params = [param.detach().clone() for param in actor_params] @@ -351,7 +351,7 @@ class TRPO(OnPolicyAlgorithm): param.data = original_param.data.clone() policy_objective_values.append(policy_objective.item()) - kl_divergences.append(0) + kl_divergences.append(0.0) else: policy_objective_values.append(new_policy_objective.item()) kl_divergences.append(kl_div.item()) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 27f9cd3..997bba2 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.8.0 +2.0.0a4 diff --git a/setup.py b/setup.py index ac81d5b..5b7e392 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.8.0,<2.0", + "stable_baselines3>=2.0.0a4", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_cnn.py b/tests/test_cnn.py index be8758a..e505ea5 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -43,7 +43,7 @@ def test_cnn(tmp_path, model_class, share_features_extractor): model = model_class("CnnPolicy", env, **kwargs).learn(250) - obs = env.reset() + obs, _ = env.reset() # FakeImageEnv is channel last by default and should be wrapped assert is_vecenv_wrapped(model.get_env(), VecTransposeImage) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index eada97d..9722da3 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -1,7 +1,10 @@ -import gym +from typing import Dict, Optional + +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import SimpleMultiObsEnv from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize @@ -27,8 +30,8 @@ class DummyDictEnv(gym.Env): else: self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) N_CHANNELS = 1 - HEIGHT = 64 - WIDTH = 64 + HEIGHT = 36 + WIDTH = 36 if channel_last: obs_shape = (HEIGHT, WIDTH, N_CHANNELS) @@ -65,19 +68,31 @@ class DummyDictEnv(gym.Env): def step(self, action): reward = 0.0 - done = False - return self.observation_space.sample(), reward, done, {} + done = truncated = False + return self.observation_space.sample(), reward, done, truncated, {} - def compute_reward(self, achieved_goal, desired_goal, info): - return np.zeros((len(achieved_goal),)) + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + self.observation_space.seed(seed) + return self.observation_space.sample(), {} - def reset(self): - return self.observation_space.sample() - - def render(self, mode="human"): + def render(self): pass +@pytest.mark.parametrize("use_discrete_actions", [True, False]) +@pytest.mark.parametrize("channel_last", [True, False]) +@pytest.mark.parametrize("nested_dict_obs", [True, False]) +@pytest.mark.parametrize("vec_only", [True, False]) +def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only): + # Check the env used for testing + if nested_dict_obs: + with pytest.warns(UserWarning, match="Nested observation spaces are not supported"): + check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) + else: + check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) + + @pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) def test_consistency(model_class): """ @@ -89,7 +104,7 @@ def test_consistency(model_class): dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) dict_env.seed(10) - obs = dict_env.reset() + obs, _ = dict_env.reset() kwargs = {} n_steps = 256 diff --git a/tests/test_identity.py b/tests/test_identity.py index 6ad0317..3725db8 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -12,7 +12,7 @@ DIM = 4 @pytest.mark.parametrize("model_class", [QRDQN, TRPO]) @pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)]) def test_discrete(model_class, env): - env_ = DummyVecEnv([lambda: env]) + vec_env = DummyVecEnv([lambda: env]) kwargs = {} n_steps = 1500 if model_class == QRDQN: @@ -30,9 +30,9 @@ def test_discrete(model_class, env): elif n_steps == TRPO: kwargs = dict(n_steps=256, cg_max_steps=5) - model = model_class("MlpPolicy", env_, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps) - evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False) - obs = env.reset() + evaluate_policy(model, vec_env, n_eval_episodes=20, reward_threshold=90, warn=False) + obs = vec_env.reset() assert np.shape(model.predict(obs)[0]) == np.shape(obs) diff --git a/tests/test_invalid_actions.py b/tests/test_invalid_actions.py index c617660..3d119b0 100644 --- a/tests/test_invalid_actions.py +++ b/tests/test_invalid_actions.py @@ -1,8 +1,10 @@ import random +from typing import Dict, Tuple -import gym +import gymnasium as gym +import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox @@ -30,12 +32,12 @@ class ToDictWrapper(gym.Wrapper): super().__init__(env) self.observation_space = spaces.Dict({"obs": self.env.observation_space}) - def reset(self): - return {"obs": self.env.reset()} + def reset(self, **kwargs) -> Tuple[Dict[str, np.ndarray], Dict]: + return {"obs": self.env.reset(seed=kwargs.get("seed", 0))[0]}, {} # type: ignore[dict-item] def step(self, action): - obs, reward, done, infos = self.env.step(action) - return {"obs": obs}, reward, done, infos + obs, reward, terminated, truncated, infos = self.env.step(action) + return {"obs": obs}, reward, terminated, truncated, infos def test_identity(): diff --git a/tests/test_lstm.py b/tests/test_lstm.py index dc7cab7..f0243df 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -1,10 +1,13 @@ -import gym +from typing import Dict, Optional + +import gymnasium as gym import numpy as np import pytest -from gym import spaces -from gym.envs.classic_control import CartPoleEnv -from gym.wrappers.time_limit import TimeLimit +from gymnasium import spaces +from gymnasium.envs.classic_control import CartPoleEnv +from gymnasium.wrappers.time_limit import TimeLimit from stable_baselines3.common.callbacks import EvalCallback +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.evaluation import evaluate_policy @@ -22,12 +25,12 @@ class ToDictWrapper(gym.Wrapper): super().__init__(env) self.observation_space = spaces.Dict({"obs": self.env.observation_space}) - def reset(self): - return {"obs": self.env.reset()} + def reset(self, **kwargs): + return {"obs": self.env.reset(**kwargs)[0]}, {} def step(self, action): - obs, reward, done, infos = self.env.step(action) - return {"obs": obs}, reward, done, infos + obs, reward, done, truncated, infos = self.env.step(action) + return {"obs": obs}, reward, done, truncated, infos class CartPoleNoVelEnv(CartPoleEnv): @@ -46,15 +49,19 @@ class CartPoleNoVelEnv(CartPoleEnv): @staticmethod def _pos_obs(full_obs): xpos, _xvel, thetapos, _thetavel = full_obs - return xpos, thetapos + return np.array([xpos, thetapos]) - def reset(self): - full_obs = super().reset() - return CartPoleNoVelEnv._pos_obs(full_obs) + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + full_obs, info = super().reset(seed=seed, options=options) + return CartPoleNoVelEnv._pos_obs(full_obs), info def step(self, action): - full_obs, rew, done, info = super().step(action) - return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info + full_obs, rew, terminated, truncated, info = super().step(action) + return CartPoleNoVelEnv._pos_obs(full_obs), rew, terminated, truncated, info + + +def test_env(): + check_env(CartPoleNoVelEnv()) @pytest.mark.parametrize( diff --git a/tests/test_run.py b/tests/test_run.py index 22c568b..90b5703 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym import pytest from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import VecNormalize diff --git a/tests/test_save_load.py b/tests/test_save_load.py index b2a62d2..502d239 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -3,7 +3,7 @@ import pathlib from collections import OrderedDict from copy import deepcopy -import gym +import gymnasium as gym import numpy as np import pytest import torch as th @@ -24,7 +24,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env: if model_class == QRDQN: return IdentityEnv(10) else: - return IdentityEnvBox(10) + return IdentityEnvBox(-10, 10) @pytest.mark.parametrize("model_class", MODEL_LIST) diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index 1973d45..cb656b7 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -1,6 +1,6 @@ from typing import Union -import gym +import gymnasium as gym import numpy as np import pytest import torch as th @@ -111,7 +111,7 @@ def test_ppo_mask_train_eval_mode(): batch_norm_stats_before = clone_on_policy_batch_norm(model) - observation = env.reset() + observation, _ = env.reset() action_masks = get_action_masks(env) first_prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True) for _ in range(5): diff --git a/tests/wrappers/test_action_masker.py b/tests/wrappers/test_action_masker.py index be42987..ebfb83c 100644 --- a/tests/wrappers/test_action_masker.py +++ b/tests/wrappers/test_action_masker.py @@ -1,7 +1,7 @@ from typing import List import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.envs import IdentityEnv from sb3_contrib.common.wrappers import ActionMasker @@ -20,10 +20,12 @@ class IdentityEnvDiscrete(IdentityEnv): super().__init__(ep_length=ep_length, space=space) def _action_masks(self) -> List[int]: + assert isinstance(self.action_space, spaces.Discrete) return [i == self.state for i in range(self.action_space.n)] def action_mask_fn(env: IdentityEnvDiscrete) -> List[int]: + assert isinstance(env.action_space, spaces.Discrete) return [i == env.state for i in range(env.action_space.n)] diff --git a/tests/wrappers/test_time_feature.py b/tests/wrappers/test_time_feature.py index bfbf33f..1d53f89 100644 --- a/tests/wrappers/test_time_feature.py +++ b/tests/wrappers/test_time_feature.py @@ -1,14 +1,14 @@ -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import BitFlippingEnv from sb3_contrib.common.wrappers import TimeFeatureWrapper -class CustomGoalEnv(gym.GoalEnv): +class CustomGoalEnv(gym.Env): def __init__(self): super().__init__() self.observation_space = spaces.Dict( @@ -21,10 +21,10 @@ class CustomGoalEnv(gym.GoalEnv): self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} def check_time_feature(obs, timestep, max_timesteps): @@ -37,14 +37,16 @@ def test_time_feature(): check_env(env, warn=False) # Check for four episodes max_timesteps = 200 - obs = env.reset() + obs, _ = env.reset() for _ in range(4): + done = False check_time_feature(obs, timestep=0, max_timesteps=max_timesteps) for step in range(1, max_timesteps + 1): - obs, _, done, _ = env.step(env.action_space.sample()) + obs, _, terminated, truncated, _ = env.step(env.action_space.sample()) check_time_feature(obs, timestep=step, max_timesteps=max_timesteps) + done = terminated or truncated if done: - obs = env.reset() + obs, _ = env.reset() env = BitFlippingEnv() with pytest.raises(AssertionError): @@ -52,16 +54,16 @@ def test_time_feature(): env = CustomGoalEnv() env = TimeFeatureWrapper(env, max_steps=500) - obs = env.reset() + obs, _ = env.reset() check_time_feature(obs["observation"], timestep=0, max_timesteps=500) - obs, _, _, _ = env.step(env.action_space.sample()) + obs = env.step(env.action_space.sample())[0] check_time_feature(obs["observation"], timestep=1, max_timesteps=500) # In test mode, the time feature must be constant env = gym.make("Pendulum-v1") env = TimeFeatureWrapper(env, test_mode=True) - obs = env.reset() + obs, _ = env.reset() check_time_feature(obs, timestep=0, max_timesteps=200) - obs, _, _, _ = env.step(env.action_space.sample()) + obs = env.step(env.action_space.sample())[0] # Should be the same check_time_feature(obs, timestep=0, max_timesteps=200)