Standardize the use of ``from gym import spaces`` (#131)
* Standardize from gym import spaces * update changelog * update issue template * update version * Update version
This commit is contained in:
parent
c9bd045d5c
commit
7c4a249fa4
|
|
@ -52,9 +52,8 @@ Describe the characteristic of your environment:
|
||||||
* Versions of any other relevant libraries
|
* Versions of any other relevant libraries
|
||||||
|
|
||||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||||
```python
|
```sh
|
||||||
import stable_baselines3 as sb3
|
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
|
||||||
sb3.get_system_info()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Additional context**
|
**Additional context**
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.7.0a10 (WIP)
|
Release 1.7.0a11 (WIP)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -11,7 +11,7 @@ Breaking Changes:
|
||||||
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
||||||
please use an ``EvalCallback`` instead
|
please use an ``EvalCallback`` instead
|
||||||
- Removed deprecated ``sde_net_arch`` parameter
|
- Removed deprecated ``sde_net_arch`` parameter
|
||||||
- Upgraded to Stable-Baselines3 >= 1.7.0a10
|
- Upgraded to Stable-Baselines3 >= 1.7.0a11
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
@ -41,6 +41,7 @@ Others:
|
||||||
- Removed ignores on Flake8 F401 (@ZikangXiong)
|
- Removed ignores on Flake8 F401 (@ZikangXiong)
|
||||||
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
|
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
|
||||||
- Set tensors construction directly on the device
|
- Set tensors construction directly on the device
|
||||||
|
- Standardized the use of ``from gym import spaces``
|
||||||
|
|
||||||
Release 1.6.2 (2022-10-10)
|
Release 1.6.2 (2022-10-10)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,10 @@ import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
import torch.nn.utils
|
import torch.nn.utils
|
||||||
|
from gym import spaces
|
||||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
|
|
@ -80,7 +80,7 @@ class ARS(BaseAlgorithm):
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
device=device,
|
device=device,
|
||||||
supported_action_spaces=(gym.spaces.Box, gym.spaces.Discrete),
|
supported_action_spaces=(spaces.Box, spaces.Discrete),
|
||||||
support_multi_env=True,
|
support_multi_env=True,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
import gym
|
|
||||||
import torch as th
|
import torch as th
|
||||||
|
from gym import spaces
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
|
||||||
|
|
@ -23,8 +23,8 @@ class ARSPolicy(BasePolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
net_arch: Optional[List[int]] = None,
|
net_arch: Optional[List[int]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
with_bias: bool = True,
|
with_bias: bool = True,
|
||||||
|
|
@ -34,7 +34,7 @@ class ARSPolicy(BasePolicy):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
action_space,
|
action_space,
|
||||||
squash_output=isinstance(action_space, gym.spaces.Box) and squash_output,
|
squash_output=isinstance(action_space, spaces.Box) and squash_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
if net_arch is None:
|
if net_arch is None:
|
||||||
|
|
@ -45,12 +45,12 @@ class ARSPolicy(BasePolicy):
|
||||||
self.features_dim = self.features_extractor.features_dim
|
self.features_dim = self.features_extractor.features_dim
|
||||||
self.activation_fn = activation_fn
|
self.activation_fn = activation_fn
|
||||||
|
|
||||||
if isinstance(action_space, gym.spaces.Box):
|
if isinstance(action_space, spaces.Box):
|
||||||
action_dim = get_action_dim(action_space)
|
action_dim = get_action_dim(action_space)
|
||||||
actor_net = create_mlp(
|
actor_net = create_mlp(
|
||||||
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
|
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
|
||||||
)
|
)
|
||||||
elif isinstance(action_space, gym.spaces.Discrete):
|
elif isinstance(action_space, spaces.Discrete):
|
||||||
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
|
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
||||||
|
|
@ -72,9 +72,9 @@ class ARSPolicy(BasePolicy):
|
||||||
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
|
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
|
||||||
|
|
||||||
features = self.extract_features(obs, self.features_extractor)
|
features = self.extract_features(obs, self.features_extractor)
|
||||||
if isinstance(self.action_space, gym.spaces.Box):
|
if isinstance(self.action_space, spaces.Box):
|
||||||
return self.action_net(features)
|
return self.action_net(features)
|
||||||
elif isinstance(self.action_space, gym.spaces.Discrete):
|
elif isinstance(self.action_space, spaces.Discrete):
|
||||||
logits = self.action_net(features)
|
logits = self.action_net(features)
|
||||||
return th.argmax(logits, dim=1)
|
return th.argmax(logits, dim=1)
|
||||||
else:
|
else:
|
||||||
|
|
@ -98,8 +98,8 @@ class ARSLinearPolicy(ARSPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
with_bias: bool = False,
|
with_bias: bool = False,
|
||||||
squash_output: bool = False,
|
squash_output: bool = False,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
|
from gym import spaces
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.torch_layers import (
|
from stable_baselines3.common.torch_layers import (
|
||||||
BaseFeaturesExtractor,
|
BaseFeaturesExtractor,
|
||||||
|
|
@ -43,8 +43,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
|
|
@ -281,7 +281,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
# Convert to numpy
|
# Convert to numpy
|
||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
if isinstance(self.action_space, gym.spaces.Box):
|
if isinstance(self.action_space, spaces.Box):
|
||||||
if self.squash_output:
|
if self.squash_output:
|
||||||
# Rescale to proper domain when using squashing
|
# Rescale to proper domain when using squashing
|
||||||
actions = self.unscale_action(actions)
|
actions = self.unscale_action(actions)
|
||||||
|
|
@ -379,8 +379,8 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
|
|
@ -433,8 +433,8 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Dict,
|
observation_space: spaces.Dict,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
|
from gym import spaces
|
||||||
from stable_baselines3.common.distributions import Distribution
|
from stable_baselines3.common.distributions import Distribution
|
||||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||||
from stable_baselines3.common.torch_layers import (
|
from stable_baselines3.common.torch_layers import (
|
||||||
|
|
@ -63,8 +63,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
|
|
@ -415,7 +415,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
# Convert to numpy
|
# Convert to numpy
|
||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
if isinstance(self.action_space, gym.spaces.Box):
|
if isinstance(self.action_space, spaces.Box):
|
||||||
if self.squash_output:
|
if self.squash_output:
|
||||||
# Rescale to proper domain when using squashing
|
# Rescale to proper domain when using squashing
|
||||||
actions = self.unscale_action(actions)
|
actions = self.unscale_action(actions)
|
||||||
|
|
@ -472,8 +472,8 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
|
|
@ -562,8 +562,8 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from typing import Dict, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from gym import spaces
|
||||||
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
|
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -25,16 +26,14 @@ class TimeFeatureWrapper(gym.Wrapper):
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False):
|
def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
env.observation_space, (gym.spaces.Box, gym.spaces.Dict)
|
env.observation_space, (spaces.Box, spaces.Dict)
|
||||||
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `gym.spaces.Dict` (`gym.GoalEnv`) observation spaces."
|
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `spaces.Dict` (`gym.GoalEnv`) observation spaces."
|
||||||
|
|
||||||
# Add a time feature to the observation
|
# Add a time feature to the observation
|
||||||
if isinstance(env.observation_space, gym.spaces.Dict):
|
if isinstance(env.observation_space, spaces.Dict):
|
||||||
assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
|
assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
|
||||||
obs_space = env.observation_space.spaces["observation"]
|
obs_space = env.observation_space.spaces["observation"]
|
||||||
assert isinstance(
|
assert isinstance(obs_space, spaces.Box), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
|
||||||
obs_space, gym.spaces.Box
|
|
||||||
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
|
|
||||||
obs_space = env.observation_space.spaces["observation"]
|
obs_space = env.observation_space.spaces["observation"]
|
||||||
else:
|
else:
|
||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
|
|
@ -45,10 +44,10 @@ class TimeFeatureWrapper(gym.Wrapper):
|
||||||
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
|
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
|
||||||
self.dtype = obs_space.dtype
|
self.dtype = obs_space.dtype
|
||||||
|
|
||||||
if isinstance(env.observation_space, gym.spaces.Dict):
|
if isinstance(env.observation_space, spaces.Dict):
|
||||||
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
|
env.observation_space.spaces["observation"] = spaces.Box(low=low, high=high, dtype=self.dtype)
|
||||||
else:
|
else:
|
||||||
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
|
env.observation_space = spaces.Box(low=low, high=high, dtype=self.dtype)
|
||||||
|
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -137,9 +137,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
self._setup_lr_schedule()
|
self._setup_lr_schedule()
|
||||||
self.set_random_seed(self.seed)
|
self.set_random_seed(self.seed)
|
||||||
|
|
||||||
buffer_cls = (
|
buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer
|
||||||
MaskableDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else MaskableRolloutBuffer
|
|
||||||
)
|
|
||||||
|
|
||||||
self.policy = self.policy_class(
|
self.policy = self.policy_class(
|
||||||
self.observation_space,
|
self.observation_space,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
|
|
@ -139,9 +138,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
self._setup_lr_schedule()
|
self._setup_lr_schedule()
|
||||||
self.set_random_seed(self.seed)
|
self.set_random_seed(self.seed)
|
||||||
|
|
||||||
buffer_cls = (
|
buffer_cls = RecurrentDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RecurrentRolloutBuffer
|
||||||
RecurrentDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer
|
|
||||||
)
|
|
||||||
|
|
||||||
self.policy = self.policy_class(
|
self.policy = self.policy_class(
|
||||||
self.observation_space,
|
self.observation_space,
|
||||||
|
|
@ -247,7 +244,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
# Rescale and perform action
|
# Rescale and perform action
|
||||||
clipped_actions = actions
|
clipped_actions = actions
|
||||||
# Clip the actions to avoid out of bound error
|
# Clip the actions to avoid out of bound error
|
||||||
if isinstance(self.action_space, gym.spaces.Box):
|
if isinstance(self.action_space, spaces.Box):
|
||||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||||
|
|
||||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||||
|
|
@ -262,7 +259,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
self._update_info_buffer(infos)
|
self._update_info_buffer(infos)
|
||||||
n_steps += 1
|
n_steps += 1
|
||||||
|
|
||||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
# Reshape in case of discrete action
|
# Reshape in case of discrete action
|
||||||
actions = actions.reshape(-1, 1)
|
actions = actions.reshape(-1, 1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
import gym
|
|
||||||
import torch as th
|
import torch as th
|
||||||
|
from gym import spaces
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.torch_layers import (
|
from stable_baselines3.common.torch_layers import (
|
||||||
BaseFeaturesExtractor,
|
BaseFeaturesExtractor,
|
||||||
|
|
@ -29,8 +29,8 @@ class QuantileNetwork(BasePolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
features_extractor: nn.Module,
|
features_extractor: nn.Module,
|
||||||
features_dim: int,
|
features_dim: int,
|
||||||
n_quantiles: int = 200,
|
n_quantiles: int = 200,
|
||||||
|
|
@ -111,8 +111,8 @@ class QRDQNPolicy(BasePolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
n_quantiles: int = 200,
|
n_quantiles: int = 200,
|
||||||
net_arch: Optional[List[int]] = None,
|
net_arch: Optional[List[int]] = None,
|
||||||
|
|
@ -233,8 +233,8 @@ class CnnPolicy(QRDQNPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
n_quantiles: int = 200,
|
n_quantiles: int = 200,
|
||||||
net_arch: Optional[List[int]] = None,
|
net_arch: Optional[List[int]] = None,
|
||||||
|
|
@ -281,8 +281,8 @@ class MultiInputPolicy(QRDQNPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
n_quantiles: int = 200,
|
n_quantiles: int = 200,
|
||||||
net_arch: Optional[List[int]] = None,
|
net_arch: Optional[List[int]] = None,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
|
from gym import spaces
|
||||||
from stable_baselines3.common.buffers import ReplayBuffer
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
|
|
@ -113,7 +113,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
seed=seed,
|
seed=seed,
|
||||||
sde_support=False,
|
sde_support=False,
|
||||||
optimize_memory_usage=optimize_memory_usage,
|
optimize_memory_usage=optimize_memory_usage,
|
||||||
supported_action_spaces=(gym.spaces.Discrete,),
|
supported_action_spaces=(spaces.Discrete,),
|
||||||
support_multi_env=True,
|
support_multi_env=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -235,7 +235,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||||
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
||||||
if isinstance(self.observation_space, gym.spaces.Dict):
|
if isinstance(self.observation_space, spaces.Dict):
|
||||||
n_batch = observation[list(observation.keys())[0]].shape[0]
|
n_batch = observation[list(observation.keys())[0]].shape[0]
|
||||||
else:
|
else:
|
||||||
n_batch = observation.shape[0]
|
n_batch = observation.shape[0]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import gym
|
|
||||||
import torch as th
|
import torch as th
|
||||||
|
from gym import spaces
|
||||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||||
from stable_baselines3.common.policies import BaseModel, BasePolicy
|
from stable_baselines3.common.policies import BaseModel, BasePolicy
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
|
|
@ -46,8 +46,8 @@ class Actor(BasePolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
net_arch: List[int],
|
net_arch: List[int],
|
||||||
features_extractor: nn.Module,
|
features_extractor: nn.Module,
|
||||||
features_dim: int,
|
features_dim: int,
|
||||||
|
|
@ -194,8 +194,8 @@ class Critic(BaseModel):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
net_arch: List[int],
|
net_arch: List[int],
|
||||||
features_extractor: nn.Module,
|
features_extractor: nn.Module,
|
||||||
features_dim: int,
|
features_dim: int,
|
||||||
|
|
@ -268,8 +268,8 @@ class TQCPolicy(BasePolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
|
@ -446,8 +446,8 @@ class CnnPolicy(TQCPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
|
@ -515,8 +515,8 @@ class MultiInputPolicy(TQCPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
|
from gym import spaces
|
||||||
from stable_baselines3.common.buffers import ReplayBuffer
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.noise import ActionNoise
|
from stable_baselines3.common.noise import ActionNoise
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
|
|
@ -125,7 +125,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
sde_sample_freq=sde_sample_freq,
|
sde_sample_freq=sde_sample_freq,
|
||||||
use_sde_at_warmup=use_sde_at_warmup,
|
use_sde_at_warmup=use_sde_at_warmup,
|
||||||
optimize_memory_usage=optimize_memory_usage,
|
optimize_memory_usage=optimize_memory_usage,
|
||||||
supported_action_spaces=(gym.spaces.Box),
|
supported_action_spaces=(spaces.Box,),
|
||||||
support_multi_env=True,
|
support_multi_env=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.7.0a10
|
1.7.0a11
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.7.0a10",
|
"stable_baselines3>=1.7.0a11",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import random
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import pytest
|
import pytest
|
||||||
|
from gym import spaces
|
||||||
from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement
|
from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement
|
||||||
from stable_baselines3.common.env_util import make_vec_env
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||||
|
|
@ -27,7 +28,7 @@ class ToDictWrapper(gym.Wrapper):
|
||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space})
|
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return {"obs": self.env.reset()}
|
return {"obs": self.env.reset()}
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ class ToDictWrapper(gym.Wrapper):
|
||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space})
|
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return {"obs": self.env.reset()}
|
return {"obs": self.env.reset()}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from gym.spaces import Discrete
|
from gym import spaces
|
||||||
from stable_baselines3.common.envs import IdentityEnv
|
from stable_baselines3.common.envs import IdentityEnv
|
||||||
|
|
||||||
from sb3_contrib.common.wrappers import ActionMasker
|
from sb3_contrib.common.wrappers import ActionMasker
|
||||||
|
|
@ -15,7 +15,7 @@ class IdentityEnvDiscrete(IdentityEnv):
|
||||||
:param dim: the size of the dimensions you want to learn
|
:param dim: the size of the dimensions you want to learn
|
||||||
:param ep_length: the length of each episode in timesteps
|
:param ep_length: the length of each episode in timesteps
|
||||||
"""
|
"""
|
||||||
space = Discrete(dim)
|
space = spaces.Discrete(dim)
|
||||||
self.useless_property = 1
|
self.useless_property = 1
|
||||||
super().__init__(ep_length=ep_length, space=space)
|
super().__init__(ep_length=ep_length, space=space)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue