diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 86f87d6..ea3cc71 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,11 +3,14 @@ Changelog ========== -Release 1.1.0a5 (WIP) +Release 1.1.0a6 (WIP) ------------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Added support for Dictionary observation spaces (cf SB3 doc) +- Upgraded to Stable-Baselines3 >= 1.1.0 +- Added proper handling of timeouts for off-policy algorithms (cf SB3 doc) Bug Fixes: ^^^^^^^^^^ diff --git a/docs/modules/qrdqn.rst b/docs/modules/qrdqn.rst index a648e27..79940d1 100644 --- a/docs/modules/qrdqn.rst +++ b/docs/modules/qrdqn.rst @@ -18,6 +18,7 @@ instead of predicting the mean return (DQN). MlpPolicy CnnPolicy + MultiInputPolicy Notes @@ -43,6 +44,7 @@ Discrete ✔ ✔ Box ❌ ✔ MultiDiscrete ❌ ✔ MultiBinary ❌ ✔ +Dict ❌ ✔️ ============= ====== =========== @@ -149,3 +151,6 @@ QR-DQN Policies .. autoclass:: CnnPolicy :members: + +.. autoclass:: MultiInputPolicy + :members: diff --git a/docs/modules/tqc.rst b/docs/modules/tqc.rst index 28279c7..9b9c80d 100644 --- a/docs/modules/tqc.rst +++ b/docs/modules/tqc.rst @@ -18,6 +18,7 @@ It truncates the quantiles predicted by different networks (a bit as it is done MlpPolicy CnnPolicy + MultiInputPolicy Notes @@ -42,6 +43,7 @@ Discrete ❌ ✔️ Box ✔️ ✔️ MultiDiscrete ❌ ✔️ MultiBinary ❌ ✔️ +Dict ❌ ✔️ ============= ====== =========== diff --git a/sb3_contrib/qrdqn/__init__.py b/sb3_contrib/qrdqn/__init__.py index 1d16fae..6ca676e 100644 --- a/sb3_contrib/qrdqn/__init__.py +++ b/sb3_contrib/qrdqn/__init__.py @@ -1,2 +1,2 @@ -from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy +from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.qrdqn.qrdqn import QRDQN diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index 30ea643..5776159 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -3,7 +3,13 @@ from typing import Any, Dict, List, Optional, Type import gym import torch as th from stable_baselines3.common.policies import BasePolicy, register_policy -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + NatureCNN, + create_mlp, +) from stable_baselines3.common.type_aliases import Schedule from torch import nn @@ -245,5 +251,54 @@ class CnnPolicy(QRDQNPolicy): ) +class MultiInputPolicy(QRDQNPolicy): + """ + Policy class for QR-DQN when using dict observations as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param n_quantiles: Number of quantiles + :param net_arch: The specification of the network architecture. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + n_quantiles: int = 200, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(MultiInputPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + n_quantiles, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + register_policy("MlpPolicy", MlpPolicy) register_policy("CnnPolicy", CnnPolicy) +register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index b2c22db..d0f5c95 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -4,6 +4,7 @@ import gym import numpy as np import torch as th from stable_baselines3.common import logger +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -34,6 +35,9 @@ class QRDQN(OffPolicyAlgorithm): (see ``train_freq`` and ``n_episodes_rollout``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 @@ -59,13 +63,15 @@ class QRDQN(OffPolicyAlgorithm): policy: Union[str, Type[QRDQNPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 5e-5, - buffer_size: int = 1000000, + buffer_size: int = 1000000, # 1e6 learning_starts: int = 50000, batch_size: Optional[int] = 32, tau: float = 1.0, gamma: float = 0.99, train_freq: int = 4, gradient_steps: int = 1, + replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, target_update_interval: int = 10000, exploration_fraction: float = 0.005, @@ -94,6 +100,8 @@ class QRDQN(OffPolicyAlgorithm): train_freq, gradient_steps, action_noise=None, # No action noise + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, @@ -152,7 +160,7 @@ class QRDQN(OffPolicyAlgorithm): self._update_learning_rate(self.policy.optimizer) losses = [] - for gradient_step in range(gradient_steps): + for _ in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) @@ -213,7 +221,10 @@ class QRDQN(OffPolicyAlgorithm): """ if not deterministic and np.random.rand() < self.exploration_rate: if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): - n_batch = observation.shape[0] + if isinstance(self.observation_space, gym.spaces.Dict): + n_batch = observation[list(observation.keys())[0]].shape[0] + else: + n_batch = observation.shape[0] action = np.array([self.action_space.sample() for _ in range(n_batch)]) else: action = np.array(self.action_space.sample()) diff --git a/sb3_contrib/tqc/__init__.py b/sb3_contrib/tqc/__init__.py index ef4b6bd..b8d9a71 100644 --- a/sb3_contrib/tqc/__init__.py +++ b/sb3_contrib/tqc/__init__.py @@ -1,2 +1,2 @@ -from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy +from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.tqc.tqc import TQC diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index 776a57d..d77b4c0 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, Union import gym import torch as th @@ -7,11 +7,13 @@ from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_ from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, + CombinedExtractor, FlattenExtractor, NatureCNN, create_mlp, get_actor_critic_arch, ) +from stable_baselines3.common.type_aliases import Schedule from torch import nn as nn # CAP the standard deviation of the actor @@ -278,6 +280,8 @@ class TQCPolicy(BasePolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param n_quantiles: Number of quantiles for the critic. + :param n_critics: Number of critic networks to create. :param share_features_extractor: Whether to share or not the features extractor between the actor and the critic (this saves computation time) """ @@ -286,8 +290,8 @@ class TQCPolicy(BasePolicy): self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: Callable, - net_arch: Optional[List[int]] = None, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, @@ -353,7 +357,7 @@ class TQCPolicy(BasePolicy): self._build(lr_schedule) - def _build(self, lr_schedule: Callable) -> None: + def _build(self, lr_schedule: Schedule) -> None: self.actor = self.make_actor() self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) @@ -448,6 +452,8 @@ class CnnPolicy(TQCPolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param n_quantiles: Number of quantiles for the critic. + :param n_critics: Number of critic networks to create. :param share_features_extractor: Whether to share or not the features extractor between the actor and the critic (this saves computation time) """ @@ -456,8 +462,8 @@ class CnnPolicy(TQCPolicy): self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: Callable, - net_arch: Optional[List[int]] = None, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, @@ -495,5 +501,80 @@ class CnnPolicy(TQCPolicy): ) +class MultiInputPolicy(TQCPolicy): + """ + Policy class (with both actor and critic) for TQC. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_quantiles: Number of quantiles for the critic. + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_quantiles: int = 25, + n_critics: int = 2, + share_features_extractor: bool = True, + ): + super(MultiInputPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + use_sde, + log_std_init, + sde_net_arch, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_quantiles, + n_critics, + share_features_extractor, + ) + + register_policy("MlpPolicy", MlpPolicy) register_policy("CnnPolicy", CnnPolicy) +register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 815916d..103c25e 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -4,6 +4,7 @@ import gym import numpy as np import torch as th from stable_baselines3.common import logger +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback @@ -35,6 +36,9 @@ class TQC(OffPolicyAlgorithm): :param gradient_steps: How many gradient update after each step :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 @@ -66,7 +70,7 @@ class TQC(OffPolicyAlgorithm): policy: Union[str, Type[TQCPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Callable] = 3e-4, - buffer_size: int = int(1e6), + buffer_size: int = 1000000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, @@ -74,6 +78,8 @@ class TQC(OffPolicyAlgorithm): train_freq: int = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, ent_coef: Union[str, float] = "auto", target_update_interval: int = 1, @@ -103,7 +109,9 @@ class TQC(OffPolicyAlgorithm): gamma, train_freq, gradient_steps, - action_noise, + action_noise=action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index c84ce18..1406d2f 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.1.0a5 +1.1.0a6 diff --git a/setup.cfg b/setup.cfg index 1f3fd5a..3ddbd03 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,11 +7,8 @@ license_file = LICENSE env = PYTHONHASHSEED=0 filterwarnings = - # Tensorboard/Tensorflow warnings - ignore:inspect.getargspec:DeprecationWarning:tensorflow - ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning - ignore:The binary mode of fromstring is deprecated:DeprecationWarning - ignore::FutureWarning:tensorflow + # Tensorboard warnings + ignore::DeprecationWarning:tensorboard # Gym warnings ignore:Parameters to load are deprecated.:DeprecationWarning ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning diff --git a/setup.py b/setup.py index a78cfb4..1aad728 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.0", + "stable_baselines3>=1.1.0a6", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 72488c3..6c27785 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -4,7 +4,7 @@ from copy import deepcopy import numpy as np import pytest import torch as th -from stable_baselines3.common.identity_env import FakeImageEnv +from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.utils import zip_strict from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py new file mode 100644 index 0000000..fda2725 --- /dev/null +++ b/tests/test_dict_env.py @@ -0,0 +1,258 @@ +import gym +import numpy as np +import pytest +from gym import spaces +from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize + +from sb3_contrib import QRDQN, TQC + + +class DummyDictEnv(gym.Env): + """Custom Environment for testing purposes only""" + + metadata = {"render.modes": ["human"]} + + def __init__( + self, + use_discrete_actions=False, + channel_last=False, + nested_dict_obs=False, + vec_only=False, + ): + super().__init__() + if use_discrete_actions: + self.action_space = spaces.Discrete(3) + else: + self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + N_CHANNELS = 1 + HEIGHT = 64 + WIDTH = 64 + + if channel_last: + obs_shape = (HEIGHT, WIDTH, N_CHANNELS) + else: + obs_shape = (N_CHANNELS, HEIGHT, WIDTH) + + self.observation_space = spaces.Dict( + { + # Image obs + "img": spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8), + # Vector obs + "vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32), + # Discrete obs + "discrete": spaces.Discrete(4), + } + ) + + # For checking consistency with normal MlpPolicy + if vec_only: + self.observation_space = spaces.Dict( + { + # Vector obs + "vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32), + } + ) + + if nested_dict_obs: + # Add dictionary observation inside observation space + self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)}) + + def seed(self, seed=None): + if seed is not None: + self.observation_space.seed(seed) + + def step(self, action): + reward = 0.0 + done = False + return self.observation_space.sample(), reward, done, {} + + def compute_reward(self, achieved_goal, desired_goal, info): + return np.zeros((len(achieved_goal),)) + + def reset(self): + return self.observation_space.sample() + + def render(self, mode="human"): + pass + + +@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +def test_consistency(model_class): + """ + Make sure that dict obs with vector only vs using flatten obs is equivalent. + This ensures notable that the network architectures are the same. + """ + use_discrete_actions = model_class == QRDQN + dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) + dict_env = gym.wrappers.TimeLimit(dict_env, 100) + env = gym.wrappers.FlattenObservation(dict_env) + dict_env.seed(10) + obs = dict_env.reset() + + kwargs = {} + n_steps = 256 + + if model_class in {}: + kwargs = dict( + n_steps=128, + ) + else: + # Avoid memory error when using replay buffer + # Reduce the size of the features and make learning faster + kwargs = dict( + buffer_size=250, + train_freq=8, + gradient_steps=1, + ) + if model_class == QRDQN: + kwargs["learning_starts"] = 0 + + dict_model = model_class("MultiInputPolicy", dict_env, gamma=0.5, seed=1, **kwargs) + action_before_learning_1, _ = dict_model.predict(obs, deterministic=True) + dict_model.learn(total_timesteps=n_steps) + + normal_model = model_class("MlpPolicy", env, gamma=0.5, seed=1, **kwargs) + action_before_learning_2, _ = normal_model.predict(obs["vec"], deterministic=True) + normal_model.learn(total_timesteps=n_steps) + + action_1, _ = dict_model.predict(obs, deterministic=True) + action_2, _ = normal_model.predict(obs["vec"], deterministic=True) + + assert np.allclose(action_before_learning_1, action_before_learning_2) + assert np.allclose(action_1, action_2) + + +@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("channel_last", [False, True]) +def test_dict_spaces(model_class, channel_last): + """ + Additional tests to check observation space support + with mixed observation. + """ + use_discrete_actions = model_class not in [TQC] + env = DummyDictEnv(use_discrete_actions=use_discrete_actions, channel_last=channel_last) + env = gym.wrappers.TimeLimit(env, 100) + + kwargs = {} + n_steps = 256 + + if model_class in {}: + kwargs = dict( + n_steps=128, + policy_kwargs=dict( + net_arch=[32], + features_extractor_kwargs=dict(cnn_output_dim=32), + ), + ) + else: + # Avoid memory error when using replay buffer + # Reduce the size of the features and make learning faster + kwargs = dict( + buffer_size=250, + policy_kwargs=dict( + net_arch=[32], + features_extractor_kwargs=dict(cnn_output_dim=32), + n_quantiles=20, + ), + train_freq=8, + gradient_steps=1, + ) + if model_class == QRDQN: + kwargs["learning_starts"] = 0 + + model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) + + model.learn(total_timesteps=n_steps) + + evaluate_policy(model, env, n_eval_episodes=5, warn=False) + + +@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("channel_last", [False, True]) +def test_dict_vec_framestack(model_class, channel_last): + """ + Additional tests to check observation space support + for Dictionary spaces and VecEnvWrapper using MultiInputPolicy. + """ + use_discrete_actions = model_class not in [TQC] + channels_order = {"vec": None, "img": "last" if channel_last else "first"} + env = DummyVecEnv( + [lambda: SimpleMultiObsEnv(random_start=True, discrete_actions=use_discrete_actions, channel_last=channel_last)] + ) + + env = VecFrameStack(env, n_stack=3, channels_order=channels_order) + + kwargs = {} + n_steps = 256 + + if model_class in {}: + kwargs = dict( + n_steps=128, + policy_kwargs=dict( + net_arch=[32], + features_extractor_kwargs=dict(cnn_output_dim=32), + ), + ) + else: + # Avoid memory error when using replay buffer + # Reduce the size of the features and make learning faster + kwargs = dict( + buffer_size=250, + policy_kwargs=dict( + net_arch=[32], + features_extractor_kwargs=dict(cnn_output_dim=32), + n_quantiles=20, + ), + train_freq=8, + gradient_steps=1, + ) + if model_class == QRDQN: + kwargs["learning_starts"] = 0 + + model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) + + model.learn(total_timesteps=n_steps) + + evaluate_policy(model, env, n_eval_episodes=5, warn=False) + + +@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +def test_vec_normalize(model_class): + """ + Additional tests to check observation space support + for GoalEnv and VecNormalize using MultiInputPolicy. + """ + env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not (model_class == QRDQN))]) + env = VecNormalize(env) + + kwargs = {} + n_steps = 256 + + if model_class in {}: + kwargs = dict( + n_steps=128, + policy_kwargs=dict( + net_arch=[32], + ), + ) + else: + # Avoid memory error when using replay buffer + # Reduce the size of the features and make learning faster + kwargs = dict( + buffer_size=250, + policy_kwargs=dict( + net_arch=[32], + ), + train_freq=8, + gradient_steps=1, + ) + if model_class == QRDQN: + kwargs["learning_starts"] = 0 + + model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) + + model.learn(total_timesteps=n_steps) + + evaluate_policy(model, env, n_eval_episodes=5, warn=False) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index c4047e0..716f202 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -8,7 +8,7 @@ import numpy as np import pytest import torch as th from stable_baselines3.common.base_class import BaseAlgorithm -from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox +from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv @@ -243,6 +243,7 @@ def test_save_load_replay_buffer(tmp_path, model_class): assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions) assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards) assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones) + infos = [[{"TimeLimit.truncated": truncated}] for truncated in old_replay_buffer.timeouts] # test extending replay buffer model.replay_buffer.extend( @@ -251,6 +252,7 @@ def test_save_load_replay_buffer(tmp_path, model_class): old_replay_buffer.actions, old_replay_buffer.rewards, old_replay_buffer.dones, + infos, ) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index dfd1e24..1744403 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -2,15 +2,13 @@ import gym import numpy as np import pytest from gym import spaces -from stable_baselines3.common.bit_flipping_env import BitFlippingEnv from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.envs import BitFlippingEnv from sb3_contrib.common.wrappers import TimeFeatureWrapper class CustomGoalEnv(gym.GoalEnv): - """docstring for CustomGoalEnv.""" - def __init__(self): super(CustomGoalEnv, self).__init__() self.observation_space = spaces.Dict(