Dictionary Observations (#29)
* Add TQC support for new HER version * Add dict obs support * Add support for dict obs
This commit is contained in:
parent
61bfdbc00a
commit
3665695d1e
|
|
@ -3,11 +3,14 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.1.0a5 (WIP)
|
||||
Release 1.1.0a6 (WIP)
|
||||
-------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Added support for Dictionary observation spaces (cf SB3 doc)
|
||||
- Upgraded to Stable-Baselines3 >= 1.1.0
|
||||
- Added proper handling of timeouts for off-policy algorithms (cf SB3 doc)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ instead of predicting the mean return (DQN).
|
|||
|
||||
MlpPolicy
|
||||
CnnPolicy
|
||||
MultiInputPolicy
|
||||
|
||||
|
||||
Notes
|
||||
|
|
@ -43,6 +44,7 @@ Discrete ✔ ✔
|
|||
Box ❌ ✔
|
||||
MultiDiscrete ❌ ✔
|
||||
MultiBinary ❌ ✔
|
||||
Dict ❌ ✔️
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
|
|
@ -149,3 +151,6 @@ QR-DQN Policies
|
|||
|
||||
.. autoclass:: CnnPolicy
|
||||
:members:
|
||||
|
||||
.. autoclass:: MultiInputPolicy
|
||||
:members:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ It truncates the quantiles predicted by different networks (a bit as it is done
|
|||
|
||||
MlpPolicy
|
||||
CnnPolicy
|
||||
MultiInputPolicy
|
||||
|
||||
|
||||
Notes
|
||||
|
|
@ -42,6 +43,7 @@ Discrete ❌ ✔️
|
|||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ✔️
|
||||
MultiBinary ❌ ✔️
|
||||
Dict ❌ ✔️
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy
|
||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
from sb3_contrib.qrdqn.qrdqn import QRDQN
|
||||
|
|
|
|||
|
|
@ -3,7 +3,13 @@ from typing import Any, Dict, List, Optional, Type
|
|||
import gym
|
||||
import torch as th
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
CombinedExtractor,
|
||||
FlattenExtractor,
|
||||
NatureCNN,
|
||||
create_mlp,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from torch import nn
|
||||
|
||||
|
|
@ -245,5 +251,54 @@ class CnnPolicy(QRDQNPolicy):
|
|||
)
|
||||
|
||||
|
||||
class MultiInputPolicy(QRDQNPolicy):
|
||||
"""
|
||||
Policy class for QR-DQN when using dict observations as input.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param n_quantiles: Number of quantiles
|
||||
:param net_arch: The specification of the network architecture.
|
||||
:param activation_fn: Activation function
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
n_quantiles: int = 200,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(MultiInputPolicy, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
n_quantiles,
|
||||
net_arch,
|
||||
activation_fn,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import gym
|
|||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
|
|
@ -34,6 +35,9 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
(see ``train_freq`` and ``n_episodes_rollout``)
|
||||
Set to ``-1`` means to do as many gradient steps as steps done in the environment
|
||||
during the rollout.
|
||||
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
|
||||
If ``None``, it will be automatically selected.
|
||||
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
|
||||
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
||||
at a cost of more complexity.
|
||||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||
|
|
@ -59,13 +63,15 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
policy: Union[str, Type[QRDQNPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 5e-5,
|
||||
buffer_size: int = 1000000,
|
||||
buffer_size: int = 1000000, # 1e6
|
||||
learning_starts: int = 50000,
|
||||
batch_size: Optional[int] = 32,
|
||||
tau: float = 1.0,
|
||||
gamma: float = 0.99,
|
||||
train_freq: int = 4,
|
||||
gradient_steps: int = 1,
|
||||
replay_buffer_class: Optional[ReplayBuffer] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
target_update_interval: int = 10000,
|
||||
exploration_fraction: float = 0.005,
|
||||
|
|
@ -94,6 +100,8 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
train_freq,
|
||||
gradient_steps,
|
||||
action_noise=None, # No action noise
|
||||
replay_buffer_class=replay_buffer_class,
|
||||
replay_buffer_kwargs=replay_buffer_kwargs,
|
||||
policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=tensorboard_log,
|
||||
verbose=verbose,
|
||||
|
|
@ -152,7 +160,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
self._update_learning_rate(self.policy.optimizer)
|
||||
|
||||
losses = []
|
||||
for gradient_step in range(gradient_steps):
|
||||
for _ in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
||||
|
|
@ -213,6 +221,9 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
"""
|
||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
||||
if isinstance(self.observation_space, gym.spaces.Dict):
|
||||
n_batch = observation[list(observation.keys())[0]].shape[0]
|
||||
else:
|
||||
n_batch = observation.shape[0]
|
||||
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy
|
||||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
from sb3_contrib.tqc.tqc import TQC
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
|
|
@ -7,11 +7,13 @@ from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_
|
|||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
CombinedExtractor,
|
||||
FlattenExtractor,
|
||||
NatureCNN,
|
||||
create_mlp,
|
||||
get_actor_critic_arch,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from torch import nn as nn
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
|
|
@ -278,6 +280,8 @@ class TQCPolicy(BasePolicy):
|
|||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param n_quantiles: Number of quantiles for the critic.
|
||||
:param n_critics: Number of critic networks to create.
|
||||
:param share_features_extractor: Whether to share or not the features extractor
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
|
@ -286,8 +290,8 @@ class TQCPolicy(BasePolicy):
|
|||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
|
|
@ -353,7 +357,7 @@ class TQCPolicy(BasePolicy):
|
|||
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
def _build(self, lr_schedule: Schedule) -> None:
|
||||
self.actor = self.make_actor()
|
||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
|
|
@ -448,6 +452,8 @@ class CnnPolicy(TQCPolicy):
|
|||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param n_quantiles: Number of quantiles for the critic.
|
||||
:param n_critics: Number of critic networks to create.
|
||||
:param share_features_extractor: Whether to share or not the features extractor
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
|
@ -456,8 +462,8 @@ class CnnPolicy(TQCPolicy):
|
|||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
|
|
@ -495,5 +501,80 @@ class CnnPolicy(TQCPolicy):
|
|||
)
|
||||
|
||||
|
||||
class MultiInputPolicy(TQCPolicy):
|
||||
"""
|
||||
Policy class (with both actor and critic) for TQC.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param lr_schedule: Learning rate schedule (could be constant)
|
||||
:param net_arch: The specification of the policy and value networks.
|
||||
:param activation_fn: Activation function
|
||||
:param use_sde: Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param sde_net_arch: Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param n_quantiles: Number of quantiles for the critic.
|
||||
:param n_critics: Number of critic networks to create.
|
||||
:param share_features_extractor: Whether to share or not the features extractor
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_quantiles: int = 25,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
):
|
||||
super(MultiInputPolicy, self).__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
activation_fn,
|
||||
use_sde,
|
||||
log_std_init,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
clip_mean,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
n_quantiles,
|
||||
n_critics,
|
||||
share_features_extractor,
|
||||
)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import gym
|
|||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
|
|
@ -35,6 +36,9 @@ class TQC(OffPolicyAlgorithm):
|
|||
:param gradient_steps: How many gradient update after each step
|
||||
:param action_noise: the action noise type (None by default), this can help
|
||||
for hard exploration problem. Cf common.noise for the different action noise type.
|
||||
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
|
||||
If ``None``, it will be automatically selected.
|
||||
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
|
||||
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
||||
at a cost of more complexity.
|
||||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||
|
|
@ -66,7 +70,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
policy: Union[str, Type[TQCPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Callable] = 3e-4,
|
||||
buffer_size: int = int(1e6),
|
||||
buffer_size: int = 1000000, # 1e6
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
|
|
@ -74,6 +78,8 @@ class TQC(OffPolicyAlgorithm):
|
|||
train_freq: int = 1,
|
||||
gradient_steps: int = 1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
replay_buffer_class: Optional[ReplayBuffer] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
ent_coef: Union[str, float] = "auto",
|
||||
target_update_interval: int = 1,
|
||||
|
|
@ -103,7 +109,9 @@ class TQC(OffPolicyAlgorithm):
|
|||
gamma,
|
||||
train_freq,
|
||||
gradient_steps,
|
||||
action_noise,
|
||||
action_noise=action_noise,
|
||||
replay_buffer_class=replay_buffer_class,
|
||||
replay_buffer_kwargs=replay_buffer_kwargs,
|
||||
policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=tensorboard_log,
|
||||
verbose=verbose,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.1.0a5
|
||||
1.1.0a6
|
||||
|
|
|
|||
|
|
@ -7,11 +7,8 @@ license_file = LICENSE
|
|||
env =
|
||||
PYTHONHASHSEED=0
|
||||
filterwarnings =
|
||||
# Tensorboard/Tensorflow warnings
|
||||
ignore:inspect.getargspec:DeprecationWarning:tensorflow
|
||||
ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning
|
||||
ignore:The binary mode of fromstring is deprecated:DeprecationWarning
|
||||
ignore::FutureWarning:tensorflow
|
||||
# Tensorboard warnings
|
||||
ignore::DeprecationWarning:tensorboard
|
||||
# Gym warnings
|
||||
ignore:Parameters to load are deprecated.:DeprecationWarning
|
||||
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -62,7 +62,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.0",
|
||||
"stable_baselines3>=1.1.0a6",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from copy import deepcopy
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
from stable_baselines3.common.identity_env import FakeImageEnv
|
||||
from stable_baselines3.common.envs import FakeImageEnv
|
||||
from stable_baselines3.common.utils import zip_strict
|
||||
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,258 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
|
||||
|
||||
from sb3_contrib import QRDQN, TQC
|
||||
|
||||
|
||||
class DummyDictEnv(gym.Env):
|
||||
"""Custom Environment for testing purposes only"""
|
||||
|
||||
metadata = {"render.modes": ["human"]}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_discrete_actions=False,
|
||||
channel_last=False,
|
||||
nested_dict_obs=False,
|
||||
vec_only=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_discrete_actions:
|
||||
self.action_space = spaces.Discrete(3)
|
||||
else:
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
N_CHANNELS = 1
|
||||
HEIGHT = 64
|
||||
WIDTH = 64
|
||||
|
||||
if channel_last:
|
||||
obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
|
||||
else:
|
||||
obs_shape = (N_CHANNELS, HEIGHT, WIDTH)
|
||||
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
# Image obs
|
||||
"img": spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8),
|
||||
# Vector obs
|
||||
"vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
|
||||
# Discrete obs
|
||||
"discrete": spaces.Discrete(4),
|
||||
}
|
||||
)
|
||||
|
||||
# For checking consistency with normal MlpPolicy
|
||||
if vec_only:
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
# Vector obs
|
||||
"vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
if nested_dict_obs:
|
||||
# Add dictionary observation inside observation space
|
||||
self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)})
|
||||
|
||||
def seed(self, seed=None):
|
||||
if seed is not None:
|
||||
self.observation_space.seed(seed)
|
||||
|
||||
def step(self, action):
|
||||
reward = 0.0
|
||||
done = False
|
||||
return self.observation_space.sample(), reward, done, {}
|
||||
|
||||
def compute_reward(self, achieved_goal, desired_goal, info):
|
||||
return np.zeros((len(achieved_goal),))
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
||||
def render(self, mode="human"):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
||||
def test_consistency(model_class):
|
||||
"""
|
||||
Make sure that dict obs with vector only vs using flatten obs is equivalent.
|
||||
This ensures notable that the network architectures are the same.
|
||||
"""
|
||||
use_discrete_actions = model_class == QRDQN
|
||||
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
|
||||
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
|
||||
env = gym.wrappers.FlattenObservation(dict_env)
|
||||
dict_env.seed(10)
|
||||
obs = dict_env.reset()
|
||||
|
||||
kwargs = {}
|
||||
n_steps = 256
|
||||
|
||||
if model_class in {}:
|
||||
kwargs = dict(
|
||||
n_steps=128,
|
||||
)
|
||||
else:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features and make learning faster
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
train_freq=8,
|
||||
gradient_steps=1,
|
||||
)
|
||||
if model_class == QRDQN:
|
||||
kwargs["learning_starts"] = 0
|
||||
|
||||
dict_model = model_class("MultiInputPolicy", dict_env, gamma=0.5, seed=1, **kwargs)
|
||||
action_before_learning_1, _ = dict_model.predict(obs, deterministic=True)
|
||||
dict_model.learn(total_timesteps=n_steps)
|
||||
|
||||
normal_model = model_class("MlpPolicy", env, gamma=0.5, seed=1, **kwargs)
|
||||
action_before_learning_2, _ = normal_model.predict(obs["vec"], deterministic=True)
|
||||
normal_model.learn(total_timesteps=n_steps)
|
||||
|
||||
action_1, _ = dict_model.predict(obs, deterministic=True)
|
||||
action_2, _ = normal_model.predict(obs["vec"], deterministic=True)
|
||||
|
||||
assert np.allclose(action_before_learning_1, action_before_learning_2)
|
||||
assert np.allclose(action_1, action_2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
||||
@pytest.mark.parametrize("channel_last", [False, True])
|
||||
def test_dict_spaces(model_class, channel_last):
|
||||
"""
|
||||
Additional tests to check observation space support
|
||||
with mixed observation.
|
||||
"""
|
||||
use_discrete_actions = model_class not in [TQC]
|
||||
env = DummyDictEnv(use_discrete_actions=use_discrete_actions, channel_last=channel_last)
|
||||
env = gym.wrappers.TimeLimit(env, 100)
|
||||
|
||||
kwargs = {}
|
||||
n_steps = 256
|
||||
|
||||
if model_class in {}:
|
||||
kwargs = dict(
|
||||
n_steps=128,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[32],
|
||||
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features and make learning faster
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[32],
|
||||
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||
n_quantiles=20,
|
||||
),
|
||||
train_freq=8,
|
||||
gradient_steps=1,
|
||||
)
|
||||
if model_class == QRDQN:
|
||||
kwargs["learning_starts"] = 0
|
||||
|
||||
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
|
||||
|
||||
model.learn(total_timesteps=n_steps)
|
||||
|
||||
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
||||
@pytest.mark.parametrize("channel_last", [False, True])
|
||||
def test_dict_vec_framestack(model_class, channel_last):
|
||||
"""
|
||||
Additional tests to check observation space support
|
||||
for Dictionary spaces and VecEnvWrapper using MultiInputPolicy.
|
||||
"""
|
||||
use_discrete_actions = model_class not in [TQC]
|
||||
channels_order = {"vec": None, "img": "last" if channel_last else "first"}
|
||||
env = DummyVecEnv(
|
||||
[lambda: SimpleMultiObsEnv(random_start=True, discrete_actions=use_discrete_actions, channel_last=channel_last)]
|
||||
)
|
||||
|
||||
env = VecFrameStack(env, n_stack=3, channels_order=channels_order)
|
||||
|
||||
kwargs = {}
|
||||
n_steps = 256
|
||||
|
||||
if model_class in {}:
|
||||
kwargs = dict(
|
||||
n_steps=128,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[32],
|
||||
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features and make learning faster
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[32],
|
||||
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||
n_quantiles=20,
|
||||
),
|
||||
train_freq=8,
|
||||
gradient_steps=1,
|
||||
)
|
||||
if model_class == QRDQN:
|
||||
kwargs["learning_starts"] = 0
|
||||
|
||||
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
|
||||
|
||||
model.learn(total_timesteps=n_steps)
|
||||
|
||||
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
||||
def test_vec_normalize(model_class):
|
||||
"""
|
||||
Additional tests to check observation space support
|
||||
for GoalEnv and VecNormalize using MultiInputPolicy.
|
||||
"""
|
||||
env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not (model_class == QRDQN))])
|
||||
env = VecNormalize(env)
|
||||
|
||||
kwargs = {}
|
||||
n_steps = 256
|
||||
|
||||
if model_class in {}:
|
||||
kwargs = dict(
|
||||
n_steps=128,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[32],
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features and make learning faster
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[32],
|
||||
),
|
||||
train_freq=8,
|
||||
gradient_steps=1,
|
||||
)
|
||||
if model_class == QRDQN:
|
||||
kwargs["learning_starts"] = 0
|
||||
|
||||
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
|
||||
|
||||
model.learn(total_timesteps=n_steps)
|
||||
|
||||
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
|
||||
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch as th
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||
from stable_baselines3.common.utils import get_device
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
|
|
@ -243,6 +243,7 @@ def test_save_load_replay_buffer(tmp_path, model_class):
|
|||
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
|
||||
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
|
||||
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
|
||||
infos = [[{"TimeLimit.truncated": truncated}] for truncated in old_replay_buffer.timeouts]
|
||||
|
||||
# test extending replay buffer
|
||||
model.replay_buffer.extend(
|
||||
|
|
@ -251,6 +252,7 @@ def test_save_load_replay_buffer(tmp_path, model_class):
|
|||
old_replay_buffer.actions,
|
||||
old_replay_buffer.rewards,
|
||||
old_replay_buffer.dones,
|
||||
infos,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,15 +2,13 @@ import gym
|
|||
import numpy as np
|
||||
import pytest
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
from stable_baselines3.common.envs import BitFlippingEnv
|
||||
|
||||
from sb3_contrib.common.wrappers import TimeFeatureWrapper
|
||||
|
||||
|
||||
class CustomGoalEnv(gym.GoalEnv):
|
||||
"""docstring for CustomGoalEnv."""
|
||||
|
||||
def __init__(self):
|
||||
super(CustomGoalEnv, self).__init__()
|
||||
self.observation_space = spaces.Dict(
|
||||
|
|
|
|||
Loading…
Reference in New Issue