diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2bb377b..942c98a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.7.0a2 (WIP) +Release 1.7.0a4 (WIP) -------------------------- Breaking Changes: @@ -11,10 +11,12 @@ Breaking Changes: - Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters, please use an ``EvalCallback`` instead - Removed deprecated ``sde_net_arch`` parameter +- Upgraded to Stable-Baselines3 >= 1.7.0a4 New Features: ^^^^^^^^^^^^^ - Introduced mypy type checking +- Added ``with_bias`` parameter to ``ARSPolicy`` Bug Fixes: ^^^^^^^^^^ @@ -29,6 +31,7 @@ Others: - Fixed flake8 config - Fixed ``sb3_contrib/common/utils.py`` type hint - Fixed ``sb3_contrib/common/recurrent/type_aliases.py`` type hint +- Fixed ``sb3_contrib/ars/policies.py`` type hint - Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong) - Removed ignores on Flake8 F401 (@ZikangXiong) - Upgraded GitHub CI/setup-python to v4 and checkout to v3 diff --git a/sb3_contrib/ars/ars.py b/sb3_contrib/ars/ars.py index ce73698..ef498d1 100644 --- a/sb3_contrib/ars/ars.py +++ b/sb3_contrib/ars/ars.py @@ -13,6 +13,7 @@ from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.save_util import load_from_zip_file from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_schedule_fn, safe_mean @@ -348,3 +349,25 @@ class ARS(BaseAlgorithm): callback.on_training_end() return self + + def set_parameters( + self, + load_path_or_dict: Union[str, Dict[str, Dict]], + exact_match: bool = True, + device: Union[th.device, str] = "auto", + ) -> None: + # Patched set_parameters() to handle ARS linear policy saved with sb3-contrib < 1.7.0 + params = None + if isinstance(load_path_or_dict, dict): + params = load_path_or_dict + else: + _, params, _ = load_from_zip_file(load_path_or_dict, device=device) + + # Patch to load LinearPolicy saved using sb3-contrib < 1.7.0 + # See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/122#issuecomment-1331981230 + for name in {"weight", "bias"}: + if f"action_net.{name}" in params.get("policy", {}): + params["policy"][f"action_net.0.{name}"] = params["policy"][f"action_net.{name}"] + del params["policy"][f"action_net.{name}"] + + super().set_parameters(params, exact_match=exact_match) diff --git a/sb3_contrib/ars/policies.py b/sb3_contrib/ars/policies.py index 218cf0b..6ab17cc 100644 --- a/sb3_contrib/ars/policies.py +++ b/sb3_contrib/ars/policies.py @@ -16,6 +16,7 @@ class ARSPolicy(BasePolicy): :param action_space: The action space of the environment :param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes. :param activation_fn: Activation function + :param with_bias: If set to False, the layers will not learn an additive bias :param squash_output: For continuous actions, whether the output is squashed or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped. """ @@ -26,6 +27,7 @@ class ARSPolicy(BasePolicy): action_space: gym.spaces.Space, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, + with_bias: bool = True, squash_output: bool = True, ): @@ -45,9 +47,11 @@ class ARSPolicy(BasePolicy): if isinstance(action_space, gym.spaces.Box): action_dim = get_action_dim(action_space) - actor_net = create_mlp(self.features_dim, action_dim, net_arch, activation_fn, squash_output=True) + actor_net = create_mlp( + self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output + ) elif isinstance(action_space, gym.spaces.Discrete): - actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn) + actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias) else: raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.") @@ -98,17 +102,7 @@ class ARSLinearPolicy(ARSPolicy): squash_output: bool = False, ): - super().__init__(observation_space, action_space, squash_output=squash_output) - - if isinstance(action_space, gym.spaces.Box): - action_dim = get_action_dim(action_space) - self.action_net = nn.Linear(self.features_dim, action_dim, bias=with_bias) - if squash_output: - self.action_net = nn.Sequential(self.action_net, nn.Tanh()) - elif isinstance(action_space, gym.spaces.Discrete): - self.action_net = nn.Linear(self.features_dim, action_space.n, bias=with_bias) - else: - raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.") + super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output) MlpPolicy = ARSPolicy diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index b895285..3f38852 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.7.0a2 +1.7.0a4 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 54a7b44..3a7fa2a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,6 @@ show_error_codes = True exclude = (?x)( sb3_contrib/ppo_recurrent/ppo_recurrent.py$ | sb3_contrib/ars/ars.py$ - | sb3_contrib/ars/policies.py$ | sb3_contrib/qrdqn/qrdqn.py$ | sb3_contrib/qrdqn/policies.py$ | sb3_contrib/common/recurrent/policies.py$ diff --git a/setup.py b/setup.py index 14214f7..60fe1d3 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.7.0a0", + "stable_baselines3>=1.7.0a4", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin",