Dictionary Observations (#29)

* Add TQC support for new HER version

* Add dict obs support

* Add support for dict obs
This commit is contained in:
Antonin RAFFIN 2021-05-11 13:24:31 +02:00 committed by GitHub
parent 61bfdbc00a
commit 3665695d1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 447 additions and 27 deletions

View File

@ -3,11 +3,14 @@
Changelog Changelog
========== ==========
Release 1.1.0a5 (WIP) Release 1.1.0a6 (WIP)
------------------------------- -------------------------------
Breaking Changes: Breaking Changes:
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
- Added support for Dictionary observation spaces (cf SB3 doc)
- Upgraded to Stable-Baselines3 >= 1.1.0
- Added proper handling of timeouts for off-policy algorithms (cf SB3 doc)
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^

View File

@ -18,6 +18,7 @@ instead of predicting the mean return (DQN).
MlpPolicy MlpPolicy
CnnPolicy CnnPolicy
MultiInputPolicy
Notes Notes
@ -43,6 +44,7 @@ Discrete ✔ ✔
Box ❌ ✔ Box ❌ ✔
MultiDiscrete ❌ ✔ MultiDiscrete ❌ ✔
MultiBinary ❌ ✔ MultiBinary ❌ ✔
Dict ❌ ✔️
============= ====== =========== ============= ====== ===========
@ -149,3 +151,6 @@ QR-DQN Policies
.. autoclass:: CnnPolicy .. autoclass:: CnnPolicy
:members: :members:
.. autoclass:: MultiInputPolicy
:members:

View File

@ -18,6 +18,7 @@ It truncates the quantiles predicted by different networks (a bit as it is done
MlpPolicy MlpPolicy
CnnPolicy CnnPolicy
MultiInputPolicy
Notes Notes
@ -42,6 +43,7 @@ Discrete ❌ ✔️
Box ✔️ ✔️ Box ✔️ ✔️
MultiDiscrete ❌ ✔️ MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️ MultiBinary ❌ ✔️
Dict ❌ ✔️
============= ====== =========== ============= ====== ===========

View File

@ -1,2 +1,2 @@
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.qrdqn.qrdqn import QRDQN from sb3_contrib.qrdqn.qrdqn import QRDQN

View File

@ -3,7 +3,13 @@ from typing import Any, Dict, List, Optional, Type
import gym import gym
import torch as th import torch as th
from stable_baselines3.common.policies import BasePolicy, register_policy from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.type_aliases import Schedule
from torch import nn from torch import nn
@ -245,5 +251,54 @@ class CnnPolicy(QRDQNPolicy):
) )
class MultiInputPolicy(QRDQNPolicy):
"""
Policy class for QR-DQN when using dict observations as input.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param n_quantiles: Number of quantiles
:param net_arch: The specification of the network architecture.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(MultiInputPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
n_quantiles,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
register_policy("MlpPolicy", MlpPolicy) register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy) register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View File

@ -4,6 +4,7 @@ import gym
import numpy as np import numpy as np
import torch as th import torch as th
from stable_baselines3.common import logger from stable_baselines3.common import logger
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@ -34,6 +35,9 @@ class QRDQN(OffPolicyAlgorithm):
(see ``train_freq`` and ``n_episodes_rollout``) (see ``train_freq`` and ``n_episodes_rollout``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout. during the rollout.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity. at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
@ -59,13 +63,15 @@ class QRDQN(OffPolicyAlgorithm):
policy: Union[str, Type[QRDQNPolicy]], policy: Union[str, Type[QRDQNPolicy]],
env: Union[GymEnv, str], env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 5e-5, learning_rate: Union[float, Schedule] = 5e-5,
buffer_size: int = 1000000, buffer_size: int = 1000000, # 1e6
learning_starts: int = 50000, learning_starts: int = 50000,
batch_size: Optional[int] = 32, batch_size: Optional[int] = 32,
tau: float = 1.0, tau: float = 1.0,
gamma: float = 0.99, gamma: float = 0.99,
train_freq: int = 4, train_freq: int = 4,
gradient_steps: int = 1, gradient_steps: int = 1,
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False, optimize_memory_usage: bool = False,
target_update_interval: int = 10000, target_update_interval: int = 10000,
exploration_fraction: float = 0.005, exploration_fraction: float = 0.005,
@ -94,6 +100,8 @@ class QRDQN(OffPolicyAlgorithm):
train_freq, train_freq,
gradient_steps, gradient_steps,
action_noise=None, # No action noise action_noise=None, # No action noise
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
policy_kwargs=policy_kwargs, policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log, tensorboard_log=tensorboard_log,
verbose=verbose, verbose=verbose,
@ -152,7 +160,7 @@ class QRDQN(OffPolicyAlgorithm):
self._update_learning_rate(self.policy.optimizer) self._update_learning_rate(self.policy.optimizer)
losses = [] losses = []
for gradient_step in range(gradient_steps): for _ in range(gradient_steps):
# Sample replay buffer # Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
@ -213,6 +221,9 @@ class QRDQN(OffPolicyAlgorithm):
""" """
if not deterministic and np.random.rand() < self.exploration_rate: if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
if isinstance(self.observation_space, gym.spaces.Dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
else:
n_batch = observation.shape[0] n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)]) action = np.array([self.action_space.sample() for _ in range(n_batch)])
else: else:

View File

@ -1,2 +1,2 @@
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.tqc.tqc import TQC from sb3_contrib.tqc.tqc import TQC

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym import gym
import torch as th import torch as th
@ -7,11 +7,13 @@ from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_
from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor, BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor, FlattenExtractor,
NatureCNN, NatureCNN,
create_mlp, create_mlp,
get_actor_critic_arch, get_actor_critic_arch,
) )
from stable_baselines3.common.type_aliases import Schedule
from torch import nn as nn from torch import nn as nn
# CAP the standard deviation of the actor # CAP the standard deviation of the actor
@ -278,6 +280,8 @@ class TQCPolicy(BasePolicy):
``th.optim.Adam`` by default ``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments, :param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer excluding the learning rate, to pass to the optimizer
:param n_quantiles: Number of quantiles for the critic.
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor :param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time) between the actor and the critic (this saves computation time)
""" """
@ -286,8 +290,8 @@ class TQCPolicy(BasePolicy):
self, self,
observation_space: gym.spaces.Space, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space, action_space: gym.spaces.Space,
lr_schedule: Callable, lr_schedule: Schedule,
net_arch: Optional[List[int]] = None, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU, activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False, use_sde: bool = False,
log_std_init: float = -3, log_std_init: float = -3,
@ -353,7 +357,7 @@ class TQCPolicy(BasePolicy):
self._build(lr_schedule) self._build(lr_schedule)
def _build(self, lr_schedule: Callable) -> None: def _build(self, lr_schedule: Schedule) -> None:
self.actor = self.make_actor() self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
@ -448,6 +452,8 @@ class CnnPolicy(TQCPolicy):
``th.optim.Adam`` by default ``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments, :param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer excluding the learning rate, to pass to the optimizer
:param n_quantiles: Number of quantiles for the critic.
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor :param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time) between the actor and the critic (this saves computation time)
""" """
@ -456,8 +462,8 @@ class CnnPolicy(TQCPolicy):
self, self,
observation_space: gym.spaces.Space, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space, action_space: gym.spaces.Space,
lr_schedule: Callable, lr_schedule: Schedule,
net_arch: Optional[List[int]] = None, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU, activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False, use_sde: bool = False,
log_std_init: float = -3, log_std_init: float = -3,
@ -495,5 +501,80 @@ class CnnPolicy(TQCPolicy):
) )
class MultiInputPolicy(TQCPolicy):
"""
Policy class (with both actor and critic) for TQC.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_quantiles: Number of quantiles for the critic.
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super(MultiInputPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
use_sde,
log_std_init,
sde_net_arch,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_quantiles,
n_critics,
share_features_extractor,
)
register_policy("MlpPolicy", MlpPolicy) register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy) register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View File

@ -4,6 +4,7 @@ import gym
import numpy as np import numpy as np
import torch as th import torch as th
from stable_baselines3.common import logger from stable_baselines3.common import logger
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
@ -35,6 +36,9 @@ class TQC(OffPolicyAlgorithm):
:param gradient_steps: How many gradient update after each step :param gradient_steps: How many gradient update after each step
:param action_noise: the action noise type (None by default), this can help :param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type. for hard exploration problem. Cf common.noise for the different action noise type.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity. at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
@ -66,7 +70,7 @@ class TQC(OffPolicyAlgorithm):
policy: Union[str, Type[TQCPolicy]], policy: Union[str, Type[TQCPolicy]],
env: Union[GymEnv, str], env: Union[GymEnv, str],
learning_rate: Union[float, Callable] = 3e-4, learning_rate: Union[float, Callable] = 3e-4,
buffer_size: int = int(1e6), buffer_size: int = 1000000, # 1e6
learning_starts: int = 100, learning_starts: int = 100,
batch_size: int = 256, batch_size: int = 256,
tau: float = 0.005, tau: float = 0.005,
@ -74,6 +78,8 @@ class TQC(OffPolicyAlgorithm):
train_freq: int = 1, train_freq: int = 1,
gradient_steps: int = 1, gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None, action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False, optimize_memory_usage: bool = False,
ent_coef: Union[str, float] = "auto", ent_coef: Union[str, float] = "auto",
target_update_interval: int = 1, target_update_interval: int = 1,
@ -103,7 +109,9 @@ class TQC(OffPolicyAlgorithm):
gamma, gamma,
train_freq, train_freq,
gradient_steps, gradient_steps,
action_noise, action_noise=action_noise,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
policy_kwargs=policy_kwargs, policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log, tensorboard_log=tensorboard_log,
verbose=verbose, verbose=verbose,

View File

@ -1 +1 @@
1.1.0a5 1.1.0a6

View File

@ -7,11 +7,8 @@ license_file = LICENSE
env = env =
PYTHONHASHSEED=0 PYTHONHASHSEED=0
filterwarnings = filterwarnings =
# Tensorboard/Tensorflow warnings # Tensorboard warnings
ignore:inspect.getargspec:DeprecationWarning:tensorflow ignore::DeprecationWarning:tensorboard
ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning
ignore:The binary mode of fromstring is deprecated:DeprecationWarning
ignore::FutureWarning:tensorflow
# Gym warnings # Gym warnings
ignore:Parameters to load are deprecated.:DeprecationWarning ignore:Parameters to load are deprecated.:DeprecationWarning
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning

View File

@ -62,7 +62,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=1.0", "stable_baselines3>=1.1.0a6",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

View File

@ -4,7 +4,7 @@ from copy import deepcopy
import numpy as np import numpy as np
import pytest import pytest
import torch as th import torch as th
from stable_baselines3.common.identity_env import FakeImageEnv from stable_baselines3.common.envs import FakeImageEnv
from stable_baselines3.common.utils import zip_strict from stable_baselines3.common.utils import zip_strict
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped

258
tests/test_dict_env.py Normal file
View File

@ -0,0 +1,258 @@
import gym
import numpy as np
import pytest
from gym import spaces
from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
from sb3_contrib import QRDQN, TQC
class DummyDictEnv(gym.Env):
"""Custom Environment for testing purposes only"""
metadata = {"render.modes": ["human"]}
def __init__(
self,
use_discrete_actions=False,
channel_last=False,
nested_dict_obs=False,
vec_only=False,
):
super().__init__()
if use_discrete_actions:
self.action_space = spaces.Discrete(3)
else:
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
N_CHANNELS = 1
HEIGHT = 64
WIDTH = 64
if channel_last:
obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
else:
obs_shape = (N_CHANNELS, HEIGHT, WIDTH)
self.observation_space = spaces.Dict(
{
# Image obs
"img": spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8),
# Vector obs
"vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
# Discrete obs
"discrete": spaces.Discrete(4),
}
)
# For checking consistency with normal MlpPolicy
if vec_only:
self.observation_space = spaces.Dict(
{
# Vector obs
"vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
}
)
if nested_dict_obs:
# Add dictionary observation inside observation space
self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)})
def seed(self, seed=None):
if seed is not None:
self.observation_space.seed(seed)
def step(self, action):
reward = 0.0
done = False
return self.observation_space.sample(), reward, done, {}
def compute_reward(self, achieved_goal, desired_goal, info):
return np.zeros((len(achieved_goal),))
def reset(self):
return self.observation_space.sample()
def render(self, mode="human"):
pass
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
def test_consistency(model_class):
"""
Make sure that dict obs with vector only vs using flatten obs is equivalent.
This ensures notable that the network architectures are the same.
"""
use_discrete_actions = model_class == QRDQN
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
env = gym.wrappers.FlattenObservation(dict_env)
dict_env.seed(10)
obs = dict_env.reset()
kwargs = {}
n_steps = 256
if model_class in {}:
kwargs = dict(
n_steps=128,
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
train_freq=8,
gradient_steps=1,
)
if model_class == QRDQN:
kwargs["learning_starts"] = 0
dict_model = model_class("MultiInputPolicy", dict_env, gamma=0.5, seed=1, **kwargs)
action_before_learning_1, _ = dict_model.predict(obs, deterministic=True)
dict_model.learn(total_timesteps=n_steps)
normal_model = model_class("MlpPolicy", env, gamma=0.5, seed=1, **kwargs)
action_before_learning_2, _ = normal_model.predict(obs["vec"], deterministic=True)
normal_model.learn(total_timesteps=n_steps)
action_1, _ = dict_model.predict(obs, deterministic=True)
action_2, _ = normal_model.predict(obs["vec"], deterministic=True)
assert np.allclose(action_before_learning_1, action_before_learning_2)
assert np.allclose(action_1, action_2)
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
@pytest.mark.parametrize("channel_last", [False, True])
def test_dict_spaces(model_class, channel_last):
"""
Additional tests to check observation space support
with mixed observation.
"""
use_discrete_actions = model_class not in [TQC]
env = DummyDictEnv(use_discrete_actions=use_discrete_actions, channel_last=channel_last)
env = gym.wrappers.TimeLimit(env, 100)
kwargs = {}
n_steps = 256
if model_class in {}:
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
),
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
n_quantiles=20,
),
train_freq=8,
gradient_steps=1,
)
if model_class == QRDQN:
kwargs["learning_starts"] = 0
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
model.learn(total_timesteps=n_steps)
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
@pytest.mark.parametrize("channel_last", [False, True])
def test_dict_vec_framestack(model_class, channel_last):
"""
Additional tests to check observation space support
for Dictionary spaces and VecEnvWrapper using MultiInputPolicy.
"""
use_discrete_actions = model_class not in [TQC]
channels_order = {"vec": None, "img": "last" if channel_last else "first"}
env = DummyVecEnv(
[lambda: SimpleMultiObsEnv(random_start=True, discrete_actions=use_discrete_actions, channel_last=channel_last)]
)
env = VecFrameStack(env, n_stack=3, channels_order=channels_order)
kwargs = {}
n_steps = 256
if model_class in {}:
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
),
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
n_quantiles=20,
),
train_freq=8,
gradient_steps=1,
)
if model_class == QRDQN:
kwargs["learning_starts"] = 0
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
model.learn(total_timesteps=n_steps)
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
def test_vec_normalize(model_class):
"""
Additional tests to check observation space support
for GoalEnv and VecNormalize using MultiInputPolicy.
"""
env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not (model_class == QRDQN))])
env = VecNormalize(env)
kwargs = {}
n_steps = 256
if model_class in {}:
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[32],
),
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(
net_arch=[32],
),
train_freq=8,
gradient_steps=1,
)
if model_class == QRDQN:
kwargs["learning_starts"] = 0
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
model.learn(total_timesteps=n_steps)
evaluate_policy(model, env, n_eval_episodes=5, warn=False)

View File

@ -8,7 +8,7 @@ import numpy as np
import pytest import pytest
import torch as th import torch as th
from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.utils import get_device from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3.common.vec_env import DummyVecEnv
@ -243,6 +243,7 @@ def test_save_load_replay_buffer(tmp_path, model_class):
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions) assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards) assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones) assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
infos = [[{"TimeLimit.truncated": truncated}] for truncated in old_replay_buffer.timeouts]
# test extending replay buffer # test extending replay buffer
model.replay_buffer.extend( model.replay_buffer.extend(
@ -251,6 +252,7 @@ def test_save_load_replay_buffer(tmp_path, model_class):
old_replay_buffer.actions, old_replay_buffer.actions,
old_replay_buffer.rewards, old_replay_buffer.rewards,
old_replay_buffer.dones, old_replay_buffer.dones,
infos,
) )

View File

@ -2,15 +2,13 @@ import gym
import numpy as np import numpy as np
import pytest import pytest
from gym import spaces from gym import spaces
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.envs import BitFlippingEnv
from sb3_contrib.common.wrappers import TimeFeatureWrapper from sb3_contrib.common.wrappers import TimeFeatureWrapper
class CustomGoalEnv(gym.GoalEnv): class CustomGoalEnv(gym.GoalEnv):
"""docstring for CustomGoalEnv."""
def __init__(self): def __init__(self):
super(CustomGoalEnv, self).__init__() super(CustomGoalEnv, self).__init__()
self.observation_space = spaces.Dict( self.observation_space = spaces.Dict(