Deprecation of shared layers in `mlp_extractor` (#133)

* Deprecation of shared layers in mlp_extractor

* Fix missing import

* Reformat and update tests

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Alex Pasquali 2023-01-05 10:42:22 +01:00 committed by GitHub
parent 7c4a249fa4
commit b5aa9a47ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 49 additions and 18 deletions

View File

@ -3,15 +3,22 @@
Changelog
==========
Release 1.7.0a11 (WIP)
Release 1.7.0a12 (WIP)
--------------------------
.. warning::
Shared layers in MLP policy (``mlp_extractor``) are now deprecated for PPO, A2C and TRPO.
This feature will be removed in SB3 v1.8.0 and the behavior of ``net_arch=[64, 64]``
will create **separate** networks with the same architecture, to be consistent with the off-policy algorithms.
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
please use an ``EvalCallback`` instead
- Removed deprecated ``sde_net_arch`` parameter
- Upgraded to Stable-Baselines3 >= 1.7.0a11
- Upgraded to Stable-Baselines3 >= 1.7.0a12
New Features:
^^^^^^^^^^^^^
@ -30,6 +37,7 @@ Bug Fixes:
Deprecations:
^^^^^^^^^^^^^
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
- Deprecated shared layers in ``MlpExtractor`` (@AlexPasqua)
Others:
^^^^^^^

View File

@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union
@ -46,7 +47,8 @@ class MaskableActorCriticPolicy(BasePolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
@ -74,12 +76,28 @@ class MaskableActorCriticPolicy(BasePolicy):
squash_output=False,
)
# Convert [dict()] to dict() as shared network are deprecated
if isinstance(net_arch, list) and len(net_arch) > 0:
if isinstance(net_arch[0], dict):
warnings.warn(
(
"As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
"you should now pass directly a dictionary and not a list "
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
),
)
net_arch = net_arch[0]
else:
# Note: deprecation warning will be emitted
# by the MlpExtractor constructor
pass
# Default network architecture, from stable-baselines
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
net_arch = dict(pi=[64, 64], vf=[64, 64])
self.net_arch = net_arch
self.activation_fn = activation_fn
@ -95,7 +113,8 @@ class MaskableActorCriticPolicy(BasePolicy):
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.make_features_extractor()
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
# TODO(antonin): update the check once we change net_arch behavior
if isinstance(net_arch, list) and len(net_arch) > 0:
raise ValueError(
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
)
@ -382,7 +401,8 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
@ -436,7 +456,8 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,

View File

@ -66,7 +66,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
@ -475,7 +476,8 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
@ -565,7 +567,8 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,

View File

@ -3,7 +3,6 @@ import time
from collections import deque
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
import torch as th
from gym import spaces

View File

@ -1 +1 @@
1.7.0a11
1.7.0a12

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.7.0a11",
"stable_baselines3>=1.7.0a12",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",

View File

@ -142,7 +142,7 @@ def test_dict_spaces(model_class, channel_last):
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[dict(pi=[32], vf=[32])],
net_arch=dict(pi=[32], vf=[32]),
features_extractor_kwargs=dict(cnn_output_dim=32),
),
)
@ -191,7 +191,7 @@ def test_dict_vec_framestack(model_class, channel_last):
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[dict(pi=[32], vf=[32])],
net_arch=dict(pi=[32], vf=[32]),
features_extractor_kwargs=dict(cnn_output_dim=32),
),
)
@ -234,7 +234,7 @@ def test_vec_normalize(model_class):
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[dict(pi=[32], vf=[32])],
net_arch=dict(pi=[32], vf=[32]),
),
)
else:

View File

@ -227,7 +227,7 @@ def test_ppo_lstm_performance():
max_grad_norm=1,
gae_lambda=0.98,
policy_kwargs=dict(
net_arch=[dict(vf=[64])],
net_arch=dict(vf=[64], pi=[]),
lstm_hidden_size=64,
ortho_init=False,
enable_critic_lstm=True,

View File

@ -77,7 +77,7 @@ def test_trpo_params():
use_sde=True,
sub_sampling_factor=4,
seed=0,
policy_kwargs=dict(net_arch=[dict(pi=[32], vf=[32])]),
policy_kwargs=dict(net_arch=dict(pi=[32], vf=[32])),
verbose=1,
)
model.learn(total_timesteps=500)