Standardize the use of ``from gym import spaces`` (#131)

* Standardize from gym import spaces

* update changelog

* update issue template

* update version

* Update version
This commit is contained in:
Quentin Gallouédec 2023-01-02 15:35:00 +01:00 committed by GitHub
parent c9bd045d5c
commit 7c4a249fa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 77 additions and 82 deletions

View File

@ -52,9 +52,8 @@ Describe the characteristic of your environment:
* Versions of any other relevant libraries * Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info: You can use `sb3.get_system_info()` to print relevant packages info:
```python ```sh
import stable_baselines3 as sb3 python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
sb3.get_system_info()
``` ```
**Additional context** **Additional context**

View File

@ -3,7 +3,7 @@
Changelog Changelog
========== ==========
Release 1.7.0a10 (WIP) Release 1.7.0a11 (WIP)
-------------------------- --------------------------
Breaking Changes: Breaking Changes:
@ -11,7 +11,7 @@ Breaking Changes:
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters, - Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
please use an ``EvalCallback`` instead please use an ``EvalCallback`` instead
- Removed deprecated ``sde_net_arch`` parameter - Removed deprecated ``sde_net_arch`` parameter
- Upgraded to Stable-Baselines3 >= 1.7.0a10 - Upgraded to Stable-Baselines3 >= 1.7.0a11
New Features: New Features:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
@ -41,6 +41,7 @@ Others:
- Removed ignores on Flake8 F401 (@ZikangXiong) - Removed ignores on Flake8 F401 (@ZikangXiong)
- Upgraded GitHub CI/setup-python to v4 and checkout to v3 - Upgraded GitHub CI/setup-python to v4 and checkout to v3
- Set tensors construction directly on the device - Set tensors construction directly on the device
- Standardized the use of ``from gym import spaces``
Release 1.6.2 (2022-10-10) Release 1.6.2 (2022-10-10)
-------------------------- --------------------------

View File

@ -5,10 +5,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Dict, Optional, Type, TypeVar, Union from typing import Any, Dict, Optional, Type, TypeVar, Union
import gym
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn.utils import torch.nn.utils
from gym import spaces
from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.evaluation import evaluate_policy
@ -80,7 +80,7 @@ class ARS(BaseAlgorithm):
policy_kwargs=policy_kwargs, policy_kwargs=policy_kwargs,
verbose=verbose, verbose=verbose,
device=device, device=device,
supported_action_spaces=(gym.spaces.Box, gym.spaces.Discrete), supported_action_spaces=(spaces.Box, spaces.Discrete),
support_multi_env=True, support_multi_env=True,
seed=seed, seed=seed,
) )

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
import gym
import torch as th import torch as th
from gym import spaces
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
@ -23,8 +23,8 @@ class ARSPolicy(BasePolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
net_arch: Optional[List[int]] = None, net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU, activation_fn: Type[nn.Module] = nn.ReLU,
with_bias: bool = True, with_bias: bool = True,
@ -34,7 +34,7 @@ class ARSPolicy(BasePolicy):
super().__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
squash_output=isinstance(action_space, gym.spaces.Box) and squash_output, squash_output=isinstance(action_space, spaces.Box) and squash_output,
) )
if net_arch is None: if net_arch is None:
@ -45,12 +45,12 @@ class ARSPolicy(BasePolicy):
self.features_dim = self.features_extractor.features_dim self.features_dim = self.features_extractor.features_dim
self.activation_fn = activation_fn self.activation_fn = activation_fn
if isinstance(action_space, gym.spaces.Box): if isinstance(action_space, spaces.Box):
action_dim = get_action_dim(action_space) action_dim = get_action_dim(action_space)
actor_net = create_mlp( actor_net = create_mlp(
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
) )
elif isinstance(action_space, gym.spaces.Discrete): elif isinstance(action_space, spaces.Discrete):
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias) actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
else: else:
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.") raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
@ -72,9 +72,9 @@ class ARSPolicy(BasePolicy):
assert isinstance(self.features_extractor, BaseFeaturesExtractor) assert isinstance(self.features_extractor, BaseFeaturesExtractor)
features = self.extract_features(obs, self.features_extractor) features = self.extract_features(obs, self.features_extractor)
if isinstance(self.action_space, gym.spaces.Box): if isinstance(self.action_space, spaces.Box):
return self.action_net(features) return self.action_net(features)
elif isinstance(self.action_space, gym.spaces.Discrete): elif isinstance(self.action_space, spaces.Discrete):
logits = self.action_net(features) logits = self.action_net(features)
return th.argmax(logits, dim=1) return th.argmax(logits, dim=1)
else: else:
@ -98,8 +98,8 @@ class ARSLinearPolicy(ARSPolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
with_bias: bool = False, with_bias: bool = False,
squash_output: bool = False, squash_output: bool = False,
): ):

