Standardize the use of ``from gym import spaces`` (#131)

* Standardize from gym import spaces

* update changelog

* update issue template

* update version

* Update version
This commit is contained in:
Quentin Gallouédec 2023-01-02 15:35:00 +01:00 committed by GitHub
parent c9bd045d5c
commit 7c4a249fa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 77 additions and 82 deletions

View File

@ -52,9 +52,8 @@ Describe the characteristic of your environment:
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```python
import stable_baselines3 as sb3
sb3.get_system_info()
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
**Additional context**

View File

@ -3,7 +3,7 @@
Changelog
==========
Release 1.7.0a10 (WIP)
Release 1.7.0a11 (WIP)
--------------------------
Breaking Changes:
@ -11,7 +11,7 @@ Breaking Changes:
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
please use an ``EvalCallback`` instead
- Removed deprecated ``sde_net_arch`` parameter
- Upgraded to Stable-Baselines3 >= 1.7.0a10
- Upgraded to Stable-Baselines3 >= 1.7.0a11
New Features:
^^^^^^^^^^^^^
@ -41,6 +41,7 @@ Others:
- Removed ignores on Flake8 F401 (@ZikangXiong)
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
- Set tensors construction directly on the device
- Standardized the use of ``from gym import spaces``
Release 1.6.2 (2022-10-10)
--------------------------

View File

@ -5,10 +5,10 @@ import warnings
from functools import partial
from typing import Any, Dict, Optional, Type, TypeVar, Union
import gym
import numpy as np
import torch as th
import torch.nn.utils
from gym import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
@ -80,7 +80,7 @@ class ARS(BaseAlgorithm):
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
supported_action_spaces=(gym.spaces.Box, gym.spaces.Discrete),
supported_action_spaces=(spaces.Box, spaces.Discrete),
support_multi_env=True,
seed=seed,
)

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Type
import gym
import torch as th
from gym import spaces
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
@ -23,8 +23,8 @@ class ARSPolicy(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
with_bias: bool = True,
@ -34,7 +34,7 @@ class ARSPolicy(BasePolicy):
super().__init__(
observation_space,
action_space,
squash_output=isinstance(action_space, gym.spaces.Box) and squash_output,
squash_output=isinstance(action_space, spaces.Box) and squash_output,
)
if net_arch is None:
@ -45,12 +45,12 @@ class ARSPolicy(BasePolicy):
self.features_dim = self.features_extractor.features_dim
self.activation_fn = activation_fn
if isinstance(action_space, gym.spaces.Box):
if isinstance(action_space, spaces.Box):
action_dim = get_action_dim(action_space)
actor_net = create_mlp(
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
)
elif isinstance(action_space, gym.spaces.Discrete):
elif isinstance(action_space, spaces.Discrete):
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
else:
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
@ -72,9 +72,9 @@ class ARSPolicy(BasePolicy):
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
features = self.extract_features(obs, self.features_extractor)
if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, spaces.Box):
return self.action_net(features)
elif isinstance(self.action_space, gym.spaces.Discrete):
elif isinstance(self.action_space, spaces.Discrete):
logits = self.action_net(features)
return th.argmax(logits, dim=1)
else:
@ -98,8 +98,8 @@ class ARSLinearPolicy(ARSPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
with_bias: bool = False,
squash_output: bool = False,
):

View File

@ -1,9 +1,9 @@
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
@ -43,8 +43,8 @@ class MaskableActorCriticPolicy(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
@ -281,7 +281,7 @@ class MaskableActorCriticPolicy(BasePolicy):
# Convert to numpy
actions = actions.cpu().numpy()
if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
@ -379,8 +379,8 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
@ -433,8 +433,8 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Dict,
action_space: gym.spaces.Space,
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,

View File

@ -1,8 +1,8 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.distributions import Distribution
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.torch_layers import (
@ -63,8 +63,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
@ -415,7 +415,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
# Convert to numpy
actions = actions.cpu().numpy()
if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
@ -472,8 +472,8 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
@ -562,8 +562,8 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,

View File

@ -2,6 +2,7 @@ from typing import Dict, Union
import gym
import numpy as np
from gym import spaces
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
@ -25,16 +26,14 @@ class TimeFeatureWrapper(gym.Wrapper):
def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False):
assert isinstance(
env.observation_space, (gym.spaces.Box, gym.spaces.Dict)
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `gym.spaces.Dict` (`gym.GoalEnv`) observation spaces."
env.observation_space, (spaces.Box, spaces.Dict)
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `spaces.Dict` (`gym.GoalEnv`) observation spaces."
# Add a time feature to the observation
if isinstance(env.observation_space, gym.spaces.Dict):
if isinstance(env.observation_space, spaces.Dict):
assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
obs_space = env.observation_space.spaces["observation"]
assert isinstance(
obs_space, gym.spaces.Box
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
assert isinstance(obs_space, spaces.Box), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
obs_space = env.observation_space.spaces["observation"]
else:
obs_space = env.observation_space
@ -45,10 +44,10 @@ class TimeFeatureWrapper(gym.Wrapper):
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
self.dtype = obs_space.dtype
if isinstance(env.observation_space, gym.spaces.Dict):
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
if isinstance(env.observation_space, spaces.Dict):
env.observation_space.spaces["observation"] = spaces.Box(low=low, high=high, dtype=self.dtype)
else:
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
env.observation_space = spaces.Box(low=low, high=high, dtype=self.dtype)
super().__init__(env)

View File

@ -137,9 +137,7 @@ class MaskablePPO(OnPolicyAlgorithm):
self._setup_lr_schedule()
self.set_random_seed(self.seed)
buffer_cls = (
MaskableDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else MaskableRolloutBuffer
)
buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer
self.policy = self.policy_class(
self.observation_space,

View File

@ -3,7 +3,6 @@ import time
from copy import deepcopy
from typing import Any, Dict, Optional, Type, TypeVar, Union
import gym
import numpy as np
import torch as th
from gym import spaces
@ -139,9 +138,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
self._setup_lr_schedule()
self.set_random_seed(self.seed)
buffer_cls = (
RecurrentDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer
)
buffer_cls = RecurrentDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RecurrentRolloutBuffer
self.policy = self.policy_class(
self.observation_space,
@ -247,7 +244,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
# Rescale and perform action
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, spaces.Box):
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions)
@ -262,7 +259,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
self._update_info_buffer(infos)
n_steps += 1
if isinstance(self.action_space, gym.spaces.Discrete):
if isinstance(self.action_space, spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Type
import gym
import torch as th
from gym import spaces
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
@ -29,8 +29,8 @@ class QuantileNetwork(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
features_extractor: nn.Module,
features_dim: int,
n_quantiles: int = 200,
@ -111,8 +111,8 @@ class QRDQNPolicy(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
@ -233,8 +233,8 @@ class CnnPolicy(QRDQNPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
@ -281,8 +281,8 @@ class MultiInputPolicy(QRDQNPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,

View File

@ -1,8 +1,8 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
@ -113,7 +113,7 @@ class QRDQN(OffPolicyAlgorithm):
seed=seed,
sde_support=False,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Discrete,),
supported_action_spaces=(spaces.Discrete,),
support_multi_env=True,
)
@ -235,7 +235,7 @@ class QRDQN(OffPolicyAlgorithm):
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
if isinstance(self.observation_space, gym.spaces.Dict):
if isinstance(self.observation_space, spaces.Dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
else:
n_batch = observation.shape[0]

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th
from gym import spaces
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BaseModel, BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim
@ -46,8 +46,8 @@ class Actor(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
@ -194,8 +194,8 @@ class Critic(BaseModel):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
@ -268,8 +268,8 @@ class TQCPolicy(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
@ -446,8 +446,8 @@ class CnnPolicy(TQCPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
@ -515,8 +515,8 @@ class MultiInputPolicy(TQCPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,

View File

@ -1,8 +1,8 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
@ -125,7 +125,7 @@ class TQC(OffPolicyAlgorithm):
sde_sample_freq=sde_sample_freq,
use_sde_at_warmup=use_sde_at_warmup,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Box),
supported_action_spaces=(spaces.Box,),
support_multi_env=True,
)

View File

@ -1 +1 @@
1.7.0a10
1.7.0a11

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.7.0a10",
"stable_baselines3>=1.7.0a11",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",

View File

@ -2,6 +2,7 @@ import random
import gym
import pytest
from gym import spaces
from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
@ -27,7 +28,7 @@ class ToDictWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space})
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
def reset(self):
return {"obs": self.env.reset()}

View File

@ -20,7 +20,7 @@ class ToDictWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space})
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
def reset(self):
return {"obs": self.env.reset()}

View File

@ -1,7 +1,7 @@
from typing import List
import pytest
from gym.spaces import Discrete
from gym import spaces
from stable_baselines3.common.envs import IdentityEnv
from sb3_contrib.common.wrappers import ActionMasker
@ -15,7 +15,7 @@ class IdentityEnvDiscrete(IdentityEnv):
:param dim: the size of the dimensions you want to learn
:param ep_length: the length of each episode in timesteps
"""
space = Discrete(dim)
space = spaces.Discrete(dim)
self.useless_property = 1
super().__init__(ep_length=ep_length, space=space)