Add `with_bias` parameter to `ARSPolicy` and fix `sb3_contrib/ars/policies.py` type hint (#122)
* Update contribution.md * New loop struct to make mypy happy * Update setup.cfg * Update changelog * fix squash_output = False in ARS policy * Add with_bias parameter to ARSPolicy * Make ARSLinearPolicy a special case of ARSPolicy * Remove ars_policy from mypy exclude * Update changelog * Update SB3 version * Fix to save ARS linear policy saved with sb3-contrib < 1.7.0 * Fix test * Turn docstring into comment Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
parent
9a728513da
commit
6b23c6cfe3
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.7.0a2 (WIP)
|
||||
Release 1.7.0a4 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -11,10 +11,12 @@ Breaking Changes:
|
|||
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
|
||||
please use an ``EvalCallback`` instead
|
||||
- Removed deprecated ``sde_net_arch`` parameter
|
||||
- Upgraded to Stable-Baselines3 >= 1.7.0a4
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Introduced mypy type checking
|
||||
- Added ``with_bias`` parameter to ``ARSPolicy``
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -29,6 +31,7 @@ Others:
|
|||
- Fixed flake8 config
|
||||
- Fixed ``sb3_contrib/common/utils.py`` type hint
|
||||
- Fixed ``sb3_contrib/common/recurrent/type_aliases.py`` type hint
|
||||
- Fixed ``sb3_contrib/ars/policies.py`` type hint
|
||||
- Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong)
|
||||
- Removed ignores on Flake8 F401 (@ZikangXiong)
|
||||
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from stable_baselines3.common.base_class import BaseAlgorithm
|
|||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.save_util import load_from_zip_file
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
|
||||
|
||||
|
|
@ -348,3 +349,25 @@ class ARS(BaseAlgorithm):
|
|||
callback.on_training_end()
|
||||
|
||||
return self
|
||||
|
||||
def set_parameters(
|
||||
self,
|
||||
load_path_or_dict: Union[str, Dict[str, Dict]],
|
||||
exact_match: bool = True,
|
||||
device: Union[th.device, str] = "auto",
|
||||
) -> None:
|
||||
# Patched set_parameters() to handle ARS linear policy saved with sb3-contrib < 1.7.0
|
||||
params = None
|
||||
if isinstance(load_path_or_dict, dict):
|
||||
params = load_path_or_dict
|
||||
else:
|
||||
_, params, _ = load_from_zip_file(load_path_or_dict, device=device)
|
||||
|
||||
# Patch to load LinearPolicy saved using sb3-contrib < 1.7.0
|
||||
# See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/122#issuecomment-1331981230
|
||||
for name in {"weight", "bias"}:
|
||||
if f"action_net.{name}" in params.get("policy", {}):
|
||||
params["policy"][f"action_net.0.{name}"] = params["policy"][f"action_net.{name}"]
|
||||
del params["policy"][f"action_net.{name}"]
|
||||
|
||||
super().set_parameters(params, exact_match=exact_match)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class ARSPolicy(BasePolicy):
|
|||
:param action_space: The action space of the environment
|
||||
:param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes.
|
||||
:param activation_fn: Activation function
|
||||
:param with_bias: If set to False, the layers will not learn an additive bias
|
||||
:param squash_output: For continuous actions, whether the output is squashed
|
||||
or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
|
||||
"""
|
||||
|
|
@ -26,6 +27,7 @@ class ARSPolicy(BasePolicy):
|
|||
action_space: gym.spaces.Space,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
with_bias: bool = True,
|
||||
squash_output: bool = True,
|
||||
):
|
||||
|
||||
|
|
@ -45,9 +47,11 @@ class ARSPolicy(BasePolicy):
|
|||
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
action_dim = get_action_dim(action_space)
|
||||
actor_net = create_mlp(self.features_dim, action_dim, net_arch, activation_fn, squash_output=True)
|
||||
actor_net = create_mlp(
|
||||
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
|
||||
)
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn)
|
||||
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
|
||||
else:
|
||||
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
||||
|
||||
|
|
@ -98,17 +102,7 @@ class ARSLinearPolicy(ARSPolicy):
|
|||
squash_output: bool = False,
|
||||
):
|
||||
|
||||
super().__init__(observation_space, action_space, squash_output=squash_output)
|
||||
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
action_dim = get_action_dim(action_space)
|
||||
self.action_net = nn.Linear(self.features_dim, action_dim, bias=with_bias)
|
||||
if squash_output:
|
||||
self.action_net = nn.Sequential(self.action_net, nn.Tanh())
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
self.action_net = nn.Linear(self.features_dim, action_space.n, bias=with_bias)
|
||||
else:
|
||||
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
||||
super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output)
|
||||
|
||||
|
||||
MlpPolicy = ARSPolicy
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a2
|
||||
1.7.0a4
|
||||
|
|
@ -26,7 +26,6 @@ show_error_codes = True
|
|||
exclude = (?x)(
|
||||
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
|
||||
| sb3_contrib/ars/ars.py$
|
||||
| sb3_contrib/ars/policies.py$
|
||||
| sb3_contrib/qrdqn/qrdqn.py$
|
||||
| sb3_contrib/qrdqn/policies.py$
|
||||
| sb3_contrib/common/recurrent/policies.py$
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.7.0a0",
|
||||
"stable_baselines3>=1.7.0a4",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
Loading…
Reference in New Issue