View File

@ -1,9 +1,9 @@
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np import numpy as np
import torch as th import torch as th
from gym import spaces
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor, BaseFeaturesExtractor,
@ -43,8 +43,8 @@ class MaskableActorCriticPolicy(BasePolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh, activation_fn: Type[nn.Module] = nn.Tanh,
@ -281,7 +281,7 @@ class MaskableActorCriticPolicy(BasePolicy):
# Convert to numpy # Convert to numpy
actions = actions.cpu().numpy() actions = actions.cpu().numpy()
if isinstance(self.action_space, gym.spaces.Box): if isinstance(self.action_space, spaces.Box):
if self.squash_output: if self.squash_output:
# Rescale to proper domain when using squashing # Rescale to proper domain when using squashing
actions = self.unscale_action(actions) actions = self.unscale_action(actions)
@ -379,8 +379,8 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh, activation_fn: Type[nn.Module] = nn.Tanh,
@ -433,8 +433,8 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Dict, observation_space: spaces.Dict,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh, activation_fn: Type[nn.Module] = nn.Tanh,

View File

@ -1,8 +1,8 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np import numpy as np
import torch as th import torch as th
from gym import spaces
from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.distributions import Distribution
from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.torch_layers import (
@ -63,8 +63,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh, activation_fn: Type[nn.Module] = nn.Tanh,
@ -415,7 +415,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
# Convert to numpy # Convert to numpy
actions = actions.cpu().numpy() actions = actions.cpu().numpy()
if isinstance(self.action_space, gym.spaces.Box): if isinstance(self.action_space, spaces.Box):
if self.squash_output: if self.squash_output:
# Rescale to proper domain when using squashing # Rescale to proper domain when using squashing
actions = self.unscale_action(actions) actions = self.unscale_action(actions)
@ -472,8 +472,8 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh, activation_fn: Type[nn.Module] = nn.Tanh,
@ -562,8 +562,8 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh, activation_fn: Type[nn.Module] = nn.Tanh,

View File

@ -2,6 +2,7 @@ from typing import Dict, Union
import gym import gym
import numpy as np import numpy as np
from gym import spaces
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
@ -25,16 +26,14 @@ class TimeFeatureWrapper(gym.Wrapper):
def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False): def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False):
assert isinstance( assert isinstance(
env.observation_space, (gym.spaces.Box, gym.spaces.Dict) env.observation_space, (spaces.Box, spaces.Dict)
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `gym.spaces.Dict` (`gym.GoalEnv`) observation spaces." ), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `spaces.Dict` (`gym.GoalEnv`) observation spaces."
# Add a time feature to the observation # Add a time feature to the observation
if isinstance(env.observation_space, gym.spaces.Dict): if isinstance(env.observation_space, spaces.Dict):
assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space" assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
obs_space = env.observation_space.spaces["observation"] obs_space = env.observation_space.spaces["observation"]
assert isinstance( assert isinstance(obs_space, spaces.Box), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
obs_space, gym.spaces.Box
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
obs_space = env.observation_space.spaces["observation"] obs_space = env.observation_space.spaces["observation"]
else: else:
obs_space = env.observation_space obs_space = env.observation_space
@ -45,10 +44,10 @@ class TimeFeatureWrapper(gym.Wrapper):
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0])) low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
self.dtype = obs_space.dtype self.dtype = obs_space.dtype
if isinstance(env.observation_space, gym.spaces.Dict): if isinstance(env.observation_space, spaces.Dict):
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=self.dtype) env.observation_space.spaces["observation"] = spaces.Box(low=low, high=high, dtype=self.dtype)
else: else:
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype) env.observation_space = spaces.Box(low=low, high=high, dtype=self.dtype)
super().__init__(env) super().__init__(env)

View File

@ -137,9 +137,7 @@ class MaskablePPO(OnPolicyAlgorithm):
self._setup_lr_schedule() self._setup_lr_schedule()
self.set_random_seed(self.seed) self.set_random_seed(self.seed)
buffer_cls = ( buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer
MaskableDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else MaskableRolloutBuffer
)
self.policy = self.policy_class( self.policy = self.policy_class(
self.observation_space, self.observation_space,

View File

@ -3,7 +3,6 @@ import time
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Optional, Type, TypeVar, Union from typing import Any, Dict, Optional, Type, TypeVar, Union
import gym
import numpy as np import numpy as np
import torch as th import torch as th
from gym import spaces from gym import spaces
@ -139,9 +138,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
self._setup_lr_schedule() self._setup_lr_schedule()
self.set_random_seed(self.seed) self.set_random_seed(self.seed)
buffer_cls = ( buffer_cls = RecurrentDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RecurrentRolloutBuffer
RecurrentDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer
)
self.policy = self.policy_class( self.policy = self.policy_class(
self.observation_space, self.observation_space,
@ -247,7 +244,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
# Rescale and perform action # Rescale and perform action
clipped_actions = actions clipped_actions = actions
# Clip the actions to avoid out of bound error # Clip the actions to avoid out of bound error
if isinstance(self.action_space, gym.spaces.Box): if isinstance(self.action_space, spaces.Box):
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions) new_obs, rewards, dones, infos = env.step(clipped_actions)
@ -262,7 +259,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
self._update_info_buffer(infos) self._update_info_buffer(infos)
n_steps += 1 n_steps += 1
if isinstance(self.action_space, gym.spaces.Discrete): if isinstance(self.action_space, spaces.Discrete):
# Reshape in case of discrete action # Reshape in case of discrete action
actions = actions.reshape(-1, 1) actions = actions.reshape(-1, 1)

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
import gym
import torch as th import torch as th
from gym import spaces
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor, BaseFeaturesExtractor,
@ -29,8 +29,8 @@ class QuantileNetwork(BasePolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
features_extractor: nn.Module, features_extractor: nn.Module,
features_dim: int, features_dim: int,
n_quantiles: int = 200, n_quantiles: int = 200,
@ -111,8 +111,8 @@ class QRDQNPolicy(BasePolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
n_quantiles: int = 200, n_quantiles: int = 200,
net_arch: Optional[List[int]] = None, net_arch: Optional[List[int]] = None,
@ -233,8 +233,8 @@ class CnnPolicy(QRDQNPolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
n_quantiles: int = 200, n_quantiles: int = 200,
net_arch: Optional[List[int]] = None, net_arch: Optional[List[int]] = None,
@ -281,8 +281,8 @@ class MultiInputPolicy(QRDQNPolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
n_quantiles: int = 200, n_quantiles: int = 200,
net_arch: Optional[List[int]] = None, net_arch: Optional[List[int]] = None,

View File

@ -1,8 +1,8 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np import numpy as np
import torch as th import torch as th
from gym import spaces
from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
@ -113,7 +113,7 @@ class QRDQN(OffPolicyAlgorithm):
seed=seed, seed=seed,
sde_support=False, sde_support=False,
optimize_memory_usage=optimize_memory_usage, optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Discrete,), supported_action_spaces=(spaces.Discrete,),
support_multi_env=True, support_multi_env=True,
) )
@ -235,7 +235,7 @@ class QRDQN(OffPolicyAlgorithm):
""" """
if not deterministic and np.random.rand() < self.exploration_rate: if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
if isinstance(self.observation_space, gym.spaces.Dict): if isinstance(self.observation_space, spaces.Dict):
n_batch = observation[list(observation.keys())[0]].shape[0] n_batch = observation[list(observation.keys())[0]].shape[0]
else: else:
n_batch = observation.shape[0] n_batch = observation.shape[0]

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th import torch as th
from gym import spaces
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BaseModel, BasePolicy from stable_baselines3.common.policies import BaseModel, BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.preprocessing import get_action_dim
@ -46,8 +46,8 @@ class Actor(BasePolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
net_arch: List[int], net_arch: List[int],
features_extractor: nn.Module, features_extractor: nn.Module,
features_dim: int, features_dim: int,
@ -194,8 +194,8 @@ class Critic(BaseModel):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
net_arch: List[int], net_arch: List[int],
features_extractor: nn.Module, features_extractor: nn.Module,
features_dim: int, features_dim: int,
@ -268,8 +268,8 @@ class TQCPolicy(BasePolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU, activation_fn: Type[nn.Module] = nn.ReLU,
@ -446,8 +446,8 @@ class CnnPolicy(TQCPolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU, activation_fn: Type[nn.Module] = nn.ReLU,
@ -515,8 +515,8 @@ class MultiInputPolicy(TQCPolicy):
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: spaces.Space,
action_space: gym.spaces.Space, action_space: spaces.Space,
lr_schedule: Schedule, lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU, activation_fn: Type[nn.Module] = nn.ReLU,

View File

@ -1,8 +1,8 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np import numpy as np
import torch as th import torch as th
from gym import spaces
from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
@ -125,7 +125,7 @@ class TQC(OffPolicyAlgorithm):
sde_sample_freq=sde_sample_freq, sde_sample_freq=sde_sample_freq,
use_sde_at_warmup=use_sde_at_warmup, use_sde_at_warmup=use_sde_at_warmup,
optimize_memory_usage=optimize_memory_usage, optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Box), supported_action_spaces=(spaces.Box,),
support_multi_env=True, support_multi_env=True,
) )

View File

@ -1 +1 @@
1.7.0a10 1.7.0a11

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=1.7.0a10", "stable_baselines3>=1.7.0a11",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

View File

@ -2,6 +2,7 @@ import random
import gym import gym
import pytest import pytest
from gym import spaces
from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement
from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
@ -27,7 +28,7 @@ class ToDictWrapper(gym.Wrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space}) self.observation_space = spaces.Dict({"obs": self.env.observation_space})
def reset(self): def reset(self):
return {"obs": self.env.reset()} return {"obs": self.env.reset()}

View File

@ -20,7 +20,7 @@ class ToDictWrapper(gym.Wrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space}) self.observation_space = spaces.Dict({"obs": self.env.observation_space})
def reset(self): def reset(self):
return {"obs": self.env.reset()} return {"obs": self.env.reset()}

View File

@ -1,7 +1,7 @@
from typing import List from typing import List
import pytest import pytest
from gym.spaces import Discrete from gym import spaces
from stable_baselines3.common.envs import IdentityEnv from stable_baselines3.common.envs import IdentityEnv
from sb3_contrib.common.wrappers import ActionMasker from sb3_contrib.common.wrappers import ActionMasker
@ -15,7 +15,7 @@ class IdentityEnvDiscrete(IdentityEnv):
:param dim: the size of the dimensions you want to learn :param dim: the size of the dimensions you want to learn
:param ep_length: the length of each episode in timesteps :param ep_length: the length of each episode in timesteps
""" """
space = Discrete(dim) space = spaces.Discrete(dim)
self.useless_property = 1 self.useless_property = 1
super().__init__(ep_length=ep_length, space=space) super().__init__(ep_length=ep_length, space=space)