Release v2.1.0 (#204)

* Release v2.1.0

* Fix mypy

* Fix warnings in tests
This commit is contained in:
Antonin RAFFIN 2023-08-17 22:17:12 +02:00 committed by Andreas Schaarschmidt
parent 4e99b74e90
commit cf4ed5fe65
10 changed files with 13 additions and 33 deletions

View File

@ -1,5 +1,5 @@
SHELL=/bin/bash SHELL=/bin/bash
LINT_PATHS=sb3_contrib/ tests/ setup.py LINT_PATHS=sb3_contrib/ tests/ setup.py docs/conf.py
pytest: pytest:
./scripts/run_tests.sh ./scripts/run_tests.sh

View File

@ -13,7 +13,7 @@
# #
import os import os
import sys import sys
from unittest.mock import MagicMock from typing import Dict
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support # We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
# PyEnchant. # PyEnchant.
@ -36,21 +36,6 @@ except ImportError:
sys.path.insert(0, os.path.abspath("..")) sys.path.insert(0, os.path.abspath(".."))
class Mock(MagicMock):
__subclasses__ = []
@classmethod
def __getattr__(cls, name):
return MagicMock()
# Mock modules that requires C modules
# Note: because of that we cannot test examples using CI
# 'torch', 'torch.nn', 'torch.nn.functional',
# DO not mock modules for now, we will need to do that for read the docs later
MOCK_MODULES = []
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
# Read version from file # Read version from file
version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt") version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt")
with open(version_file) as file_handler: with open(version_file) as file_handler:
@ -59,7 +44,7 @@ with open(version_file) as file_handler:
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "Stable Baselines3 - Contrib" project = "Stable Baselines3 - Contrib"
copyright = "2022, Stable Baselines3" copyright = "2023, Stable Baselines3"
author = "Stable Baselines3 Contributors" author = "Stable Baselines3 Contributors"
# The short X.Y version # The short X.Y version
@ -171,7 +156,7 @@ htmlhelp_basename = "StableBaselines3doc"
# -- Options for LaTeX output ------------------------------------------------ # -- Options for LaTeX output ------------------------------------------------
latex_elements = { latex_elements: Dict[str, str] = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
# #
# 'papersize': 'letterpaper', # 'papersize': 'letterpaper',

View File

@ -4,13 +4,14 @@ Changelog
========== ==========
Release 2.1.0a0 (WIP) Release 2.1.0 (2023-08-17)
-------------------------- --------------------------
Breaking Changes: Breaking Changes:
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
- Removed Python 3.7 support - Removed Python 3.7 support
- SB3 now requires PyTorch > 1.13 - SB3 now requires PyTorch > 1.13
- Upgraded to Stable-Baselines3 >= 2.1.0
New Features: New Features:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
@ -18,7 +19,7 @@ New Features:
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^
- Fixed MaskablePPO ignoring stats_window_size argument - Fixed MaskablePPO ignoring ``stats_window_size`` argument
Deprecations: Deprecations:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^

View File

@ -389,7 +389,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
observation, vectorized_env = self.obs_to_tensor(observation) observation, vectorized_env = self.obs_to_tensor(observation)
if isinstance(observation, dict): if isinstance(observation, dict):
n_envs = observation[list(observation.keys())[0]].shape[0] n_envs = observation[next(iter(observation.keys()))].shape[0]
else: else:
n_envs = observation.shape[0] n_envs = observation.shape[0]
# state : (n_layers, n_envs, dim) # state : (n_layers, n_envs, dim)

View File

@ -1,4 +1,3 @@
import multiprocessing
import multiprocessing as mp import multiprocessing as mp
from collections import defaultdict from collections import defaultdict
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union

View File

@ -256,7 +256,7 @@ class QRDQN(OffPolicyAlgorithm):
if not deterministic and np.random.rand() < self.exploration_rate: if not deterministic and np.random.rand() < self.exploration_rate:
if self.policy.is_vectorized_observation(observation): if self.policy.is_vectorized_observation(observation):
if isinstance(observation, dict): if isinstance(observation, dict):
n_batch = observation[list(observation.keys())[0]].shape[0] n_batch = observation[next(iter(observation.keys()))].shape[0]
else: else:
n_batch = observation.shape[0] n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)]) action = np.array([self.action_space.sample() for _ in range(n_batch)])

View File

@ -1 +1 @@
2.1.0a0 2.1.0

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=2.0.0", "stable_baselines3>=2.1.0",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

View File

@ -62,10 +62,6 @@ class DummyDictEnv(gym.Env):
# Add dictionary observation inside observation space # Add dictionary observation inside observation space
self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)}) self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)})
def seed(self, seed=None):
if seed is not None:
self.observation_space.seed(seed)
def step(self, action): def step(self, action):
reward = 0.0 reward = 0.0
done = truncated = False done = truncated = False
@ -103,8 +99,7 @@ def test_consistency(model_class):
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
dict_env = gym.wrappers.TimeLimit(dict_env, 100) dict_env = gym.wrappers.TimeLimit(dict_env, 100)
env = gym.wrappers.FlattenObservation(dict_env) env = gym.wrappers.FlattenObservation(dict_env)
dict_env.seed(10) obs, _ = dict_env.reset(seed=10)
obs, _ = dict_env.reset()
kwargs = {} kwargs = {}
n_steps = 256 n_steps = 256

View File

@ -76,5 +76,5 @@ def test_action_masks_returns_expected_result():
# Only one valid action expected # Only one valid action expected
masks = env.action_masks() masks = env.action_masks()
masks[env.state] = not masks[env.state] # Bit-flip the one expected valid action masks[env.unwrapped.state] = not masks[env.unwrapped.state] # Bit-flip the one expected valid action
assert all([not mask for mask in masks]) assert all([not mask for mask in masks])