Removed shared layers in mlp_extractor (#137)
* Removed shared layers in mlp_extractor * Add ruff * Update version and add warning Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
1d0edd2dab
commit
6bc8e426bf
8
Makefile
8
Makefile
|
|
@ -19,6 +19,14 @@ lint:
|
||||||
# exit-zero treats all errors as warnings.
|
# exit-zero treats all errors as warnings.
|
||||||
flake8 ${LINT_PATHS} --count --exit-zero --statistics
|
flake8 ${LINT_PATHS} --count --exit-zero --statistics
|
||||||
|
|
||||||
|
ruff:
|
||||||
|
# stop the build if there are Python syntax errors or undefined names
|
||||||
|
# see https://lintlyci.github.io/Flake8Rules/
|
||||||
|
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
|
||||||
|
# exit-zero treats all errors as warnings.
|
||||||
|
ruff ${LINT_PATHS} --exit-zero --line-length 127
|
||||||
|
|
||||||
|
|
||||||
format:
|
format:
|
||||||
# Sort imports
|
# Sort imports
|
||||||
isort ${LINT_PATHS}
|
isort ${LINT_PATHS}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,27 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
Release 1.8.0a2 (WIP)
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
Breaking Changes:
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
|
||||||
|
- Upgraded to Stable-Baselines3 >= 1.8.0
|
||||||
|
|
||||||
|
New Features:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
Deprecations:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Others:
|
||||||
|
^^^^^^^
|
||||||
|
|
||||||
|
|
||||||
Release 1.7.0 (2023-01-10)
|
Release 1.7.0 (2023-01-10)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,8 +47,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
# TODO(antonin): update type annotation when we remove shared network support
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||||
|
|
@ -76,21 +75,15 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
squash_output=False,
|
squash_output=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert [dict()] to dict() as shared network are deprecated
|
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
|
||||||
if isinstance(net_arch, list) and len(net_arch) > 0:
|
|
||||||
if isinstance(net_arch[0], dict):
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
(
|
(
|
||||||
"As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
|
"As shared layers in the mlp_extractor are removed since SB3 v1.8.0, "
|
||||||
"you should now pass directly a dictionary and not a list "
|
"you should now pass directly a dictionary and not a list "
|
||||||
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
|
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
net_arch = net_arch[0]
|
net_arch = net_arch[0]
|
||||||
else:
|
|
||||||
# Note: deprecation warning will be emitted
|
|
||||||
# by the MlpExtractor constructor
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Default network architecture, from stable-baselines
|
# Default network architecture, from stable-baselines
|
||||||
if net_arch is None:
|
if net_arch is None:
|
||||||
|
|
@ -112,12 +105,6 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
else:
|
else:
|
||||||
self.pi_features_extractor = self.features_extractor
|
self.pi_features_extractor = self.features_extractor
|
||||||
self.vf_features_extractor = self.make_features_extractor()
|
self.vf_features_extractor = self.make_features_extractor()
|
||||||
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
|
|
||||||
# TODO(antonin): update the check once we change net_arch behavior
|
|
||||||
if isinstance(net_arch, list) and len(net_arch) > 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Action distribution
|
# Action distribution
|
||||||
self.action_dist = make_masked_proba_distribution(action_space)
|
self.action_dist = make_masked_proba_distribution(action_space)
|
||||||
|
|
@ -401,8 +388,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
# TODO(antonin): update type annotation when we remove shared network support
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||||
|
|
@ -456,8 +442,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
||||||
observation_space: spaces.Dict,
|
observation_space: spaces.Dict,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
# TODO(antonin): update type annotation when we remove shared network support
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||||
|
|
|
||||||
|
|
@ -66,8 +66,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
# TODO(antonin): update type annotation when we remove shared network support
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
|
|
@ -476,8 +475,7 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
# TODO(antonin): update type annotation when we remove shared network support
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
|
|
@ -567,8 +565,7 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
# TODO(antonin): update type annotation when we remove shared network support
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.7.0
|
1.8.0a2
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.7.0",
|
"stable_baselines3>=1.8.0a2",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue