Removed shared layers in mlp_extractor (#137)

* Removed shared layers in mlp_extractor

* Add ruff

* Update version and add warning

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Alex Pasquali 2023-01-25 16:28:27 +01:00 committed by GitHub
parent 1d0edd2dab
commit 6bc8e426bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 46 additions and 35 deletions

View File

@ -19,6 +19,14 @@ lint:
# exit-zero treats all errors as warnings.
flake8 ${LINT_PATHS} --count --exit-zero --statistics
ruff:
# stop the build if there are Python syntax errors or undefined names
# see https://lintlyci.github.io/Flake8Rules/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero --line-length 127
format:
# Sort imports
isort ${LINT_PATHS}

View File

@ -3,6 +3,27 @@
Changelog
==========
Release 1.8.0a2 (WIP)
--------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
- Upgraded to Stable-Baselines3 >= 1.8.0
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Release 1.7.0 (2023-01-10)
--------------------------

View File

@ -47,8 +47,7 @@ class MaskableActorCriticPolicy(BasePolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
@ -76,21 +75,15 @@ class MaskableActorCriticPolicy(BasePolicy):
squash_output=False,
)
# Convert [dict()] to dict() as shared network are deprecated
if isinstance(net_arch, list) and len(net_arch) > 0:
if isinstance(net_arch[0], dict):
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
warnings.warn(
(
"As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
"As shared layers in the mlp_extractor are removed since SB3 v1.8.0, "
"you should now pass directly a dictionary and not a list "
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
),
)
net_arch = net_arch[0]
else:
# Note: deprecation warning will be emitted
# by the MlpExtractor constructor
pass
# Default network architecture, from stable-baselines
if net_arch is None:
@ -112,12 +105,6 @@ class MaskableActorCriticPolicy(BasePolicy):
else:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.make_features_extractor()
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
# TODO(antonin): update the check once we change net_arch behavior
if isinstance(net_arch, list) and len(net_arch) > 0:
raise ValueError(
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
)
# Action distribution
self.action_dist = make_masked_proba_distribution(action_space)
@ -401,8 +388,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
@ -456,8 +442,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,

View File

@ -66,8 +66,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
@ -476,8 +475,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
@ -567,8 +565,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,

View File

@ -1 +1 @@
1.7.0
1.8.0a2

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.7.0",
"stable_baselines3>=1.8.0a2",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",