Update MaskablePPO docs (#150)

* MaskablePPO docs

Added a warning about possible crashes caused by chack_env in case of invalid actions.

* Reformat with black 23

* Rephrase note on action sampling

* Fix action noise

* Update changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Alex Pasquali 2023-02-13 14:31:49 +01:00 committed by GitHub
parent 6bc8e426bf
commit 376d9551de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 19 additions and 24 deletions

View File

@ -3,7 +3,7 @@
Changelog
==========
Release 1.8.0a2 (WIP)
Release 1.8.0a4 (WIP)
--------------------------
Breaking Changes:
@ -23,6 +23,10 @@ Deprecations:
Others:
^^^^^^^
Documentation:
^^^^^^^^^^^^^^
- Added warning about potential crashes caused by ``check_env`` in the ``MaskablePPO`` docs (@AlexPasqua)
Release 1.7.0 (2023-01-10)
--------------------------

View File

@ -83,6 +83,11 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker``
to specify the name (see `PR #25 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/25>`_):
.. note::
If you are using a custom environment and you want to debug it with ``check_env``,
it will execute the method ``step`` passing a random action to it (using ``action_space.sample()``),
without taking into account the invalid actions mask (see `issue #145 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/145>`_).
.. code-block:: python

View File

@ -71,7 +71,6 @@ class ARS(BaseAlgorithm):
device: Union[th.device, str] = "cpu",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
@ -186,7 +185,6 @@ class ARS(BaseAlgorithm):
results = async_eval.get_results()
for weights_idx, (episode_rewards, episode_lengths) in results:
# Update reward to cancel out alive bonus if needed
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
batch_steps += np.sum(episode_lengths)
@ -214,7 +212,6 @@ class ARS(BaseAlgorithm):
else:
# Single process, synchronous version
for weights_idx in range(self.pop_size):
# Load current candidate weights
train_policy.load_from_vector(candidate_weights[weights_idx].cpu())
# Evaluate the candidate

View File

@ -30,7 +30,6 @@ class ARSPolicy(BasePolicy):
with_bias: bool = True,
squash_output: bool = True,
):
super().__init__(
observation_space,
action_space,
@ -103,7 +102,6 @@ class ARSLinearPolicy(ARSPolicy):
with_bias: bool = False,
squash_output: bool = False,
):
super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output)

View File

@ -185,7 +185,6 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)
@ -205,7 +204,6 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
return MaskableDictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),

View File

@ -36,11 +36,9 @@ class MaskableEvalCallback(EvalCallback):
self.use_masking = use_masking
def _on_step(self) -> bool:
continue_training = True
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
# Sync training and eval env if there is VecNormalize
if self.model.get_vec_normalize_env() is not None:
try:

View File

@ -108,7 +108,6 @@ def evaluate_policy( # noqa: C901
current_lengths += 1
for i in range(n_envs):
if episode_counts[i] < episode_count_targets[i]:
# unpack values so that the callback can access the local variables
reward = rewards[i]
done = dones[i]

View File

@ -57,7 +57,6 @@ class MaskableActorCriticPolicy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in Adam optimizer

View File

@ -462,7 +462,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if continue_training is False:

View File

@ -123,7 +123,6 @@ class QRDQNPolicy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space,
action_space,

View File

@ -91,7 +91,6 @@ class QRDQN(OffPolicyAlgorithm):
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
@ -255,7 +254,6 @@ class QRDQN(OffPolicyAlgorithm):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfQRDQN:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View File

@ -101,7 +101,6 @@ class TQC(OffPolicyAlgorithm):
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
@ -292,7 +291,6 @@ class TQC(OffPolicyAlgorithm):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfTQC:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View File

@ -99,7 +99,6 @@ class TRPO(OnPolicyAlgorithm):
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
@ -235,7 +234,6 @@ class TRPO(OnPolicyAlgorithm):
# This will only loop once (get all data in one go)
for rollout_data in self.rollout_buffer.get(batch_size=None):
# Optional: sub-sample data for faster computation
if self.sub_sampling_factor > 1:
rollout_data = RolloutBufferSamples(
@ -308,7 +306,6 @@ class TRPO(OnPolicyAlgorithm):
with th.no_grad():
# Line-search (backtracking)
for _ in range(self.line_search_max_iter):
start_idx = 0
# Applying the scaled step direction
for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape):
@ -408,7 +405,6 @@ class TRPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfTRPO:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View File

@ -1 +1 @@
1.8.0a2
1.8.0a4

View File

@ -1,3 +1,4 @@
import numpy as np
import pytest
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.noise import NormalActionNoise
@ -25,7 +26,13 @@ def test_deterministic_training_common(algo):
ars_multi = False
if algo in [TQC]:
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
kwargs.update(
{
"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)),
"learning_starts": 100,
"train_freq": 4,
}
)
else:
if algo == QRDQN:
env_id = "CartPole-v1"