From cf4ed5fe65350758bf63f461a25d969cb50cf1fd Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 17 Aug 2023 22:17:12 +0200 Subject: [PATCH] Release v2.1.0 (#204) * Release v2.1.0 * Fix mypy * Fix warnings in tests --- Makefile | 2 +- docs/conf.py | 21 +++------------------ docs/misc/changelog.rst | 5 +++-- sb3_contrib/common/recurrent/policies.py | 2 +- sb3_contrib/common/vec_env/async_eval.py | 1 - sb3_contrib/qrdqn/qrdqn.py | 2 +- sb3_contrib/version.txt | 2 +- setup.py | 2 +- tests/test_dict_env.py | 7 +------ tests/wrappers/test_action_masker.py | 2 +- 10 files changed, 13 insertions(+), 33 deletions(-) diff --git a/Makefile b/Makefile index 0f53f5b..9643530 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ SHELL=/bin/bash -LINT_PATHS=sb3_contrib/ tests/ setup.py +LINT_PATHS=sb3_contrib/ tests/ setup.py docs/conf.py pytest: ./scripts/run_tests.sh diff --git a/docs/conf.py b/docs/conf.py index d62bfc9..efd8cc6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,7 +13,7 @@ # import os import sys -from unittest.mock import MagicMock +from typing import Dict # We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support # PyEnchant. @@ -36,21 +36,6 @@ except ImportError: sys.path.insert(0, os.path.abspath("..")) -class Mock(MagicMock): - __subclasses__ = [] - - @classmethod - def __getattr__(cls, name): - return MagicMock() - - -# Mock modules that requires C modules -# Note: because of that we cannot test examples using CI -# 'torch', 'torch.nn', 'torch.nn.functional', -# DO not mock modules for now, we will need to do that for read the docs later -MOCK_MODULES = [] -sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) - # Read version from file version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt") with open(version_file) as file_handler: @@ -59,7 +44,7 @@ with open(version_file) as file_handler: # -- Project information ----------------------------------------------------- project = "Stable Baselines3 - Contrib" -copyright = "2022, Stable Baselines3" +copyright = "2023, Stable Baselines3" author = "Stable Baselines3 Contributors" # The short X.Y version @@ -171,7 +156,7 @@ htmlhelp_basename = "StableBaselines3doc" # -- Options for LaTeX output ------------------------------------------------ -latex_elements = { +latex_elements: Dict[str, str] = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 29f832e..8c93bae 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,13 +4,14 @@ Changelog ========== -Release 2.1.0a0 (WIP) +Release 2.1.0 (2023-08-17) -------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ - Removed Python 3.7 support - SB3 now requires PyTorch > 1.13 +- Upgraded to Stable-Baselines3 >= 2.1.0 New Features: ^^^^^^^^^^^^^ @@ -18,7 +19,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ -- Fixed MaskablePPO ignoring stats_window_size argument +- Fixed MaskablePPO ignoring ``stats_window_size`` argument Deprecations: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index cf0c326..5e5090b 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -389,7 +389,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): observation, vectorized_env = self.obs_to_tensor(observation) if isinstance(observation, dict): - n_envs = observation[list(observation.keys())[0]].shape[0] + n_envs = observation[next(iter(observation.keys()))].shape[0] else: n_envs = observation.shape[0] # state : (n_layers, n_envs, dim) diff --git a/sb3_contrib/common/vec_env/async_eval.py b/sb3_contrib/common/vec_env/async_eval.py index f2500ad..9eed7b5 100644 --- a/sb3_contrib/common/vec_env/async_eval.py +++ b/sb3_contrib/common/vec_env/async_eval.py @@ -1,4 +1,3 @@ -import multiprocessing import multiprocessing as mp from collections import defaultdict from typing import Callable, List, Optional, Tuple, Union diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 97177be..c7dbacd 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -256,7 +256,7 @@ class QRDQN(OffPolicyAlgorithm): if not deterministic and np.random.rand() < self.exploration_rate: if self.policy.is_vectorized_observation(observation): if isinstance(observation, dict): - n_batch = observation[list(observation.keys())[0]].shape[0] + n_batch = observation[next(iter(observation.keys()))].shape[0] else: n_batch = observation.shape[0] action = np.array([self.action_space.sample() for _ in range(n_batch)]) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index ecaf4ee..7ec1d6d 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.1.0a0 +2.1.0 diff --git a/setup.py b/setup.py index 5beb233..07b1ab5 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.0.0", + "stable_baselines3>=2.1.0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 9722da3..a276709 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -62,10 +62,6 @@ class DummyDictEnv(gym.Env): # Add dictionary observation inside observation space self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)}) - def seed(self, seed=None): - if seed is not None: - self.observation_space.seed(seed) - def step(self, action): reward = 0.0 done = truncated = False @@ -103,8 +99,7 @@ def test_consistency(model_class): dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) - dict_env.seed(10) - obs, _ = dict_env.reset() + obs, _ = dict_env.reset(seed=10) kwargs = {} n_steps = 256 diff --git a/tests/wrappers/test_action_masker.py b/tests/wrappers/test_action_masker.py index ebfb83c..afd6d34 100644 --- a/tests/wrappers/test_action_masker.py +++ b/tests/wrappers/test_action_masker.py @@ -76,5 +76,5 @@ def test_action_masks_returns_expected_result(): # Only one valid action expected masks = env.action_masks() - masks[env.state] = not masks[env.state] # Bit-flip the one expected valid action + masks[env.unwrapped.state] = not masks[env.unwrapped.state] # Bit-flip the one expected valid action assert all([not mask for mask in masks])