Release v2.1.0 (#204)
* Release v2.1.0 * Fix mypy * Fix warnings in tests
This commit is contained in:
parent
4e99b74e90
commit
cf4ed5fe65
2
Makefile
2
Makefile
|
|
@ -1,5 +1,5 @@
|
||||||
SHELL=/bin/bash
|
SHELL=/bin/bash
|
||||||
LINT_PATHS=sb3_contrib/ tests/ setup.py
|
LINT_PATHS=sb3_contrib/ tests/ setup.py docs/conf.py
|
||||||
|
|
||||||
pytest:
|
pytest:
|
||||||
./scripts/run_tests.sh
|
./scripts/run_tests.sh
|
||||||
|
|
|
||||||
21
docs/conf.py
21
docs/conf.py
|
|
@ -13,7 +13,7 @@
|
||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import MagicMock
|
from typing import Dict
|
||||||
|
|
||||||
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
|
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
|
||||||
# PyEnchant.
|
# PyEnchant.
|
||||||
|
|
@ -36,21 +36,6 @@ except ImportError:
|
||||||
sys.path.insert(0, os.path.abspath(".."))
|
sys.path.insert(0, os.path.abspath(".."))
|
||||||
|
|
||||||
|
|
||||||
class Mock(MagicMock):
|
|
||||||
__subclasses__ = []
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __getattr__(cls, name):
|
|
||||||
return MagicMock()
|
|
||||||
|
|
||||||
|
|
||||||
# Mock modules that requires C modules
|
|
||||||
# Note: because of that we cannot test examples using CI
|
|
||||||
# 'torch', 'torch.nn', 'torch.nn.functional',
|
|
||||||
# DO not mock modules for now, we will need to do that for read the docs later
|
|
||||||
MOCK_MODULES = []
|
|
||||||
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
|
|
||||||
|
|
||||||
# Read version from file
|
# Read version from file
|
||||||
version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt")
|
version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt")
|
||||||
with open(version_file) as file_handler:
|
with open(version_file) as file_handler:
|
||||||
|
|
@ -59,7 +44,7 @@ with open(version_file) as file_handler:
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "Stable Baselines3 - Contrib"
|
project = "Stable Baselines3 - Contrib"
|
||||||
copyright = "2022, Stable Baselines3"
|
copyright = "2023, Stable Baselines3"
|
||||||
author = "Stable Baselines3 Contributors"
|
author = "Stable Baselines3 Contributors"
|
||||||
|
|
||||||
# The short X.Y version
|
# The short X.Y version
|
||||||
|
|
@ -171,7 +156,7 @@ htmlhelp_basename = "StableBaselines3doc"
|
||||||
|
|
||||||
# -- Options for LaTeX output ------------------------------------------------
|
# -- Options for LaTeX output ------------------------------------------------
|
||||||
|
|
||||||
latex_elements = {
|
latex_elements: Dict[str, str] = {
|
||||||
# The paper size ('letterpaper' or 'a4paper').
|
# The paper size ('letterpaper' or 'a4paper').
|
||||||
#
|
#
|
||||||
# 'papersize': 'letterpaper',
|
# 'papersize': 'letterpaper',
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,14 @@ Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
Release 2.1.0a0 (WIP)
|
Release 2.1.0 (2023-08-17)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
- Removed Python 3.7 support
|
- Removed Python 3.7 support
|
||||||
- SB3 now requires PyTorch > 1.13
|
- SB3 now requires PyTorch > 1.13
|
||||||
|
- Upgraded to Stable-Baselines3 >= 2.1.0
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
@ -18,7 +19,7 @@ New Features:
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
- Fixed MaskablePPO ignoring stats_window_size argument
|
- Fixed MaskablePPO ignoring ``stats_window_size`` argument
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -389,7 +389,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
observation, vectorized_env = self.obs_to_tensor(observation)
|
observation, vectorized_env = self.obs_to_tensor(observation)
|
||||||
|
|
||||||
if isinstance(observation, dict):
|
if isinstance(observation, dict):
|
||||||
n_envs = observation[list(observation.keys())[0]].shape[0]
|
n_envs = observation[next(iter(observation.keys()))].shape[0]
|
||||||
else:
|
else:
|
||||||
n_envs = observation.shape[0]
|
n_envs = observation.shape[0]
|
||||||
# state : (n_layers, n_envs, dim)
|
# state : (n_layers, n_envs, dim)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import multiprocessing
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
|
||||||
|
|
@ -256,7 +256,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||||
if self.policy.is_vectorized_observation(observation):
|
if self.policy.is_vectorized_observation(observation):
|
||||||
if isinstance(observation, dict):
|
if isinstance(observation, dict):
|
||||||
n_batch = observation[list(observation.keys())[0]].shape[0]
|
n_batch = observation[next(iter(observation.keys()))].shape[0]
|
||||||
else:
|
else:
|
||||||
n_batch = observation.shape[0]
|
n_batch = observation.shape[0]
|
||||||
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
2.1.0a0
|
2.1.0
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=2.0.0",
|
"stable_baselines3>=2.1.0",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -62,10 +62,6 @@ class DummyDictEnv(gym.Env):
|
||||||
# Add dictionary observation inside observation space
|
# Add dictionary observation inside observation space
|
||||||
self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)})
|
self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)})
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
if seed is not None:
|
|
||||||
self.observation_space.seed(seed)
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
done = truncated = False
|
done = truncated = False
|
||||||
|
|
@ -103,8 +99,7 @@ def test_consistency(model_class):
|
||||||
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
|
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
|
||||||
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
|
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
|
||||||
env = gym.wrappers.FlattenObservation(dict_env)
|
env = gym.wrappers.FlattenObservation(dict_env)
|
||||||
dict_env.seed(10)
|
obs, _ = dict_env.reset(seed=10)
|
||||||
obs, _ = dict_env.reset()
|
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
n_steps = 256
|
n_steps = 256
|
||||||
|
|
|
||||||
|
|
@ -76,5 +76,5 @@ def test_action_masks_returns_expected_result():
|
||||||
|
|
||||||
# Only one valid action expected
|
# Only one valid action expected
|
||||||
masks = env.action_masks()
|
masks = env.action_masks()
|
||||||
masks[env.state] = not masks[env.state] # Bit-flip the one expected valid action
|
masks[env.unwrapped.state] = not masks[env.unwrapped.state] # Bit-flip the one expected valid action
|
||||||
assert all([not mask for mask in masks])
|
assert all([not mask for mask in masks])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue