Dictionary Observations (#29)

* Add TQC support for new HER version

* Add dict obs support

* Add support for dict obs
This commit is contained in:
Antonin RAFFIN 2021-05-11 13:24:31 +02:00 committed by GitHub
parent 61bfdbc00a
commit 3665695d1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 447 additions and 27 deletions

View File

@ -3,11 +3,14 @@
Changelog
==========
Release 1.1.0a5 (WIP)
Release 1.1.0a6 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Added support for Dictionary observation spaces (cf SB3 doc)
- Upgraded to Stable-Baselines3 >= 1.1.0
- Added proper handling of timeouts for off-policy algorithms (cf SB3 doc)
Bug Fixes:
^^^^^^^^^^

View File

@ -18,6 +18,7 @@ instead of predicting the mean return (DQN).
MlpPolicy
CnnPolicy
MultiInputPolicy
Notes
@ -43,6 +44,7 @@ Discrete ✔ ✔
Box ❌ ✔
MultiDiscrete ❌ ✔
MultiBinary ❌ ✔
Dict ❌ ✔️
============= ====== ===========
@ -149,3 +151,6 @@ QR-DQN Policies
.. autoclass:: CnnPolicy
:members:
.. autoclass:: MultiInputPolicy
:members:

View File

@ -18,6 +18,7 @@ It truncates the quantiles predicted by different networks (a bit as it is done
MlpPolicy
CnnPolicy
MultiInputPolicy
Notes
@ -42,6 +43,7 @@ Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
Dict ❌ ✔️
============= ====== ===========

View File

@ -1,2 +1,2 @@
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.qrdqn.qrdqn import QRDQN

View File

@ -3,7 +3,13 @@ from typing import Any, Dict, List, Optional, Type
import gym
import torch as th
from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import Schedule
from torch import nn
@ -245,5 +251,54 @@ class CnnPolicy(QRDQNPolicy):
)
class MultiInputPolicy(QRDQNPolicy):
"""
Policy class for QR-DQN when using dict observations as input.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param n_quantiles: Number of quantiles
:param net_arch: The specification of the network architecture.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(MultiInputPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
n_quantiles,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View File

@ -4,6 +4,7 @@ import gym
import numpy as np
import torch as th
from stable_baselines3.common import logger
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@ -34,6 +35,9 @@ class QRDQN(OffPolicyAlgorithm):
(see ``train_freq`` and ``n_episodes_rollout``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
@ -59,13 +63,15 @@ class QRDQN(OffPolicyAlgorithm):
policy: Union[str, Type[QRDQNPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 5e-5,
buffer_size: int = 1000000,
buffer_size: int = 1000000, # 1e6
learning_starts: int = 50000,
batch_size: Optional[int] = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: int = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
exploration_fraction: float = 0.005,
@ -94,6 +100,8 @@ class QRDQN(OffPolicyAlgorithm):
train_freq,
gradient_steps,
action_noise=None, # No action noise
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
@ -152,7 +160,7 @@ class QRDQN(OffPolicyAlgorithm):
self._update_learning_rate(self.policy.optimizer)
losses = []
for gradient_step in range(gradient_steps):
for _ in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
@ -213,7 +221,10 @@ class QRDQN(OffPolicyAlgorithm):
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
n_batch = observation.shape[0]
if isinstance(self.observation_space, gym.spaces.Dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
else:
n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])
else:
action = np.array(self.action_space.sample())

View File

@ -1,2 +1,2 @@
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.tqc.tqc import TQC

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th
@ -7,11 +7,13 @@ from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
NatureCNN,
create_mlp,
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
from torch import nn as nn
# CAP the standard deviation of the actor
@ -278,6 +280,8 @@ class TQCPolicy(BasePolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_quantiles: Number of quantiles for the critic.
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
@ -286,8 +290,8 @@ class TQCPolicy(BasePolicy):
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
@ -353,7 +357,7 @@ class TQCPolicy(BasePolicy):
self._build(lr_schedule)
def _build(self, lr_schedule: Callable) -> None:
def _build(self, lr_schedule: Schedule) -> None:
self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
@ -448,6 +452,8 @@ class CnnPolicy(TQCPolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_quantiles: Number of quantiles for the critic.
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
@ -456,8 +462,8 @@ class CnnPolicy(TQCPolicy):
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
@ -495,5 +501,80 @@ class CnnPolicy(TQCPolicy):
)
class MultiInputPolicy(TQCPolicy):
"""
Policy class (with both actor and critic) for TQC.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_quantiles: Number of quantiles for the critic.
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super(MultiInputPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
use_sde,
log_std_init,
sde_net_arch,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_quantiles,
n_critics,
share_features_extractor,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View File

@ -4,6 +4,7 @@ import gym
import numpy as np
import torch as th
from stable_baselines3.common import logger
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
@ -35,6 +36,9 @@ class TQC(OffPolicyAlgorithm):
:param gradient_steps: How many gradient update after each step
:param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
@ -66,7 +70,7 @@ class TQC(OffPolicyAlgorithm):
policy: Union[str, Type[TQCPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Callable] = 3e-4,
buffer_size: int = int(1e6),
buffer_size: int = 1000000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
@ -74,6 +78,8 @@ class TQC(OffPolicyAlgorithm):
train_freq: int = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
ent_coef: Union[str, float] = "auto",
target_update_interval: int = 1,
@ -103,7 +109,9 @@ class TQC(OffPolicyAlgorithm):
gamma,
train_freq,
gradient_steps,
action_noise,
action_noise=action_noise,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,

View File

@ -1 +1 @@
1.1.0a5
1.1.0a6

View File

@ -7,11 +7,8 @@ license_file = LICENSE
env =
PYTHONHASHSEED=0
filterwarnings =
# Tensorboard/Tensorflow warnings
ignore:inspect.getargspec:DeprecationWarning:tensorflow
ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning
ignore:The binary mode of fromstring is deprecated:DeprecationWarning
ignore::FutureWarning:tensorflow
# Tensorboard warnings
ignore::DeprecationWarning:tensorboard
# Gym warnings
ignore:Parameters to load are deprecated.:DeprecationWarning
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning

View File

@ -62,7 +62,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.0",
"stable_baselines3>=1.1.0a6",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",

View File

@ -4,7 +4,7 @@ from copy import deepcopy
import numpy as np
import pytest
import torch as th
from stable_baselines3.common.identity_env import FakeImageEnv
from stable_baselines3.common.envs import FakeImageEnv
from stable_baselines3.common.utils import zip_strict
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped

258
tests/test_dict_env.py Normal file
View File

@ -0,0 +1,258 @@
import gym
import numpy as np
import pytest
from gym import spaces
from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
from sb3_contrib import QRDQN, TQC
class DummyDictEnv(gym.Env):
"""Custom Environment for testing purposes only"""
metadata = {"render.modes": ["human"]}
def __init__(
self,
use_discrete_actions=False,
channel_last=False,
nested_dict_obs=False,
vec_only=False,
):
super().__init__()
if use_discrete_actions:
self.action_space = spaces.Discrete(3)
else:
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
N_CHANNELS = 1
HEIGHT = 64
WIDTH = 64
if channel_last:
obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
else:
obs_shape = (N_CHANNELS, HEIGHT, WIDTH)
self.observation_space = spaces.Dict(
{
# Image obs
"img": spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8),
# Vector obs
"vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
# Discrete obs
"discrete": spaces.Discrete(4),
}
)
# For checking consistency with normal MlpPolicy
if vec_only:
self.observation_space = spaces.Dict(
{
# Vector obs
"vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
}
)
if nested_dict_obs:
# Add dictionary observation inside observation space
self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)})
def seed(self, seed=None):
if seed is not None:
self.observation_space.seed(seed)
def step(self, action):
reward = 0.0
done = False
return self.observation_space.sample(), reward, done, {}
def compute_reward(self, achieved_goal, desired_goal, info):
return np.zeros((len(achieved_goal),))
def reset(self):
return self.observation_space.sample()
def render(self, mode="human"):
pass
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
def test_consistency(model_class):
"""
Make sure that dict obs with vector only vs using flatten obs is equivalent.
This ensures notable that the network architectures are the same.
"""
use_discrete_actions = model_class == QRDQN
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
env = gym.wrappers.FlattenObservation(dict_env)
dict_env.seed(10)
obs = dict_env.reset()
kwargs = {}
n_steps = 256
if model_class in {}:
kwargs = dict(
n_steps=128,
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
train_freq=8,
gradient_steps=1,
)
if model_class == QRDQN:
kwargs["learning_starts"] = 0
dict_model = model_class("MultiInputPolicy", dict_env, gamma=0.5, seed=1, **kwargs)
action_before_learning_1, _ = dict_model.predict(obs, deterministic=True)
dict_model.learn(total_timesteps=n_steps)
normal_model = model_class("MlpPolicy", env, gamma=0.5, seed=1, **kwargs)
action_before_learning_2, _ = normal_model.predict(obs["vec"], deterministic=True)
normal_model.learn(total_timesteps=n_steps)
action_1, _ = dict_model.predict(obs, deterministic=True)
action_2, _ = normal_model.predict(obs["vec"], deterministic=True)
assert np.allclose(action_before_learning_1, action_before_learning_2)
assert np.allclose(action_1, action_2)
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
@pytest.mark.parametrize("channel_last", [False, True])
def test_dict_spaces(model_class, channel_last):
"""
Additional tests to check observation space support
with mixed observation.
"""
use_discrete_actions = model_class not in [TQC]
env = DummyDictEnv(use_discrete_actions=use_discrete_actions, channel_last=channel_last)
env = gym.wrappers.TimeLimit(env, 100)
kwargs = {}
n_steps = 256
if model_class in {}:
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
),
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
n_quantiles=20,
),
train_freq=8,
gradient_steps=1,
)
if model_class == QRDQN:
kwargs["learning_starts"] = 0
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
model.learn(total_timesteps=n_steps)
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
@pytest.mark.parametrize("channel_last", [False, True])
def test_dict_vec_framestack(model_class, channel_last):
"""
Additional tests to check observation space support
for Dictionary spaces and VecEnvWrapper using MultiInputPolicy.
"""
use_discrete_actions = model_class not in [TQC]
channels_order = {"vec": None, "img": "last" if channel_last else "first"}
env = DummyVecEnv(
[lambda: SimpleMultiObsEnv(random_start=True, discrete_actions=use_discrete_actions, channel_last=channel_last)]
)
env = VecFrameStack(env, n_stack=3, channels_order=channels_order)
kwargs = {}
n_steps = 256
if model_class in {}:
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
),
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
n_quantiles=20,
),
train_freq=8,
gradient_steps=1,
)
if model_class == QRDQN:
kwargs["learning_starts"] = 0
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
model.learn(total_timesteps=n_steps)
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
def test_vec_normalize(model_class):
"""
Additional tests to check observation space support
for GoalEnv and VecNormalize using MultiInputPolicy.
"""
env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not (model_class == QRDQN))])
env = VecNormalize(env)
kwargs = {}
n_steps = 256
if model_class in {}:
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[32],
),
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(
net_arch=[32],
),
train_freq=8,
gradient_steps=1,
)
if model_class == QRDQN:
kwargs["learning_starts"] = 0
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
model.learn(total_timesteps=n_steps)
evaluate_policy(model, env, n_eval_episodes=5, warn=False)

View File

@ -8,7 +8,7 @@ import numpy as np
import pytest
import torch as th
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
@ -243,6 +243,7 @@ def test_save_load_replay_buffer(tmp_path, model_class):
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
infos = [[{"TimeLimit.truncated": truncated}] for truncated in old_replay_buffer.timeouts]
# test extending replay buffer
model.replay_buffer.extend(
@ -251,6 +252,7 @@ def test_save_load_replay_buffer(tmp_path, model_class):
old_replay_buffer.actions,
old_replay_buffer.rewards,
old_replay_buffer.dones,
infos,
)

View File

@ -2,15 +2,13 @@ import gym
import numpy as np
import pytest
from gym import spaces
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.envs import BitFlippingEnv
from sb3_contrib.common.wrappers import TimeFeatureWrapper
class CustomGoalEnv(gym.GoalEnv):
"""docstring for CustomGoalEnv."""
def __init__(self):
super(CustomGoalEnv, self).__init__()
self.observation_space = spaces.Dict(