Add `with_bias` parameter to `ARSPolicy` and fix `sb3_contrib/ars/policies.py` type hint (#122)
* Update contribution.md * New loop struct to make mypy happy * Update setup.cfg * Update changelog * fix squash_output = False in ARS policy * Add with_bias parameter to ARSPolicy * Make ARSLinearPolicy a special case of ARSPolicy * Remove ars_policy from mypy exclude * Update changelog * Update SB3 version * Fix to save ARS linear policy saved with sb3-contrib < 1.7.0 * Fix test * Turn docstring into comment Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
parent
9a728513da
commit
6b23c6cfe3
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.7.0a2 (WIP)
|
Release 1.7.0a4 (WIP)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -11,10 +11,12 @@ Breaking Changes:
|
||||||
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
||||||
please use an ``EvalCallback`` instead
|
please use an ``EvalCallback`` instead
|
||||||
- Removed deprecated ``sde_net_arch`` parameter
|
- Removed deprecated ``sde_net_arch`` parameter
|
||||||
|
- Upgraded to Stable-Baselines3 >= 1.7.0a4
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
- Introduced mypy type checking
|
- Introduced mypy type checking
|
||||||
|
- Added ``with_bias`` parameter to ``ARSPolicy``
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
@ -29,6 +31,7 @@ Others:
|
||||||
- Fixed flake8 config
|
- Fixed flake8 config
|
||||||
- Fixed ``sb3_contrib/common/utils.py`` type hint
|
- Fixed ``sb3_contrib/common/utils.py`` type hint
|
||||||
- Fixed ``sb3_contrib/common/recurrent/type_aliases.py`` type hint
|
- Fixed ``sb3_contrib/common/recurrent/type_aliases.py`` type hint
|
||||||
|
- Fixed ``sb3_contrib/ars/policies.py`` type hint
|
||||||
- Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong)
|
- Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong)
|
||||||
- Removed ignores on Flake8 F401 (@ZikangXiong)
|
- Removed ignores on Flake8 F401 (@ZikangXiong)
|
||||||
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
|
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from stable_baselines3.common.base_class import BaseAlgorithm
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
|
from stable_baselines3.common.save_util import load_from_zip_file
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
|
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
|
||||||
|
|
||||||
|
|
@ -348,3 +349,25 @@ class ARS(BaseAlgorithm):
|
||||||
callback.on_training_end()
|
callback.on_training_end()
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def set_parameters(
|
||||||
|
self,
|
||||||
|
load_path_or_dict: Union[str, Dict[str, Dict]],
|
||||||
|
exact_match: bool = True,
|
||||||
|
device: Union[th.device, str] = "auto",
|
||||||
|
) -> None:
|
||||||
|
# Patched set_parameters() to handle ARS linear policy saved with sb3-contrib < 1.7.0
|
||||||
|
params = None
|
||||||
|
if isinstance(load_path_or_dict, dict):
|
||||||
|
params = load_path_or_dict
|
||||||
|
else:
|
||||||
|
_, params, _ = load_from_zip_file(load_path_or_dict, device=device)
|
||||||
|
|
||||||
|
# Patch to load LinearPolicy saved using sb3-contrib < 1.7.0
|
||||||
|
# See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/122#issuecomment-1331981230
|
||||||
|
for name in {"weight", "bias"}:
|
||||||
|
if f"action_net.{name}" in params.get("policy", {}):
|
||||||
|
params["policy"][f"action_net.0.{name}"] = params["policy"][f"action_net.{name}"]
|
||||||
|
del params["policy"][f"action_net.{name}"]
|
||||||
|
|
||||||
|
super().set_parameters(params, exact_match=exact_match)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ class ARSPolicy(BasePolicy):
|
||||||
:param action_space: The action space of the environment
|
:param action_space: The action space of the environment
|
||||||
:param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes.
|
:param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes.
|
||||||
:param activation_fn: Activation function
|
:param activation_fn: Activation function
|
||||||
|
:param with_bias: If set to False, the layers will not learn an additive bias
|
||||||
:param squash_output: For continuous actions, whether the output is squashed
|
:param squash_output: For continuous actions, whether the output is squashed
|
||||||
or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
|
or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
|
||||||
"""
|
"""
|
||||||
|
|
@ -26,6 +27,7 @@ class ARSPolicy(BasePolicy):
|
||||||
action_space: gym.spaces.Space,
|
action_space: gym.spaces.Space,
|
||||||
net_arch: Optional[List[int]] = None,
|
net_arch: Optional[List[int]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
with_bias: bool = True,
|
||||||
squash_output: bool = True,
|
squash_output: bool = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -45,9 +47,11 @@ class ARSPolicy(BasePolicy):
|
||||||
|
|
||||||
if isinstance(action_space, gym.spaces.Box):
|
if isinstance(action_space, gym.spaces.Box):
|
||||||
action_dim = get_action_dim(action_space)
|
action_dim = get_action_dim(action_space)
|
||||||
actor_net = create_mlp(self.features_dim, action_dim, net_arch, activation_fn, squash_output=True)
|
actor_net = create_mlp(
|
||||||
|
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
|
||||||
|
)
|
||||||
elif isinstance(action_space, gym.spaces.Discrete):
|
elif isinstance(action_space, gym.spaces.Discrete):
|
||||||
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn)
|
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
||||||
|
|
||||||
|
|
@ -98,17 +102,7 @@ class ARSLinearPolicy(ARSPolicy):
|
||||||
squash_output: bool = False,
|
squash_output: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(observation_space, action_space, squash_output=squash_output)
|
super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output)
|
||||||
|
|
||||||
if isinstance(action_space, gym.spaces.Box):
|
|
||||||
action_dim = get_action_dim(action_space)
|
|
||||||
self.action_net = nn.Linear(self.features_dim, action_dim, bias=with_bias)
|
|
||||||
if squash_output:
|
|
||||||
self.action_net = nn.Sequential(self.action_net, nn.Tanh())
|
|
||||||
elif isinstance(action_space, gym.spaces.Discrete):
|
|
||||||
self.action_net = nn.Linear(self.features_dim, action_space.n, bias=with_bias)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
|
||||||
|
|
||||||
|
|
||||||
MlpPolicy = ARSPolicy
|
MlpPolicy = ARSPolicy
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.7.0a2
|
1.7.0a4
|
||||||
|
|
@ -26,7 +26,6 @@ show_error_codes = True
|
||||||
exclude = (?x)(
|
exclude = (?x)(
|
||||||
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
|
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
|
||||||
| sb3_contrib/ars/ars.py$
|
| sb3_contrib/ars/ars.py$
|
||||||
| sb3_contrib/ars/policies.py$
|
|
||||||
| sb3_contrib/qrdqn/qrdqn.py$
|
| sb3_contrib/qrdqn/qrdqn.py$
|
||||||
| sb3_contrib/qrdqn/policies.py$
|
| sb3_contrib/qrdqn/policies.py$
|
||||||
| sb3_contrib/common/recurrent/policies.py$
|
| sb3_contrib/common/recurrent/policies.py$
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.7.0a0",
|
"stable_baselines3>=1.7.0a4",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue