diff --git a/.github/ISSUE_TEMPLATE/issue-template.md b/.github/ISSUE_TEMPLATE/issue-template.md index b84c0aa..e185110 100644 --- a/.github/ISSUE_TEMPLATE/issue-template.md +++ b/.github/ISSUE_TEMPLATE/issue-template.md @@ -52,9 +52,8 @@ Describe the characteristic of your environment: * Versions of any other relevant libraries You can use `sb3.get_system_info()` to print relevant packages info: - ```python - import stable_baselines3 as sb3 - sb3.get_system_info() + ```sh + python -c 'import stable_baselines3 as sb3; sb3.get_system_info()' ``` **Additional context** diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4c8c17b..dee9f86 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.7.0a10 (WIP) +Release 1.7.0a11 (WIP) -------------------------- Breaking Changes: @@ -11,7 +11,7 @@ Breaking Changes: - Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters, please use an ``EvalCallback`` instead - Removed deprecated ``sde_net_arch`` parameter -- Upgraded to Stable-Baselines3 >= 1.7.0a10 +- Upgraded to Stable-Baselines3 >= 1.7.0a11 New Features: ^^^^^^^^^^^^^ @@ -41,6 +41,7 @@ Others: - Removed ignores on Flake8 F401 (@ZikangXiong) - Upgraded GitHub CI/setup-python to v4 and checkout to v3 - Set tensors construction directly on the device +- Standardized the use of ``from gym import spaces`` Release 1.6.2 (2022-10-10) -------------------------- diff --git a/sb3_contrib/ars/ars.py b/sb3_contrib/ars/ars.py index ef498d1..4b05412 100644 --- a/sb3_contrib/ars/ars.py +++ b/sb3_contrib/ars/ars.py @@ -5,10 +5,10 @@ import warnings from functools import partial from typing import Any, Dict, Optional, Type, TypeVar, Union -import gym import numpy as np import torch as th import torch.nn.utils +from gym import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.evaluation import evaluate_policy @@ -80,7 +80,7 @@ class ARS(BaseAlgorithm): policy_kwargs=policy_kwargs, verbose=verbose, device=device, - supported_action_spaces=(gym.spaces.Box, gym.spaces.Discrete), + supported_action_spaces=(spaces.Box, spaces.Discrete), support_multi_env=True, seed=seed, ) diff --git a/sb3_contrib/ars/policies.py b/sb3_contrib/ars/policies.py index b44e576..782d441 100644 --- a/sb3_contrib/ars/policies.py +++ b/sb3_contrib/ars/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Type -import gym import torch as th +from gym import spaces from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp @@ -23,8 +23,8 @@ class ARSPolicy(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, with_bias: bool = True, @@ -34,7 +34,7 @@ class ARSPolicy(BasePolicy): super().__init__( observation_space, action_space, - squash_output=isinstance(action_space, gym.spaces.Box) and squash_output, + squash_output=isinstance(action_space, spaces.Box) and squash_output, ) if net_arch is None: @@ -45,12 +45,12 @@ class ARSPolicy(BasePolicy): self.features_dim = self.features_extractor.features_dim self.activation_fn = activation_fn - if isinstance(action_space, gym.spaces.Box): + if isinstance(action_space, spaces.Box): action_dim = get_action_dim(action_space) actor_net = create_mlp( self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output ) - elif isinstance(action_space, gym.spaces.Discrete): + elif isinstance(action_space, spaces.Discrete): actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias) else: raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.") @@ -72,9 +72,9 @@ class ARSPolicy(BasePolicy): assert isinstance(self.features_extractor, BaseFeaturesExtractor) features = self.extract_features(obs, self.features_extractor) - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, spaces.Box): return self.action_net(features) - elif isinstance(self.action_space, gym.spaces.Discrete): + elif isinstance(self.action_space, spaces.Discrete): logits = self.action_net(features) return th.argmax(logits, dim=1) else: @@ -98,8 +98,8 @@ class ARSLinearPolicy(ARSPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, with_bias: bool = False, squash_output: bool = False, ): diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index c4fbdcc..fbf597e 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -1,9 +1,9 @@ from functools import partial from typing import Any, Dict, List, Optional, Tuple, Type, Union -import gym import numpy as np import torch as th +from gym import spaces from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -43,8 +43,8 @@ class MaskableActorCriticPolicy(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, @@ -281,7 +281,7 @@ class MaskableActorCriticPolicy(BasePolicy): # Convert to numpy actions = actions.cpu().numpy() - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, spaces.Box): if self.squash_output: # Rescale to proper domain when using squashing actions = self.unscale_action(actions) @@ -379,8 +379,8 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, @@ -433,8 +433,8 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy): def __init__( self, - observation_space: gym.spaces.Dict, - action_space: gym.spaces.Space, + observation_space: spaces.Dict, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 953fafd..1dd4869 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -1,8 +1,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union -import gym import numpy as np import torch as th +from gym import spaces from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.torch_layers import ( @@ -63,8 +63,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, @@ -415,7 +415,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): # Convert to numpy actions = actions.cpu().numpy() - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, spaces.Box): if self.squash_output: # Rescale to proper domain when using squashing actions = self.unscale_action(actions) @@ -472,8 +472,8 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, @@ -562,8 +562,8 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, diff --git a/sb3_contrib/common/wrappers/time_feature.py b/sb3_contrib/common/wrappers/time_feature.py index 2aa5a9d..4aa5a2f 100644 --- a/sb3_contrib/common/wrappers/time_feature.py +++ b/sb3_contrib/common/wrappers/time_feature.py @@ -2,6 +2,7 @@ from typing import Dict, Union import gym import numpy as np +from gym import spaces from stable_baselines3.common.type_aliases import GymObs, GymStepReturn @@ -25,16 +26,14 @@ class TimeFeatureWrapper(gym.Wrapper): def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False): assert isinstance( - env.observation_space, (gym.spaces.Box, gym.spaces.Dict) - ), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `gym.spaces.Dict` (`gym.GoalEnv`) observation spaces." + env.observation_space, (spaces.Box, spaces.Dict) + ), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `spaces.Dict` (`gym.GoalEnv`) observation spaces." # Add a time feature to the observation - if isinstance(env.observation_space, gym.spaces.Dict): + if isinstance(env.observation_space, spaces.Dict): assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space" obs_space = env.observation_space.spaces["observation"] - assert isinstance( - obs_space, gym.spaces.Box - ), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space." + assert isinstance(obs_space, spaces.Box), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space." obs_space = env.observation_space.spaces["observation"] else: obs_space = env.observation_space @@ -45,10 +44,10 @@ class TimeFeatureWrapper(gym.Wrapper): low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0])) self.dtype = obs_space.dtype - if isinstance(env.observation_space, gym.spaces.Dict): - env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=self.dtype) + if isinstance(env.observation_space, spaces.Dict): + env.observation_space.spaces["observation"] = spaces.Box(low=low, high=high, dtype=self.dtype) else: - env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype) + env.observation_space = spaces.Box(low=low, high=high, dtype=self.dtype) super().__init__(env) diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 78aa58f..cbeee09 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -137,9 +137,7 @@ class MaskablePPO(OnPolicyAlgorithm): self._setup_lr_schedule() self.set_random_seed(self.seed) - buffer_cls = ( - MaskableDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else MaskableRolloutBuffer - ) + buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer self.policy = self.policy_class( self.observation_space, diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index ee71471..9cff998 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -3,7 +3,6 @@ import time from copy import deepcopy from typing import Any, Dict, Optional, Type, TypeVar, Union -import gym import numpy as np import torch as th from gym import spaces @@ -139,9 +138,7 @@ class RecurrentPPO(OnPolicyAlgorithm): self._setup_lr_schedule() self.set_random_seed(self.seed) - buffer_cls = ( - RecurrentDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer - ) + buffer_cls = RecurrentDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RecurrentRolloutBuffer self.policy = self.policy_class( self.observation_space, @@ -247,7 +244,7 @@ class RecurrentPPO(OnPolicyAlgorithm): # Rescale and perform action clipped_actions = actions # Clip the actions to avoid out of bound error - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, spaces.Box): clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) new_obs, rewards, dones, infos = env.step(clipped_actions) @@ -262,7 +259,7 @@ class RecurrentPPO(OnPolicyAlgorithm): self._update_info_buffer(infos) n_steps += 1 - if isinstance(self.action_space, gym.spaces.Discrete): + if isinstance(self.action_space, spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index a4a0010..cc78dfb 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Type -import gym import torch as th +from gym import spaces from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -29,8 +29,8 @@ class QuantileNetwork(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, features_extractor: nn.Module, features_dim: int, n_quantiles: int = 200, @@ -111,8 +111,8 @@ class QRDQNPolicy(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, n_quantiles: int = 200, net_arch: Optional[List[int]] = None, @@ -233,8 +233,8 @@ class CnnPolicy(QRDQNPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, n_quantiles: int = 200, net_arch: Optional[List[int]] = None, @@ -281,8 +281,8 @@ class MultiInputPolicy(QRDQNPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, n_quantiles: int = 200, net_arch: Optional[List[int]] = None, diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 22c5810..66ea3c8 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -1,8 +1,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union -import gym import numpy as np import torch as th +from gym import spaces from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy @@ -113,7 +113,7 @@ class QRDQN(OffPolicyAlgorithm): seed=seed, sde_support=False, optimize_memory_usage=optimize_memory_usage, - supported_action_spaces=(gym.spaces.Discrete,), + supported_action_spaces=(spaces.Discrete,), support_multi_env=True, ) @@ -235,7 +235,7 @@ class QRDQN(OffPolicyAlgorithm): """ if not deterministic and np.random.rand() < self.exploration_rate: if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): - if isinstance(self.observation_space, gym.spaces.Dict): + if isinstance(self.observation_space, spaces.Dict): n_batch = observation[list(observation.keys())[0]].shape[0] else: n_batch = observation.shape[0] diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index c25ebe3..ce7fb26 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union -import gym import torch as th +from gym import spaces from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution from stable_baselines3.common.policies import BaseModel, BasePolicy from stable_baselines3.common.preprocessing import get_action_dim @@ -46,8 +46,8 @@ class Actor(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, net_arch: List[int], features_extractor: nn.Module, features_dim: int, @@ -194,8 +194,8 @@ class Critic(BaseModel): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, net_arch: List[int], features_extractor: nn.Module, features_dim: int, @@ -268,8 +268,8 @@ class TQCPolicy(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -446,8 +446,8 @@ class CnnPolicy(TQCPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -515,8 +515,8 @@ class MultiInputPolicy(TQCPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 17151e8..9d452d4 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -1,8 +1,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union -import gym import numpy as np import torch as th +from gym import spaces from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm @@ -125,7 +125,7 @@ class TQC(OffPolicyAlgorithm): sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, optimize_memory_usage=optimize_memory_usage, - supported_action_spaces=(gym.spaces.Box), + supported_action_spaces=(spaces.Box,), support_multi_env=True, ) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 89e17c2..a02b7e4 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.7.0a10 +1.7.0a11 diff --git a/setup.py b/setup.py index c332330..ba766cd 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.7.0a10", + "stable_baselines3>=1.7.0a11", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_invalid_actions.py b/tests/test_invalid_actions.py index 3c2c5f8..c617660 100644 --- a/tests/test_invalid_actions.py +++ b/tests/test_invalid_actions.py @@ -2,6 +2,7 @@ import random import gym import pytest +from gym import spaces from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox @@ -27,7 +28,7 @@ class ToDictWrapper(gym.Wrapper): def __init__(self, env): super().__init__(env) - self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space}) + self.observation_space = spaces.Dict({"obs": self.env.observation_space}) def reset(self): return {"obs": self.env.reset()} diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 1e62e0e..0bfdb38 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -20,7 +20,7 @@ class ToDictWrapper(gym.Wrapper): def __init__(self, env): super().__init__(env) - self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space}) + self.observation_space = spaces.Dict({"obs": self.env.observation_space}) def reset(self): return {"obs": self.env.reset()} diff --git a/tests/wrappers/test_action_masker.py b/tests/wrappers/test_action_masker.py index e130959..be42987 100644 --- a/tests/wrappers/test_action_masker.py +++ b/tests/wrappers/test_action_masker.py @@ -1,7 +1,7 @@ from typing import List import pytest -from gym.spaces import Discrete +from gym import spaces from stable_baselines3.common.envs import IdentityEnv from sb3_contrib.common.wrappers import ActionMasker @@ -15,7 +15,7 @@ class IdentityEnvDiscrete(IdentityEnv): :param dim: the size of the dimensions you want to learn :param ep_length: the length of each episode in timesteps """ - space = Discrete(dim) + space = spaces.Discrete(dim) self.useless_property = 1 super().__init__(ep_length=ep_length, space=space)