Upgrade Stable-Baselines3 (#19)
* Upgrade Stable-Baselines3 * Fix policy saving/loading
This commit is contained in:
parent
b15cc3dbcf
commit
74e60381a6
|
|
@ -3,11 +3,12 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Pre-Release 0.11.0a5 (WIP)
|
Pre-Release 0.11.0 (2021-02-27)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Upgraded to Stable-Baselines3 >= 0.11.0
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -68,8 +68,8 @@ class QuantileNetwork(BasePolicy):
|
||||||
action = q_values.argmax(dim=1).reshape(-1)
|
action = q_values.argmax(dim=1).reshape(-1)
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def _get_data(self) -> Dict[str, Any]:
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
data = super()._get_data()
|
data = super()._get_constructor_parameters()
|
||||||
|
|
||||||
data.update(
|
data.update(
|
||||||
dict(
|
dict(
|
||||||
|
|
@ -176,8 +176,8 @@ class QRDQNPolicy(BasePolicy):
|
||||||
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||||
return self.quantile_net._predict(obs, deterministic=deterministic)
|
return self.quantile_net._predict(obs, deterministic=deterministic)
|
||||||
|
|
||||||
def _get_data(self) -> Dict[str, Any]:
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
data = super()._get_data()
|
data = super()._get_constructor_parameters()
|
||||||
|
|
||||||
data.update(
|
data.update(
|
||||||
dict(
|
dict(
|
||||||
|
|
|
||||||
|
|
@ -27,13 +27,12 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
:param batch_size: Minibatch size for each gradient update
|
:param batch_size: Minibatch size for each gradient update
|
||||||
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
|
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
|
||||||
:param gamma: the discount factor
|
:param gamma: the discount factor
|
||||||
:param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable.
|
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
|
||||||
|
like ``(5, "step")`` or ``(2, "episode")``.
|
||||||
:param gradient_steps: How many gradient steps to do after each rollout
|
:param gradient_steps: How many gradient steps to do after each rollout
|
||||||
(see ``train_freq`` and ``n_episodes_rollout``)
|
(see ``train_freq`` and ``n_episodes_rollout``)
|
||||||
Set to ``-1`` means to do as many gradient steps as steps done in the environment
|
Set to ``-1`` means to do as many gradient steps as steps done in the environment
|
||||||
during the rollout.
|
during the rollout.
|
||||||
:param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
|
|
||||||
Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable.
|
|
||||||
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
||||||
at a cost of more complexity.
|
at a cost of more complexity.
|
||||||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||||
|
|
@ -66,7 +65,6 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
train_freq: int = 4,
|
train_freq: int = 4,
|
||||||
gradient_steps: int = 1,
|
gradient_steps: int = 1,
|
||||||
n_episodes_rollout: int = -1,
|
|
||||||
optimize_memory_usage: bool = False,
|
optimize_memory_usage: bool = False,
|
||||||
target_update_interval: int = 10000,
|
target_update_interval: int = 10000,
|
||||||
exploration_fraction: float = 0.005,
|
exploration_fraction: float = 0.005,
|
||||||
|
|
@ -94,7 +92,6 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
gamma,
|
gamma,
|
||||||
train_freq,
|
train_freq,
|
||||||
gradient_steps,
|
gradient_steps,
|
||||||
n_episodes_rollout,
|
|
||||||
action_noise=None, # No action noise
|
action_noise=None, # No action noise
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
|
|
|
||||||
|
|
@ -110,8 +110,8 @@ class Actor(BasePolicy):
|
||||||
self.mu = nn.Linear(last_layer_dim, action_dim)
|
self.mu = nn.Linear(last_layer_dim, action_dim)
|
||||||
self.log_std = nn.Linear(last_layer_dim, action_dim)
|
self.log_std = nn.Linear(last_layer_dim, action_dim)
|
||||||
|
|
||||||
def _get_data(self) -> Dict[str, Any]:
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
data = super()._get_data()
|
data = super()._get_constructor_parameters()
|
||||||
|
|
||||||
data.update(
|
data.update(
|
||||||
dict(
|
dict(
|
||||||
|
|
@ -374,8 +374,8 @@ class TQCPolicy(BasePolicy):
|
||||||
|
|
||||||
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||||
|
|
||||||
def _get_data(self) -> Dict[str, Any]:
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
data = super()._get_data()
|
data = super()._get_constructor_parameters()
|
||||||
|
|
||||||
data.update(
|
data.update(
|
||||||
dict(
|
dict(
|
||||||
|
|
|
||||||
|
|
@ -30,10 +30,9 @@ class TQC(OffPolicyAlgorithm):
|
||||||
:param batch_size: Minibatch size for each gradient update
|
:param batch_size: Minibatch size for each gradient update
|
||||||
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
|
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
|
||||||
:param gamma: the discount factor
|
:param gamma: the discount factor
|
||||||
:param train_freq: Update the model every ``train_freq`` steps.
|
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
|
||||||
|
like ``(5, "step")`` or ``(2, "episode")``.
|
||||||
:param gradient_steps: How many gradient update after each step
|
:param gradient_steps: How many gradient update after each step
|
||||||
:param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
|
|
||||||
Note that this cannot be used at the same time as ``train_freq``
|
|
||||||
:param action_noise: the action noise type (None by default), this can help
|
:param action_noise: the action noise type (None by default), this can help
|
||||||
for hard exploration problem. Cf common.noise for the different action noise type.
|
for hard exploration problem. Cf common.noise for the different action noise type.
|
||||||
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
||||||
|
|
@ -74,7 +73,6 @@ class TQC(OffPolicyAlgorithm):
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
train_freq: int = 1,
|
train_freq: int = 1,
|
||||||
gradient_steps: int = 1,
|
gradient_steps: int = 1,
|
||||||
n_episodes_rollout: int = -1,
|
|
||||||
action_noise: Optional[ActionNoise] = None,
|
action_noise: Optional[ActionNoise] = None,
|
||||||
optimize_memory_usage: bool = False,
|
optimize_memory_usage: bool = False,
|
||||||
ent_coef: Union[str, float] = "auto",
|
ent_coef: Union[str, float] = "auto",
|
||||||
|
|
@ -105,7 +103,6 @@ class TQC(OffPolicyAlgorithm):
|
||||||
gamma,
|
gamma,
|
||||||
train_freq,
|
train_freq,
|
||||||
gradient_steps,
|
gradient_steps,
|
||||||
n_episodes_rollout,
|
|
||||||
action_noise,
|
action_noise,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
0.11.0a5
|
0.11.0
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -62,7 +62,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3[tests,docs]>=0.11.0a2",
|
"stable_baselines3[tests,docs]>=0.11.0",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue