Update MaskablePPO docs (#150)

* MaskablePPO docs

Added a warning about possible crashes caused by chack_env in case of invalid actions.

* Reformat with black 23

* Rephrase note on action sampling

* Fix action noise

* Update changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Alex Pasquali 2023-02-13 14:31:49 +01:00 committed by GitHub
parent 6bc8e426bf
commit 376d9551de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 19 additions and 24 deletions

View File

@ -3,7 +3,7 @@
Changelog Changelog
========== ==========
Release 1.8.0a2 (WIP) Release 1.8.0a4 (WIP)
-------------------------- --------------------------
Breaking Changes: Breaking Changes:
@ -23,6 +23,10 @@ Deprecations:
Others: Others:
^^^^^^^ ^^^^^^^
Documentation:
^^^^^^^^^^^^^^
- Added warning about potential crashes caused by ``check_env`` in the ``MaskablePPO`` docs (@AlexPasqua)
Release 1.7.0 (2023-01-10) Release 1.7.0 (2023-01-10)
-------------------------- --------------------------

View File

@ -83,6 +83,11 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker`` If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker``
to specify the name (see `PR #25 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/25>`_): to specify the name (see `PR #25 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/25>`_):
.. note::
If you are using a custom environment and you want to debug it with ``check_env``,
it will execute the method ``step`` passing a random action to it (using ``action_space.sample()``),
without taking into account the invalid actions mask (see `issue #145 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/145>`_).
.. code-block:: python .. code-block:: python

View File

@ -71,7 +71,6 @@ class ARS(BaseAlgorithm):
device: Union[th.device, str] = "cpu", device: Union[th.device, str] = "cpu",
_init_setup_model: bool = True, _init_setup_model: bool = True,
): ):
super().__init__( super().__init__(
policy, policy,
env, env,
@ -186,7 +185,6 @@ class ARS(BaseAlgorithm):
results = async_eval.get_results() results = async_eval.get_results()
for weights_idx, (episode_rewards, episode_lengths) in results: for weights_idx, (episode_rewards, episode_lengths) in results:
# Update reward to cancel out alive bonus if needed # Update reward to cancel out alive bonus if needed
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths) candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
batch_steps += np.sum(episode_lengths) batch_steps += np.sum(episode_lengths)
@ -214,7 +212,6 @@ class ARS(BaseAlgorithm):
else: else:
# Single process, synchronous version # Single process, synchronous version
for weights_idx in range(self.pop_size): for weights_idx in range(self.pop_size):
# Load current candidate weights # Load current candidate weights
train_policy.load_from_vector(candidate_weights[weights_idx].cpu()) train_policy.load_from_vector(candidate_weights[weights_idx].cpu())
# Evaluate the candidate # Evaluate the candidate

View File

@ -30,7 +30,6 @@ class ARSPolicy(BasePolicy):
with_bias: bool = True, with_bias: bool = True,
squash_output: bool = True, squash_output: bool = True,
): ):
super().__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
@ -103,7 +102,6 @@ class ARSLinearPolicy(ARSPolicy):
with_bias: bool = False, with_bias: bool = False,
squash_output: bool = False, squash_output: bool = False,
): ):
super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output) super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output)

View File

@ -185,7 +185,6 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
indices = np.random.permutation(self.buffer_size * self.n_envs) indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data # Prepare the data
if not self.generator_ready: if not self.generator_ready:
for key, obs in self.observations.items(): for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs) self.observations[key] = self.swap_and_flatten(obs)
@ -205,7 +204,6 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
start_idx += batch_size start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
return MaskableDictRolloutBufferSamples( return MaskableDictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]), actions=self.to_torch(self.actions[batch_inds]),

View File

@ -36,11 +36,9 @@ class MaskableEvalCallback(EvalCallback):
self.use_masking = use_masking self.use_masking = use_masking
def _on_step(self) -> bool: def _on_step(self) -> bool:
continue_training = True continue_training = True
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
# Sync training and eval env if there is VecNormalize # Sync training and eval env if there is VecNormalize
if self.model.get_vec_normalize_env() is not None: if self.model.get_vec_normalize_env() is not None:
try: try:

View File

@ -108,7 +108,6 @@ def evaluate_policy( # noqa: C901
current_lengths += 1 current_lengths += 1
for i in range(n_envs): for i in range(n_envs):
if episode_counts[i] < episode_count_targets[i]: if episode_counts[i] < episode_count_targets[i]:
# unpack values so that the callback can access the local variables # unpack values so that the callback can access the local variables
reward = rewards[i] reward = rewards[i]
done = dones[i] done = dones[i]

View File

@ -57,7 +57,6 @@ class MaskableActorCriticPolicy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
): ):
if optimizer_kwargs is None: if optimizer_kwargs is None:
optimizer_kwargs = {} optimizer_kwargs = {}
# Small values to avoid NaN in Adam optimizer # Small values to avoid NaN in Adam optimizer

View File

@ -462,7 +462,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
callback.on_training_start(locals(), globals()) callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps: while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if continue_training is False: if continue_training is False:

View File

@ -123,7 +123,6 @@ class QRDQNPolicy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
): ):
super().__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,

View File

@ -91,7 +91,6 @@ class QRDQN(OffPolicyAlgorithm):
device: Union[th.device, str] = "auto", device: Union[th.device, str] = "auto",
_init_setup_model: bool = True, _init_setup_model: bool = True,
): ):
super().__init__( super().__init__(
policy, policy,
env, env,
@ -255,7 +254,6 @@ class QRDQN(OffPolicyAlgorithm):
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> SelfQRDQN: ) -> SelfQRDQN:
return super().learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
callback=callback, callback=callback,

View File

@ -101,7 +101,6 @@ class TQC(OffPolicyAlgorithm):
device: Union[th.device, str] = "auto", device: Union[th.device, str] = "auto",
_init_setup_model: bool = True, _init_setup_model: bool = True,
): ):
super().__init__( super().__init__(
policy, policy,
env, env,
@ -292,7 +291,6 @@ class TQC(OffPolicyAlgorithm):
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> SelfTQC: ) -> SelfTQC:
return super().learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
callback=callback, callback=callback,

View File

@ -99,7 +99,6 @@ class TRPO(OnPolicyAlgorithm):
device: Union[th.device, str] = "auto", device: Union[th.device, str] = "auto",
_init_setup_model: bool = True, _init_setup_model: bool = True,
): ):
super().__init__( super().__init__(
policy, policy,
env, env,
@ -235,7 +234,6 @@ class TRPO(OnPolicyAlgorithm):
# This will only loop once (get all data in one go) # This will only loop once (get all data in one go)
for rollout_data in self.rollout_buffer.get(batch_size=None): for rollout_data in self.rollout_buffer.get(batch_size=None):
# Optional: sub-sample data for faster computation # Optional: sub-sample data for faster computation
if self.sub_sampling_factor > 1: if self.sub_sampling_factor > 1:
rollout_data = RolloutBufferSamples( rollout_data = RolloutBufferSamples(
@ -308,7 +306,6 @@ class TRPO(OnPolicyAlgorithm):
with th.no_grad(): with th.no_grad():
# Line-search (backtracking) # Line-search (backtracking)
for _ in range(self.line_search_max_iter): for _ in range(self.line_search_max_iter):
start_idx = 0 start_idx = 0
# Applying the scaled step direction # Applying the scaled step direction
for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape): for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape):
@ -408,7 +405,6 @@ class TRPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> SelfTRPO: ) -> SelfTRPO:
return super().learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
callback=callback, callback=callback,

View File

@ -1 +1 @@
1.8.0a2 1.8.0a4

View File

@ -1,3 +1,4 @@
import numpy as np
import pytest import pytest
from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.noise import NormalActionNoise
@ -25,7 +26,13 @@ def test_deterministic_training_common(algo):
ars_multi = False ars_multi = False
if algo in [TQC]: if algo in [TQC]:
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4}) kwargs.update(
{
"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)),
"learning_starts": 100,
"train_freq": 4,
}
)
else: else:
if algo == QRDQN: if algo == QRDQN:
env_id = "CartPole-v1" env_id = "CartPole-v1"