Release v2.1.0 (#204)
* Release v2.1.0 * Fix mypy * Fix warnings in tests
This commit is contained in:
parent
4e99b74e90
commit
cf4ed5fe65
2
Makefile
2
Makefile
|
|
@ -1,5 +1,5 @@
|
|||
SHELL=/bin/bash
|
||||
LINT_PATHS=sb3_contrib/ tests/ setup.py
|
||||
LINT_PATHS=sb3_contrib/ tests/ setup.py docs/conf.py
|
||||
|
||||
pytest:
|
||||
./scripts/run_tests.sh
|
||||
|
|
|
|||
21
docs/conf.py
21
docs/conf.py
|
|
@ -13,7 +13,7 @@
|
|||
#
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
from typing import Dict
|
||||
|
||||
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
|
||||
# PyEnchant.
|
||||
|
|
@ -36,21 +36,6 @@ except ImportError:
|
|||
sys.path.insert(0, os.path.abspath(".."))
|
||||
|
||||
|
||||
class Mock(MagicMock):
|
||||
__subclasses__ = []
|
||||
|
||||
@classmethod
|
||||
def __getattr__(cls, name):
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# Mock modules that requires C modules
|
||||
# Note: because of that we cannot test examples using CI
|
||||
# 'torch', 'torch.nn', 'torch.nn.functional',
|
||||
# DO not mock modules for now, we will need to do that for read the docs later
|
||||
MOCK_MODULES = []
|
||||
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
|
||||
|
||||
# Read version from file
|
||||
version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt")
|
||||
with open(version_file) as file_handler:
|
||||
|
|
@ -59,7 +44,7 @@ with open(version_file) as file_handler:
|
|||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "Stable Baselines3 - Contrib"
|
||||
copyright = "2022, Stable Baselines3"
|
||||
copyright = "2023, Stable Baselines3"
|
||||
author = "Stable Baselines3 Contributors"
|
||||
|
||||
# The short X.Y version
|
||||
|
|
@ -171,7 +156,7 @@ htmlhelp_basename = "StableBaselines3doc"
|
|||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
latex_elements: Dict[str, str] = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 2.1.0a0 (WIP)
|
||||
Release 2.1.0 (2023-08-17)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Removed Python 3.7 support
|
||||
- SB3 now requires PyTorch > 1.13
|
||||
- Upgraded to Stable-Baselines3 >= 2.1.0
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -18,7 +19,7 @@ New Features:
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed MaskablePPO ignoring stats_window_size argument
|
||||
- Fixed MaskablePPO ignoring ``stats_window_size`` argument
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -389,7 +389,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
observation, vectorized_env = self.obs_to_tensor(observation)
|
||||
|
||||
if isinstance(observation, dict):
|
||||
n_envs = observation[list(observation.keys())[0]].shape[0]
|
||||
n_envs = observation[next(iter(observation.keys()))].shape[0]
|
||||
else:
|
||||
n_envs = observation.shape[0]
|
||||
# state : (n_layers, n_envs, dim)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import multiprocessing
|
||||
import multiprocessing as mp
|
||||
from collections import defaultdict
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||
if self.policy.is_vectorized_observation(observation):
|
||||
if isinstance(observation, dict):
|
||||
n_batch = observation[list(observation.keys())[0]].shape[0]
|
||||
n_batch = observation[next(iter(observation.keys()))].shape[0]
|
||||
else:
|
||||
n_batch = observation.shape[0]
|
||||
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.1.0a0
|
||||
2.1.0
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=2.0.0",
|
||||
"stable_baselines3>=2.1.0",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -62,10 +62,6 @@ class DummyDictEnv(gym.Env):
|
|||
# Add dictionary observation inside observation space
|
||||
self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)})
|
||||
|
||||
def seed(self, seed=None):
|
||||
if seed is not None:
|
||||
self.observation_space.seed(seed)
|
||||
|
||||
def step(self, action):
|
||||
reward = 0.0
|
||||
done = truncated = False
|
||||
|
|
@ -103,8 +99,7 @@ def test_consistency(model_class):
|
|||
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
|
||||
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
|
||||
env = gym.wrappers.FlattenObservation(dict_env)
|
||||
dict_env.seed(10)
|
||||
obs, _ = dict_env.reset()
|
||||
obs, _ = dict_env.reset(seed=10)
|
||||
|
||||
kwargs = {}
|
||||
n_steps = 256
|
||||
|
|
|
|||
|
|
@ -76,5 +76,5 @@ def test_action_masks_returns_expected_result():
|
|||
|
||||
# Only one valid action expected
|
||||
masks = env.action_masks()
|
||||
masks[env.state] = not masks[env.state] # Bit-flip the one expected valid action
|
||||
masks[env.unwrapped.state] = not masks[env.unwrapped.state] # Bit-flip the one expected valid action
|
||||
assert all([not mask for mask in masks])
|
||||
|
|
|
|||
Loading…
Reference in New Issue