Add `with_bias` parameter to `ARSPolicy` and fix `sb3_contrib/ars/policies.py` type hint (#122)

* Update contribution.md

* New loop struct to make mypy happy

* Update setup.cfg

* Update changelog

* fix squash_output = False in ARS policy

* Add with_bias parameter to ARSPolicy

* Make ARSLinearPolicy a special case of ARSPolicy

* Remove ars_policy from mypy exclude

* Update changelog

* Update SB3 version

* Fix to save ARS linear policy saved with sb3-contrib < 1.7.0

* Fix test

* Turn docstring into comment

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
Quentin Gallouédec 2022-12-12 13:22:09 +01:00 committed by GitHub
parent 9a728513da
commit 6b23c6cfe3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 36 additions and 17 deletions

View File

@ -3,7 +3,7 @@
Changelog
==========
Release 1.7.0a2 (WIP)
Release 1.7.0a4 (WIP)
--------------------------
Breaking Changes:
@ -11,10 +11,12 @@ Breaking Changes:
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
please use an ``EvalCallback`` instead
- Removed deprecated ``sde_net_arch`` parameter
- Upgraded to Stable-Baselines3 >= 1.7.0a4
New Features:
^^^^^^^^^^^^^
- Introduced mypy type checking
- Added ``with_bias`` parameter to ``ARSPolicy``
Bug Fixes:
^^^^^^^^^^
@ -29,6 +31,7 @@ Others:
- Fixed flake8 config
- Fixed ``sb3_contrib/common/utils.py`` type hint
- Fixed ``sb3_contrib/common/recurrent/type_aliases.py`` type hint
- Fixed ``sb3_contrib/ars/policies.py`` type hint
- Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong)
- Removed ignores on Flake8 F401 (@ZikangXiong)
- Upgraded GitHub CI/setup-python to v4 and checkout to v3

View File

@ -13,6 +13,7 @@ from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
@ -348,3 +349,25 @@ class ARS(BaseAlgorithm):
callback.on_training_end()
return self
def set_parameters(
self,
load_path_or_dict: Union[str, Dict[str, Dict]],
exact_match: bool = True,
device: Union[th.device, str] = "auto",
) -> None:
# Patched set_parameters() to handle ARS linear policy saved with sb3-contrib < 1.7.0
params = None
if isinstance(load_path_or_dict, dict):
params = load_path_or_dict
else:
_, params, _ = load_from_zip_file(load_path_or_dict, device=device)
# Patch to load LinearPolicy saved using sb3-contrib < 1.7.0
# See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/122#issuecomment-1331981230
for name in {"weight", "bias"}:
if f"action_net.{name}" in params.get("policy", {}):
params["policy"][f"action_net.0.{name}"] = params["policy"][f"action_net.{name}"]
del params["policy"][f"action_net.{name}"]
super().set_parameters(params, exact_match=exact_match)

View File

@ -16,6 +16,7 @@ class ARSPolicy(BasePolicy):
:param action_space: The action space of the environment
:param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes.
:param activation_fn: Activation function
:param with_bias: If set to False, the layers will not learn an additive bias
:param squash_output: For continuous actions, whether the output is squashed
or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
"""
@ -26,6 +27,7 @@ class ARSPolicy(BasePolicy):
action_space: gym.spaces.Space,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
with_bias: bool = True,
squash_output: bool = True,
):
@ -45,9 +47,11 @@ class ARSPolicy(BasePolicy):
if isinstance(action_space, gym.spaces.Box):
action_dim = get_action_dim(action_space)
actor_net = create_mlp(self.features_dim, action_dim, net_arch, activation_fn, squash_output=True)
actor_net = create_mlp(
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
)
elif isinstance(action_space, gym.spaces.Discrete):
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn)
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
else:
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
@ -98,17 +102,7 @@ class ARSLinearPolicy(ARSPolicy):
squash_output: bool = False,
):
super().__init__(observation_space, action_space, squash_output=squash_output)
if isinstance(action_space, gym.spaces.Box):
action_dim = get_action_dim(action_space)
self.action_net = nn.Linear(self.features_dim, action_dim, bias=with_bias)
if squash_output:
self.action_net = nn.Sequential(self.action_net, nn.Tanh())
elif isinstance(action_space, gym.spaces.Discrete):
self.action_net = nn.Linear(self.features_dim, action_space.n, bias=with_bias)
else:
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output)
MlpPolicy = ARSPolicy

View File

@ -1 +1 @@
1.7.0a2
1.7.0a4

View File

@ -26,7 +26,6 @@ show_error_codes = True
exclude = (?x)(
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
| sb3_contrib/ars/ars.py$
| sb3_contrib/ars/policies.py$
| sb3_contrib/qrdqn/qrdqn.py$
| sb3_contrib/qrdqn/policies.py$
| sb3_contrib/common/recurrent/policies.py$

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.7.0a0",
"stable_baselines3>=1.7.0a4",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",