Standardize the use of ``from gym import spaces`` (#131)
* Standardize from gym import spaces * update changelog * update issue template * update version * Update version
This commit is contained in:
parent
c9bd045d5c
commit
7c4a249fa4
|
|
@ -52,9 +52,8 @@ Describe the characteristic of your environment:
|
|||
* Versions of any other relevant libraries
|
||||
|
||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||
```python
|
||||
import stable_baselines3 as sb3
|
||||
sb3.get_system_info()
|
||||
```sh
|
||||
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
|
||||
```
|
||||
|
||||
**Additional context**
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.7.0a10 (WIP)
|
||||
Release 1.7.0a11 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -11,7 +11,7 @@ Breaking Changes:
|
|||
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
||||
please use an ``EvalCallback`` instead
|
||||
- Removed deprecated ``sde_net_arch`` parameter
|
||||
- Upgraded to Stable-Baselines3 >= 1.7.0a10
|
||||
- Upgraded to Stable-Baselines3 >= 1.7.0a11
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -41,6 +41,7 @@ Others:
|
|||
- Removed ignores on Flake8 F401 (@ZikangXiong)
|
||||
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
|
||||
- Set tensors construction directly on the device
|
||||
- Standardized the use of ``from gym import spaces``
|
||||
|
||||
Release 1.6.2 (2022-10-10)
|
||||
--------------------------
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ import warnings
|
|||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn.utils
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
|
@ -80,7 +80,7 @@ class ARS(BaseAlgorithm):
|
|||
policy_kwargs=policy_kwargs,
|
||||
verbose=verbose,
|
||||
device=device,
|
||||
supported_action_spaces=(gym.spaces.Box, gym.spaces.Discrete),
|
||||
supported_action_spaces=(spaces.Box, spaces.Discrete),
|
||||
support_multi_env=True,
|
||||
seed=seed,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
|
||||
|
|
@ -23,8 +23,8 @@ class ARSPolicy(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
with_bias: bool = True,
|
||||
|
|
@ -34,7 +34,7 @@ class ARSPolicy(BasePolicy):
|
|||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
squash_output=isinstance(action_space, gym.spaces.Box) and squash_output,
|
||||
squash_output=isinstance(action_space, spaces.Box) and squash_output,
|
||||
)
|
||||
|
||||
if net_arch is None:
|
||||
|
|
@ -45,12 +45,12 @@ class ARSPolicy(BasePolicy):
|
|||
self.features_dim = self.features_extractor.features_dim
|
||||
self.activation_fn = activation_fn
|
||||
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
if isinstance(action_space, spaces.Box):
|
||||
action_dim = get_action_dim(action_space)
|
||||
actor_net = create_mlp(
|
||||
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
|
||||
)
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
|
||||
else:
|
||||
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
||||
|
|
@ -72,9 +72,9 @@ class ARSPolicy(BasePolicy):
|
|||
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
|
||||
|
||||
features = self.extract_features(obs, self.features_extractor)
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
return self.action_net(features)
|
||||
elif isinstance(self.action_space, gym.spaces.Discrete):
|
||||
elif isinstance(self.action_space, spaces.Discrete):
|
||||
logits = self.action_net(features)
|
||||
return th.argmax(logits, dim=1)
|
||||
else:
|
||||
|
|
@ -98,8 +98,8 @@ class ARSLinearPolicy(ARSPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
with_bias: bool = False,
|
||||
squash_output: bool = False,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
|
|
@ -43,8 +43,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
@ -281,7 +281,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
# Convert to numpy
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
if self.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.unscale_action(actions)
|
||||
|
|
@ -379,8 +379,8 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
@ -433,8 +433,8 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Dict,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.distributions import Distribution
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
|
|
@ -63,8 +63,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
@ -415,7 +415,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
# Convert to numpy
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
if self.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.unscale_action(actions)
|
||||
|
|
@ -472,8 +472,8 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
@ -562,8 +562,8 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Dict, Union
|
|||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
|
||||
|
||||
|
||||
|
|
@ -25,16 +26,14 @@ class TimeFeatureWrapper(gym.Wrapper):
|
|||
|
||||
def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False):
|
||||
assert isinstance(
|
||||
env.observation_space, (gym.spaces.Box, gym.spaces.Dict)
|
||||
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `gym.spaces.Dict` (`gym.GoalEnv`) observation spaces."
|
||||
env.observation_space, (spaces.Box, spaces.Dict)
|
||||
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `spaces.Dict` (`gym.GoalEnv`) observation spaces."
|
||||
|
||||
# Add a time feature to the observation
|
||||
if isinstance(env.observation_space, gym.spaces.Dict):
|
||||
if isinstance(env.observation_space, spaces.Dict):
|
||||
assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
|
||||
obs_space = env.observation_space.spaces["observation"]
|
||||
assert isinstance(
|
||||
obs_space, gym.spaces.Box
|
||||
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
|
||||
assert isinstance(obs_space, spaces.Box), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
|
||||
obs_space = env.observation_space.spaces["observation"]
|
||||
else:
|
||||
obs_space = env.observation_space
|
||||
|
|
@ -45,10 +44,10 @@ class TimeFeatureWrapper(gym.Wrapper):
|
|||
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
|
||||
self.dtype = obs_space.dtype
|
||||
|
||||
if isinstance(env.observation_space, gym.spaces.Dict):
|
||||
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
|
||||
if isinstance(env.observation_space, spaces.Dict):
|
||||
env.observation_space.spaces["observation"] = spaces.Box(low=low, high=high, dtype=self.dtype)
|
||||
else:
|
||||
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
|
||||
env.observation_space = spaces.Box(low=low, high=high, dtype=self.dtype)
|
||||
|
||||
super().__init__(env)
|
||||
|
||||
|
|
|
|||
|
|
@ -137,9 +137,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
buffer_cls = (
|
||||
MaskableDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else MaskableRolloutBuffer
|
||||
)
|
||||
buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer
|
||||
|
||||
self.policy = self.policy_class(
|
||||
self.observation_space,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import time
|
|||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
|
@ -139,9 +138,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
buffer_cls = (
|
||||
RecurrentDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer
|
||||
)
|
||||
buffer_cls = RecurrentDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RecurrentRolloutBuffer
|
||||
|
||||
self.policy = self.policy_class(
|
||||
self.observation_space,
|
||||
|
|
@ -247,7 +244,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
# Rescale and perform action
|
||||
clipped_actions = actions
|
||||
# Clip the actions to avoid out of bound error
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
|
|
@ -262,7 +259,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
self._update_info_buffer(infos)
|
||||
n_steps += 1
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Reshape in case of discrete action
|
||||
actions = actions.reshape(-1, 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
|
|
@ -29,8 +29,8 @@ class QuantileNetwork(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
n_quantiles: int = 200,
|
||||
|
|
@ -111,8 +111,8 @@ class QRDQNPolicy(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
n_quantiles: int = 200,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
|
|
@ -233,8 +233,8 @@ class CnnPolicy(QRDQNPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
n_quantiles: int = 200,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
|
|
@ -281,8 +281,8 @@ class MultiInputPolicy(QRDQNPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
n_quantiles: int = 200,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
|
|
@ -113,7 +113,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
seed=seed,
|
||||
sde_support=False,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
supported_action_spaces=(gym.spaces.Discrete,),
|
||||
supported_action_spaces=(spaces.Discrete,),
|
||||
support_multi_env=True,
|
||||
)
|
||||
|
||||
|
|
@ -235,7 +235,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
"""
|
||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
||||
if isinstance(self.observation_space, gym.spaces.Dict):
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
n_batch = observation[list(observation.keys())[0]].shape[0]
|
||||
else:
|
||||
n_batch = observation.shape[0]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
|
@ -46,8 +46,8 @@ class Actor(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
|
|
@ -194,8 +194,8 @@ class Critic(BaseModel):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
|
|
@ -268,8 +268,8 @@ class TQCPolicy(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
@ -446,8 +446,8 @@ class CnnPolicy(TQCPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
@ -515,8 +515,8 @@ class MultiInputPolicy(TQCPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
|
|
@ -125,7 +125,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
supported_action_spaces=(gym.spaces.Box),
|
||||
supported_action_spaces=(spaces.Box,),
|
||||
support_multi_env=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a10
|
||||
1.7.0a11
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.7.0a10",
|
||||
"stable_baselines3>=1.7.0a11",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import random
|
|||
|
||||
import gym
|
||||
import pytest
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||
|
|
@ -27,7 +28,7 @@ class ToDictWrapper(gym.Wrapper):
|
|||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space})
|
||||
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
|
||||
|
||||
def reset(self):
|
||||
return {"obs": self.env.reset()}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class ToDictWrapper(gym.Wrapper):
|
|||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space})
|
||||
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
|
||||
|
||||
def reset(self):
|
||||
return {"obs": self.env.reset()}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import List
|
||||
|
||||
import pytest
|
||||
from gym.spaces import Discrete
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.envs import IdentityEnv
|
||||
|
||||
from sb3_contrib.common.wrappers import ActionMasker
|
||||
|
|
@ -15,7 +15,7 @@ class IdentityEnvDiscrete(IdentityEnv):
|
|||
:param dim: the size of the dimensions you want to learn
|
||||
:param ep_length: the length of each episode in timesteps
|
||||
"""
|
||||
space = Discrete(dim)
|
||||
space = spaces.Discrete(dim)
|
||||
self.useless_property = 1
|
||||
super().__init__(ep_length=ep_length, space=space)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue