From 6bc8e426bf249a42b195ee2f4c132a91cbd3968d Mon Sep 17 00:00:00 2001 From: Alex Pasquali Date: Wed, 25 Jan 2023 16:28:27 +0100 Subject: [PATCH] Removed shared layers in mlp_extractor (#137) * Removed shared layers in mlp_extractor * Add ruff * Update version and add warning Co-authored-by: Antonin Raffin --- Makefile | 8 +++++ docs/misc/changelog.rst | 21 +++++++++++++ sb3_contrib/common/maskable/policies.py | 39 ++++++++---------------- sb3_contrib/common/recurrent/policies.py | 9 ++---- sb3_contrib/version.txt | 2 +- setup.py | 2 +- 6 files changed, 46 insertions(+), 35 deletions(-) diff --git a/Makefile b/Makefile index 1d3b794..4b5d230 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,14 @@ lint: # exit-zero treats all errors as warnings. flake8 ${LINT_PATHS} --count --exit-zero --statistics +ruff: + # stop the build if there are Python syntax errors or undefined names + # see https://lintlyci.github.io/Flake8Rules/ + ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source + # exit-zero treats all errors as warnings. + ruff ${LINT_PATHS} --exit-zero --line-length 127 + + format: # Sort imports isort ${LINT_PATHS} diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e09dd94..a1b5a68 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,27 @@ Changelog ========== +Release 1.8.0a2 (WIP) +-------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Removed shared layers in ``mlp_extractor`` (@AlexPasqua) +- Upgraded to Stable-Baselines3 >= 1.8.0 + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + + Release 1.7.0 (2023-01-10) -------------------------- diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index 3fb401a..b0b434b 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -47,8 +47,7 @@ class MaskableActorCriticPolicy(BasePolicy): observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - # TODO(antonin): update type annotation when we remove shared network support - net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, @@ -76,21 +75,15 @@ class MaskableActorCriticPolicy(BasePolicy): squash_output=False, ) - # Convert [dict()] to dict() as shared network are deprecated - if isinstance(net_arch, list) and len(net_arch) > 0: - if isinstance(net_arch[0], dict): - warnings.warn( - ( - "As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, " - "you should now pass directly a dictionary and not a list " - "(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])" - ), - ) - net_arch = net_arch[0] - else: - # Note: deprecation warning will be emitted - # by the MlpExtractor constructor - pass + if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict): + warnings.warn( + ( + "As shared layers in the mlp_extractor are removed since SB3 v1.8.0, " + "you should now pass directly a dictionary and not a list " + "(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])" + ), + ) + net_arch = net_arch[0] # Default network architecture, from stable-baselines if net_arch is None: @@ -112,12 +105,6 @@ class MaskableActorCriticPolicy(BasePolicy): else: self.pi_features_extractor = self.features_extractor self.vf_features_extractor = self.make_features_extractor() - # if the features extractor is not shared, there cannot be shared layers in the mlp_extractor - # TODO(antonin): update the check once we change net_arch behavior - if isinstance(net_arch, list) and len(net_arch) > 0: - raise ValueError( - "Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor" - ) # Action distribution self.action_dist = make_masked_proba_distribution(action_space) @@ -401,8 +388,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy): observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - # TODO(antonin): update type annotation when we remove shared network support - net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, @@ -456,8 +442,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy): observation_space: spaces.Dict, action_space: spaces.Space, lr_schedule: Schedule, - # TODO(antonin): update type annotation when we remove shared network support - net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 6721173..bc3b626 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -66,8 +66,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - # TODO(antonin): update type annotation when we remove shared network support - net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, @@ -476,8 +475,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - # TODO(antonin): update type annotation when we remove shared network support - net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, @@ -567,8 +565,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy): observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - # TODO(antonin): update type annotation when we remove shared network support - net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index bd8bf88..c3d22c0 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.7.0 +1.8.0a2 diff --git a/setup.py b/setup.py index c0867f3..0e48c57 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.7.0", + "stable_baselines3>=1.8.0a2", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin",