Update MaskablePPO docs (#150)
* MaskablePPO docs Added a warning about possible crashes caused by chack_env in case of invalid actions. * Reformat with black 23 * Rephrase note on action sampling * Fix action noise * Update changelog --------- Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
6bc8e426bf
commit
376d9551de
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.8.0a2 (WIP)
|
Release 1.8.0a4 (WIP)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -23,6 +23,10 @@ Deprecations:
|
||||||
Others:
|
Others:
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
Documentation:
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
- Added warning about potential crashes caused by ``check_env`` in the ``MaskablePPO`` docs (@AlexPasqua)
|
||||||
|
|
||||||
|
|
||||||
Release 1.7.0 (2023-01-10)
|
Release 1.7.0 (2023-01-10)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,11 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
|
||||||
If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker``
|
If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker``
|
||||||
to specify the name (see `PR #25 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/25>`_):
|
to specify the name (see `PR #25 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/25>`_):
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you are using a custom environment and you want to debug it with ``check_env``,
|
||||||
|
it will execute the method ``step`` passing a random action to it (using ``action_space.sample()``),
|
||||||
|
without taking into account the invalid actions mask (see `issue #145 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/145>`_).
|
||||||
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,6 @@ class ARS(BaseAlgorithm):
|
||||||
device: Union[th.device, str] = "cpu",
|
device: Union[th.device, str] = "cpu",
|
||||||
_init_setup_model: bool = True,
|
_init_setup_model: bool = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy,
|
policy,
|
||||||
env,
|
env,
|
||||||
|
|
@ -186,7 +185,6 @@ class ARS(BaseAlgorithm):
|
||||||
results = async_eval.get_results()
|
results = async_eval.get_results()
|
||||||
|
|
||||||
for weights_idx, (episode_rewards, episode_lengths) in results:
|
for weights_idx, (episode_rewards, episode_lengths) in results:
|
||||||
|
|
||||||
# Update reward to cancel out alive bonus if needed
|
# Update reward to cancel out alive bonus if needed
|
||||||
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
|
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
|
||||||
batch_steps += np.sum(episode_lengths)
|
batch_steps += np.sum(episode_lengths)
|
||||||
|
|
@ -214,7 +212,6 @@ class ARS(BaseAlgorithm):
|
||||||
else:
|
else:
|
||||||
# Single process, synchronous version
|
# Single process, synchronous version
|
||||||
for weights_idx in range(self.pop_size):
|
for weights_idx in range(self.pop_size):
|
||||||
|
|
||||||
# Load current candidate weights
|
# Load current candidate weights
|
||||||
train_policy.load_from_vector(candidate_weights[weights_idx].cpu())
|
train_policy.load_from_vector(candidate_weights[weights_idx].cpu())
|
||||||
# Evaluate the candidate
|
# Evaluate the candidate
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,6 @@ class ARSPolicy(BasePolicy):
|
||||||
with_bias: bool = True,
|
with_bias: bool = True,
|
||||||
squash_output: bool = True,
|
squash_output: bool = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
action_space,
|
action_space,
|
||||||
|
|
@ -103,7 +102,6 @@ class ARSLinearPolicy(ARSPolicy):
|
||||||
with_bias: bool = False,
|
with_bias: bool = False,
|
||||||
squash_output: bool = False,
|
squash_output: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output)
|
super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -185,7 +185,6 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
if not self.generator_ready:
|
if not self.generator_ready:
|
||||||
|
|
||||||
for key, obs in self.observations.items():
|
for key, obs in self.observations.items():
|
||||||
self.observations[key] = self.swap_and_flatten(obs)
|
self.observations[key] = self.swap_and_flatten(obs)
|
||||||
|
|
||||||
|
|
@ -205,7 +204,6 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
start_idx += batch_size
|
start_idx += batch_size
|
||||||
|
|
||||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
|
||||||
|
|
||||||
return MaskableDictRolloutBufferSamples(
|
return MaskableDictRolloutBufferSamples(
|
||||||
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||||
actions=self.to_torch(self.actions[batch_inds]),
|
actions=self.to_torch(self.actions[batch_inds]),
|
||||||
|
|
|
||||||
|
|
@ -36,11 +36,9 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
self.use_masking = use_masking
|
self.use_masking = use_masking
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
|
|
||||||
continue_training = True
|
continue_training = True
|
||||||
|
|
||||||
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
||||||
|
|
||||||
# Sync training and eval env if there is VecNormalize
|
# Sync training and eval env if there is VecNormalize
|
||||||
if self.model.get_vec_normalize_env() is not None:
|
if self.model.get_vec_normalize_env() is not None:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,6 @@ def evaluate_policy( # noqa: C901
|
||||||
current_lengths += 1
|
current_lengths += 1
|
||||||
for i in range(n_envs):
|
for i in range(n_envs):
|
||||||
if episode_counts[i] < episode_count_targets[i]:
|
if episode_counts[i] < episode_count_targets[i]:
|
||||||
|
|
||||||
# unpack values so that the callback can access the local variables
|
# unpack values so that the callback can access the local variables
|
||||||
reward = rewards[i]
|
reward = rewards[i]
|
||||||
done = dones[i]
|
done = dones[i]
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,6 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if optimizer_kwargs is None:
|
if optimizer_kwargs is None:
|
||||||
optimizer_kwargs = {}
|
optimizer_kwargs = {}
|
||||||
# Small values to avoid NaN in Adam optimizer
|
# Small values to avoid NaN in Adam optimizer
|
||||||
|
|
|
||||||
|
|
@ -462,7 +462,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
callback.on_training_start(locals(), globals())
|
callback.on_training_start(locals(), globals())
|
||||||
|
|
||||||
while self.num_timesteps < total_timesteps:
|
while self.num_timesteps < total_timesteps:
|
||||||
|
|
||||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
||||||
|
|
||||||
if continue_training is False:
|
if continue_training is False:
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,6 @@ class QRDQNPolicy(BasePolicy):
|
||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
action_space,
|
action_space,
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,6 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
device: Union[th.device, str] = "auto",
|
device: Union[th.device, str] = "auto",
|
||||||
_init_setup_model: bool = True,
|
_init_setup_model: bool = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy,
|
policy,
|
||||||
env,
|
env,
|
||||||
|
|
@ -255,7 +254,6 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> SelfQRDQN:
|
) -> SelfQRDQN:
|
||||||
|
|
||||||
return super().learn(
|
return super().learn(
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,6 @@ class TQC(OffPolicyAlgorithm):
|
||||||
device: Union[th.device, str] = "auto",
|
device: Union[th.device, str] = "auto",
|
||||||
_init_setup_model: bool = True,
|
_init_setup_model: bool = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy,
|
policy,
|
||||||
env,
|
env,
|
||||||
|
|
@ -292,7 +291,6 @@ class TQC(OffPolicyAlgorithm):
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> SelfTQC:
|
) -> SelfTQC:
|
||||||
|
|
||||||
return super().learn(
|
return super().learn(
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,6 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
device: Union[th.device, str] = "auto",
|
device: Union[th.device, str] = "auto",
|
||||||
_init_setup_model: bool = True,
|
_init_setup_model: bool = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy,
|
policy,
|
||||||
env,
|
env,
|
||||||
|
|
@ -235,7 +234,6 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
# This will only loop once (get all data in one go)
|
# This will only loop once (get all data in one go)
|
||||||
for rollout_data in self.rollout_buffer.get(batch_size=None):
|
for rollout_data in self.rollout_buffer.get(batch_size=None):
|
||||||
|
|
||||||
# Optional: sub-sample data for faster computation
|
# Optional: sub-sample data for faster computation
|
||||||
if self.sub_sampling_factor > 1:
|
if self.sub_sampling_factor > 1:
|
||||||
rollout_data = RolloutBufferSamples(
|
rollout_data = RolloutBufferSamples(
|
||||||
|
|
@ -308,7 +306,6 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Line-search (backtracking)
|
# Line-search (backtracking)
|
||||||
for _ in range(self.line_search_max_iter):
|
for _ in range(self.line_search_max_iter):
|
||||||
|
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
# Applying the scaled step direction
|
# Applying the scaled step direction
|
||||||
for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape):
|
for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape):
|
||||||
|
|
@ -408,7 +405,6 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> SelfTRPO:
|
) -> SelfTRPO:
|
||||||
|
|
||||||
return super().learn(
|
return super().learn(
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.8.0a2
|
1.8.0a4
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from stable_baselines3.common.env_util import make_vec_env
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
from stable_baselines3.common.noise import NormalActionNoise
|
from stable_baselines3.common.noise import NormalActionNoise
|
||||||
|
|
@ -25,7 +26,13 @@ def test_deterministic_training_common(algo):
|
||||||
ars_multi = False
|
ars_multi = False
|
||||||
|
|
||||||
if algo in [TQC]:
|
if algo in [TQC]:
|
||||||
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)),
|
||||||
|
"learning_starts": 100,
|
||||||
|
"train_freq": 4,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if algo == QRDQN:
|
if algo == QRDQN:
|
||||||
env_id = "CartPole-v1"
|
env_id = "CartPole-v1"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue