Upgrade Stable-Baselines3 (#19)

* Upgrade Stable-Baselines3

* Fix policy saving/loading
This commit is contained in:
Antonin RAFFIN 2021-02-27 18:17:22 +01:00 committed by GitHub
parent b15cc3dbcf
commit 74e60381a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 16 additions and 21 deletions

View File

@ -3,11 +3,12 @@
Changelog Changelog
========== ==========
Pre-Release 0.11.0a5 (WIP) Pre-Release 0.11.0 (2021-02-27)
------------------------------- -------------------------------
Breaking Changes: Breaking Changes:
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 0.11.0
New Features: New Features:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^

View File

@ -68,8 +68,8 @@ class QuantileNetwork(BasePolicy):
action = q_values.argmax(dim=1).reshape(-1) action = q_values.argmax(dim=1).reshape(-1)
return action return action
def _get_data(self) -> Dict[str, Any]: def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_data() data = super()._get_constructor_parameters()
data.update( data.update(
dict( dict(
@ -176,8 +176,8 @@ class QRDQNPolicy(BasePolicy):
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
return self.quantile_net._predict(obs, deterministic=deterministic) return self.quantile_net._predict(obs, deterministic=deterministic)
def _get_data(self) -> Dict[str, Any]: def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_data() data = super()._get_constructor_parameters()
data.update( data.update(
dict( dict(

View File

@ -27,13 +27,12 @@ class QRDQN(OffPolicyAlgorithm):
:param batch_size: Minibatch size for each gradient update :param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
:param gamma: the discount factor :param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable. :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout :param gradient_steps: How many gradient steps to do after each rollout
(see ``train_freq`` and ``n_episodes_rollout``) (see ``train_freq`` and ``n_episodes_rollout``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout. during the rollout.
:param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity. at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
@ -66,7 +65,6 @@ class QRDQN(OffPolicyAlgorithm):
gamma: float = 0.99, gamma: float = 0.99,
train_freq: int = 4, train_freq: int = 4,
gradient_steps: int = 1, gradient_steps: int = 1,
n_episodes_rollout: int = -1,
optimize_memory_usage: bool = False, optimize_memory_usage: bool = False,
target_update_interval: int = 10000, target_update_interval: int = 10000,
exploration_fraction: float = 0.005, exploration_fraction: float = 0.005,
@ -94,7 +92,6 @@ class QRDQN(OffPolicyAlgorithm):
gamma, gamma,
train_freq, train_freq,
gradient_steps, gradient_steps,
n_episodes_rollout,
action_noise=None, # No action noise action_noise=None, # No action noise
policy_kwargs=policy_kwargs, policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log, tensorboard_log=tensorboard_log,

View File

@ -110,8 +110,8 @@ class Actor(BasePolicy):
self.mu = nn.Linear(last_layer_dim, action_dim) self.mu = nn.Linear(last_layer_dim, action_dim)
self.log_std = nn.Linear(last_layer_dim, action_dim) self.log_std = nn.Linear(last_layer_dim, action_dim)
def _get_data(self) -> Dict[str, Any]: def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_data() data = super()._get_constructor_parameters()
data.update( data.update(
dict( dict(
@ -374,8 +374,8 @@ class TQCPolicy(BasePolicy):
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
def _get_data(self) -> Dict[str, Any]: def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_data() data = super()._get_constructor_parameters()
data.update( data.update(
dict( dict(

View File

@ -30,10 +30,9 @@ class TQC(OffPolicyAlgorithm):
:param batch_size: Minibatch size for each gradient update :param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: the discount factor :param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient update after each step :param gradient_steps: How many gradient update after each step
:param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
Note that this cannot be used at the same time as ``train_freq``
:param action_noise: the action noise type (None by default), this can help :param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type. for hard exploration problem. Cf common.noise for the different action noise type.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
@ -74,7 +73,6 @@ class TQC(OffPolicyAlgorithm):
gamma: float = 0.99, gamma: float = 0.99,
train_freq: int = 1, train_freq: int = 1,
gradient_steps: int = 1, gradient_steps: int = 1,
n_episodes_rollout: int = -1,
action_noise: Optional[ActionNoise] = None, action_noise: Optional[ActionNoise] = None,
optimize_memory_usage: bool = False, optimize_memory_usage: bool = False,
ent_coef: Union[str, float] = "auto", ent_coef: Union[str, float] = "auto",
@ -105,7 +103,6 @@ class TQC(OffPolicyAlgorithm):
gamma, gamma,
train_freq, train_freq,
gradient_steps, gradient_steps,
n_episodes_rollout,
action_noise, action_noise,
policy_kwargs=policy_kwargs, policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log, tensorboard_log=tensorboard_log,

View File

@ -1 +1 @@
0.11.0a5 0.11.0

View File

@ -62,7 +62,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3[tests,docs]>=0.11.0a2", "stable_baselines3[tests,docs]>=0.11.0",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",