Deprecation of shared layers in `mlp_extractor` (#133)
* Deprecation of shared layers in mlp_extractor * Fix missing import * Reformat and update tests Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
7c4a249fa4
commit
b5aa9a47ce
|
|
@ -3,15 +3,22 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.7.0a11 (WIP)
|
||||
Release 1.7.0a12 (WIP)
|
||||
--------------------------
|
||||
|
||||
.. warning::
|
||||
|
||||
Shared layers in MLP policy (``mlp_extractor``) are now deprecated for PPO, A2C and TRPO.
|
||||
This feature will be removed in SB3 v1.8.0 and the behavior of ``net_arch=[64, 64]``
|
||||
will create **separate** networks with the same architecture, to be consistent with the off-policy algorithms.
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
||||
please use an ``EvalCallback`` instead
|
||||
- Removed deprecated ``sde_net_arch`` parameter
|
||||
- Upgraded to Stable-Baselines3 >= 1.7.0a11
|
||||
- Upgraded to Stable-Baselines3 >= 1.7.0a12
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -30,6 +37,7 @@ Bug Fixes:
|
|||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
|
||||
- Deprecated shared layers in ``MlpExtractor`` (@AlexPasqua)
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
|
|
@ -46,7 +47,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
# TODO(antonin): update type annotation when we remove shared network support
|
||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
|
|
@ -74,12 +76,28 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
squash_output=False,
|
||||
)
|
||||
|
||||
# Convert [dict()] to dict() as shared network are deprecated
|
||||
if isinstance(net_arch, list) and len(net_arch) > 0:
|
||||
if isinstance(net_arch[0], dict):
|
||||
warnings.warn(
|
||||
(
|
||||
"As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
|
||||
"you should now pass directly a dictionary and not a list "
|
||||
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
|
||||
),
|
||||
)
|
||||
net_arch = net_arch[0]
|
||||
else:
|
||||
# Note: deprecation warning will be emitted
|
||||
# by the MlpExtractor constructor
|
||||
pass
|
||||
|
||||
# Default network architecture, from stable-baselines
|
||||
if net_arch is None:
|
||||
if features_extractor_class == NatureCNN:
|
||||
net_arch = []
|
||||
else:
|
||||
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
|
||||
net_arch = dict(pi=[64, 64], vf=[64, 64])
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
|
|
@ -95,7 +113,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
self.pi_features_extractor = self.features_extractor
|
||||
self.vf_features_extractor = self.make_features_extractor()
|
||||
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
|
||||
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
|
||||
# TODO(antonin): update the check once we change net_arch behavior
|
||||
if isinstance(net_arch, list) and len(net_arch) > 0:
|
||||
raise ValueError(
|
||||
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
|
||||
)
|
||||
|
|
@ -382,7 +401,8 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
# TODO(antonin): update type annotation when we remove shared network support
|
||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
|
|
@ -436,7 +456,8 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
|||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
# TODO(antonin): update type annotation when we remove shared network support
|
||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
# TODO(antonin): update type annotation when we remove shared network support
|
||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
|
|
@ -475,7 +476,8 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
# TODO(antonin): update type annotation when we remove shared network support
|
||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
|
|
@ -565,7 +567,8 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
# TODO(antonin): update type annotation when we remove shared network support
|
||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import time
|
|||
from collections import deque
|
||||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a11
|
||||
1.7.0a12
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.7.0a11",
|
||||
"stable_baselines3>=1.7.0a12",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ def test_dict_spaces(model_class, channel_last):
|
|||
kwargs = dict(
|
||||
n_steps=128,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[dict(pi=[32], vf=[32])],
|
||||
net_arch=dict(pi=[32], vf=[32]),
|
||||
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||
),
|
||||
)
|
||||
|
|
@ -191,7 +191,7 @@ def test_dict_vec_framestack(model_class, channel_last):
|
|||
kwargs = dict(
|
||||
n_steps=128,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[dict(pi=[32], vf=[32])],
|
||||
net_arch=dict(pi=[32], vf=[32]),
|
||||
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||
),
|
||||
)
|
||||
|
|
@ -234,7 +234,7 @@ def test_vec_normalize(model_class):
|
|||
kwargs = dict(
|
||||
n_steps=128,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[dict(pi=[32], vf=[32])],
|
||||
net_arch=dict(pi=[32], vf=[32]),
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ def test_ppo_lstm_performance():
|
|||
max_grad_norm=1,
|
||||
gae_lambda=0.98,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[dict(vf=[64])],
|
||||
net_arch=dict(vf=[64], pi=[]),
|
||||
lstm_hidden_size=64,
|
||||
ortho_init=False,
|
||||
enable_critic_lstm=True,
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ def test_trpo_params():
|
|||
use_sde=True,
|
||||
sub_sampling_factor=4,
|
||||
seed=0,
|
||||
policy_kwargs=dict(net_arch=[dict(pi=[32], vf=[32])]),
|
||||
policy_kwargs=dict(net_arch=dict(pi=[32], vf=[32])),
|
||||
verbose=1,
|
||||
)
|
||||
model.learn(total_timesteps=500)
|
||||
|
|
|
|||
Loading…
Reference in New Issue