From 74e60381a6928384bac94def69da526488c30d19 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 27 Feb 2021 18:17:22 +0100 Subject: [PATCH] Upgrade Stable-Baselines3 (#19) * Upgrade Stable-Baselines3 * Fix policy saving/loading --- docs/misc/changelog.rst | 3 ++- sb3_contrib/qrdqn/policies.py | 8 ++++---- sb3_contrib/qrdqn/qrdqn.py | 7 ++----- sb3_contrib/tqc/policies.py | 8 ++++---- sb3_contrib/tqc/tqc.py | 7 ++----- sb3_contrib/version.txt | 2 +- setup.py | 2 +- 7 files changed, 16 insertions(+), 21 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index bc021e8..d9c2f00 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,11 +3,12 @@ Changelog ========== -Pre-Release 0.11.0a5 (WIP) +Pre-Release 0.11.0 (2021-02-27) ------------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Upgraded to Stable-Baselines3 >= 0.11.0 New Features: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index db95299..48e94b8 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -68,8 +68,8 @@ class QuantileNetwork(BasePolicy): action = q_values.argmax(dim=1).reshape(-1) return action - def _get_data(self) -> Dict[str, Any]: - data = super()._get_data() + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() data.update( dict( @@ -176,8 +176,8 @@ class QRDQNPolicy(BasePolicy): def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: return self.quantile_net._predict(obs, deterministic=deterministic) - def _get_data(self) -> Dict[str, Any]: - data = super()._get_data() + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() data.update( dict( diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index c2c2f94..046f3a3 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -27,13 +27,12 @@ class QRDQN(OffPolicyAlgorithm): :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update :param gamma: the discount factor - :param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable. + :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit + like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq`` and ``n_episodes_rollout``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. - :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes. - Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 @@ -66,7 +65,6 @@ class QRDQN(OffPolicyAlgorithm): gamma: float = 0.99, train_freq: int = 4, gradient_steps: int = 1, - n_episodes_rollout: int = -1, optimize_memory_usage: bool = False, target_update_interval: int = 10000, exploration_fraction: float = 0.005, @@ -94,7 +92,6 @@ class QRDQN(OffPolicyAlgorithm): gamma, train_freq, gradient_steps, - n_episodes_rollout, action_noise=None, # No action noise policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index 3807d44..07519b8 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -110,8 +110,8 @@ class Actor(BasePolicy): self.mu = nn.Linear(last_layer_dim, action_dim) self.log_std = nn.Linear(last_layer_dim, action_dim) - def _get_data(self) -> Dict[str, Any]: - data = super()._get_data() + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() data.update( dict( @@ -374,8 +374,8 @@ class TQCPolicy(BasePolicy): self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) - def _get_data(self) -> Dict[str, Any]: - data = super()._get_data() + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() data.update( dict( diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 5bab2bd..44eb555 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -30,10 +30,9 @@ class TQC(OffPolicyAlgorithm): :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) :param gamma: the discount factor - :param train_freq: Update the model every ``train_freq`` steps. + :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit + like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient update after each step - :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes. - Note that this cannot be used at the same time as ``train_freq`` :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer @@ -74,7 +73,6 @@ class TQC(OffPolicyAlgorithm): gamma: float = 0.99, train_freq: int = 1, gradient_steps: int = 1, - n_episodes_rollout: int = -1, action_noise: Optional[ActionNoise] = None, optimize_memory_usage: bool = False, ent_coef: Union[str, float] = "auto", @@ -105,7 +103,6 @@ class TQC(OffPolicyAlgorithm): gamma, train_freq, gradient_steps, - n_episodes_rollout, action_noise, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index f23fba9..d9df1bb 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -0.11.0a5 +0.11.0 diff --git a/setup.py b/setup.py index 93b499d..df3da6d 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3[tests,docs]>=0.11.0a2", + "stable_baselines3[tests,docs]>=0.11.0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin",