Deprecation of shared layers in `mlp_extractor` (#133)
* Deprecation of shared layers in mlp_extractor * Fix missing import * Reformat and update tests Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
7c4a249fa4
commit
b5aa9a47ce
|
|
@ -3,15 +3,22 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.7.0a11 (WIP)
|
Release 1.7.0a12 (WIP)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Shared layers in MLP policy (``mlp_extractor``) are now deprecated for PPO, A2C and TRPO.
|
||||||
|
This feature will be removed in SB3 v1.8.0 and the behavior of ``net_arch=[64, 64]``
|
||||||
|
will create **separate** networks with the same architecture, to be consistent with the off-policy algorithms.
|
||||||
|
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
||||||
please use an ``EvalCallback`` instead
|
please use an ``EvalCallback`` instead
|
||||||
- Removed deprecated ``sde_net_arch`` parameter
|
- Removed deprecated ``sde_net_arch`` parameter
|
||||||
- Upgraded to Stable-Baselines3 >= 1.7.0a11
|
- Upgraded to Stable-Baselines3 >= 1.7.0a12
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
@ -30,6 +37,7 @@ Bug Fixes:
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
|
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
|
||||||
|
- Deprecated shared layers in ``MlpExtractor`` (@AlexPasqua)
|
||||||
|
|
||||||
Others:
|
Others:
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
|
@ -46,7 +47,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
# TODO(antonin): update type annotation when we remove shared network support
|
||||||
|
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||||
|
|
@ -74,12 +76,28 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
squash_output=False,
|
squash_output=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Convert [dict()] to dict() as shared network are deprecated
|
||||||
|
if isinstance(net_arch, list) and len(net_arch) > 0:
|
||||||
|
if isinstance(net_arch[0], dict):
|
||||||
|
warnings.warn(
|
||||||
|
(
|
||||||
|
"As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
|
||||||
|
"you should now pass directly a dictionary and not a list "
|
||||||
|
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
net_arch = net_arch[0]
|
||||||
|
else:
|
||||||
|
# Note: deprecation warning will be emitted
|
||||||
|
# by the MlpExtractor constructor
|
||||||
|
pass
|
||||||
|
|
||||||
# Default network architecture, from stable-baselines
|
# Default network architecture, from stable-baselines
|
||||||
if net_arch is None:
|
if net_arch is None:
|
||||||
if features_extractor_class == NatureCNN:
|
if features_extractor_class == NatureCNN:
|
||||||
net_arch = []
|
net_arch = []
|
||||||
else:
|
else:
|
||||||
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
|
net_arch = dict(pi=[64, 64], vf=[64, 64])
|
||||||
|
|
||||||
self.net_arch = net_arch
|
self.net_arch = net_arch
|
||||||
self.activation_fn = activation_fn
|
self.activation_fn = activation_fn
|
||||||
|
|
@ -95,7 +113,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
self.pi_features_extractor = self.features_extractor
|
self.pi_features_extractor = self.features_extractor
|
||||||
self.vf_features_extractor = self.make_features_extractor()
|
self.vf_features_extractor = self.make_features_extractor()
|
||||||
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
|
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
|
||||||
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
|
# TODO(antonin): update the check once we change net_arch behavior
|
||||||
|
if isinstance(net_arch, list) and len(net_arch) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
|
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
|
||||||
)
|
)
|
||||||
|
|
@ -382,7 +401,8 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
# TODO(antonin): update type annotation when we remove shared network support
|
||||||
|
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||||
|
|
@ -436,7 +456,8 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
|
||||||
observation_space: spaces.Dict,
|
observation_space: spaces.Dict,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
# TODO(antonin): update type annotation when we remove shared network support
|
||||||
|
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
# TODO(antonin): update type annotation when we remove shared network support
|
||||||
|
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
|
|
@ -475,7 +476,8 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
# TODO(antonin): update type annotation when we remove shared network support
|
||||||
|
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
|
|
@ -565,7 +567,8 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy):
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
# TODO(antonin): update type annotation when we remove shared network support
|
||||||
|
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
ortho_init: bool = True,
|
ortho_init: bool = True,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.7.0a11
|
1.7.0a12
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.7.0a11",
|
"stable_baselines3>=1.7.0a12",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,7 @@ def test_dict_spaces(model_class, channel_last):
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
n_steps=128,
|
n_steps=128,
|
||||||
policy_kwargs=dict(
|
policy_kwargs=dict(
|
||||||
net_arch=[dict(pi=[32], vf=[32])],
|
net_arch=dict(pi=[32], vf=[32]),
|
||||||
features_extractor_kwargs=dict(cnn_output_dim=32),
|
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -191,7 +191,7 @@ def test_dict_vec_framestack(model_class, channel_last):
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
n_steps=128,
|
n_steps=128,
|
||||||
policy_kwargs=dict(
|
policy_kwargs=dict(
|
||||||
net_arch=[dict(pi=[32], vf=[32])],
|
net_arch=dict(pi=[32], vf=[32]),
|
||||||
features_extractor_kwargs=dict(cnn_output_dim=32),
|
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -234,7 +234,7 @@ def test_vec_normalize(model_class):
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
n_steps=128,
|
n_steps=128,
|
||||||
policy_kwargs=dict(
|
policy_kwargs=dict(
|
||||||
net_arch=[dict(pi=[32], vf=[32])],
|
net_arch=dict(pi=[32], vf=[32]),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -227,7 +227,7 @@ def test_ppo_lstm_performance():
|
||||||
max_grad_norm=1,
|
max_grad_norm=1,
|
||||||
gae_lambda=0.98,
|
gae_lambda=0.98,
|
||||||
policy_kwargs=dict(
|
policy_kwargs=dict(
|
||||||
net_arch=[dict(vf=[64])],
|
net_arch=dict(vf=[64], pi=[]),
|
||||||
lstm_hidden_size=64,
|
lstm_hidden_size=64,
|
||||||
ortho_init=False,
|
ortho_init=False,
|
||||||
enable_critic_lstm=True,
|
enable_critic_lstm=True,
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ def test_trpo_params():
|
||||||
use_sde=True,
|
use_sde=True,
|
||||||
sub_sampling_factor=4,
|
sub_sampling_factor=4,
|
||||||
seed=0,
|
seed=0,
|
||||||
policy_kwargs=dict(net_arch=[dict(pi=[32], vf=[32])]),
|
policy_kwargs=dict(net_arch=dict(pi=[32], vf=[32])),
|
||||||
verbose=1,
|
verbose=1,
|
||||||
)
|
)
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=500)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue