Update MaskablePPO docs (#150)
* MaskablePPO docs Added a warning about possible crashes caused by chack_env in case of invalid actions. * Reformat with black 23 * Rephrase note on action sampling * Fix action noise * Update changelog --------- Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
6bc8e426bf
commit
376d9551de
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.8.0a2 (WIP)
|
||||
Release 1.8.0a4 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -23,6 +23,10 @@ Deprecations:
|
|||
Others:
|
||||
^^^^^^^
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Added warning about potential crashes caused by ``check_env`` in the ``MaskablePPO`` docs (@AlexPasqua)
|
||||
|
||||
|
||||
Release 1.7.0 (2023-01-10)
|
||||
--------------------------
|
||||
|
|
|
|||
|
|
@ -83,6 +83,11 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
|
|||
If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker``
|
||||
to specify the name (see `PR #25 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/25>`_):
|
||||
|
||||
.. note::
|
||||
If you are using a custom environment and you want to debug it with ``check_env``,
|
||||
it will execute the method ``step`` passing a random action to it (using ``action_space.sample()``),
|
||||
without taking into account the invalid actions mask (see `issue #145 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/145>`_).
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,6 @@ class ARS(BaseAlgorithm):
|
|||
device: Union[th.device, str] = "cpu",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
|
|
@ -186,7 +185,6 @@ class ARS(BaseAlgorithm):
|
|||
results = async_eval.get_results()
|
||||
|
||||
for weights_idx, (episode_rewards, episode_lengths) in results:
|
||||
|
||||
# Update reward to cancel out alive bonus if needed
|
||||
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
|
||||
batch_steps += np.sum(episode_lengths)
|
||||
|
|
@ -214,7 +212,6 @@ class ARS(BaseAlgorithm):
|
|||
else:
|
||||
# Single process, synchronous version
|
||||
for weights_idx in range(self.pop_size):
|
||||
|
||||
# Load current candidate weights
|
||||
train_policy.load_from_vector(candidate_weights[weights_idx].cpu())
|
||||
# Evaluate the candidate
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ class ARSPolicy(BasePolicy):
|
|||
with_bias: bool = True,
|
||||
squash_output: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
|
|
@ -103,7 +102,6 @@ class ARSLinearPolicy(ARSPolicy):
|
|||
with_bias: bool = False,
|
||||
squash_output: bool = False,
|
||||
):
|
||||
|
||||
super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -185,7 +185,6 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
if not self.generator_ready:
|
||||
|
||||
for key, obs in self.observations.items():
|
||||
self.observations[key] = self.swap_and_flatten(obs)
|
||||
|
||||
|
|
@ -205,7 +204,6 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
start_idx += batch_size
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
|
||||
|
||||
return MaskableDictRolloutBufferSamples(
|
||||
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||
actions=self.to_torch(self.actions[batch_inds]),
|
||||
|
|
|
|||
|
|
@ -36,11 +36,9 @@ class MaskableEvalCallback(EvalCallback):
|
|||
self.use_masking = use_masking
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
|
||||
continue_training = True
|
||||
|
||||
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
||||
|
||||
# Sync training and eval env if there is VecNormalize
|
||||
if self.model.get_vec_normalize_env() is not None:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -108,7 +108,6 @@ def evaluate_policy( # noqa: C901
|
|||
current_lengths += 1
|
||||
for i in range(n_envs):
|
||||
if episode_counts[i] < episode_count_targets[i]:
|
||||
|
||||
# unpack values so that the callback can access the local variables
|
||||
reward = rewards[i]
|
||||
done = dones[i]
|
||||
|
|
|
|||
|
|
@ -57,7 +57,6 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
# Small values to avoid NaN in Adam optimizer
|
||||
|
|
|
|||
|
|
@ -462,7 +462,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
||||
|
||||
if continue_training is False:
|
||||
|
|
|
|||
|
|
@ -123,7 +123,6 @@ class QRDQNPolicy(BasePolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
|
|
|
|||
|
|
@ -91,7 +91,6 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
|
|
@ -255,7 +254,6 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfQRDQN:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -101,7 +101,6 @@ class TQC(OffPolicyAlgorithm):
|
|||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
|
|
@ -292,7 +291,6 @@ class TQC(OffPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfTQC:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -99,7 +99,6 @@ class TRPO(OnPolicyAlgorithm):
|
|||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
|
|
@ -235,7 +234,6 @@ class TRPO(OnPolicyAlgorithm):
|
|||
|
||||
# This will only loop once (get all data in one go)
|
||||
for rollout_data in self.rollout_buffer.get(batch_size=None):
|
||||
|
||||
# Optional: sub-sample data for faster computation
|
||||
if self.sub_sampling_factor > 1:
|
||||
rollout_data = RolloutBufferSamples(
|
||||
|
|
@ -308,7 +306,6 @@ class TRPO(OnPolicyAlgorithm):
|
|||
with th.no_grad():
|
||||
# Line-search (backtracking)
|
||||
for _ in range(self.line_search_max_iter):
|
||||
|
||||
start_idx = 0
|
||||
# Applying the scaled step direction
|
||||
for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape):
|
||||
|
|
@ -408,7 +405,6 @@ class TRPO(OnPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfTRPO:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0a2
|
||||
1.8.0a4
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.noise import NormalActionNoise
|
||||
|
|
@ -25,7 +26,13 @@ def test_deterministic_training_common(algo):
|
|||
ars_multi = False
|
||||
|
||||
if algo in [TQC]:
|
||||
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
|
||||
kwargs.update(
|
||||
{
|
||||
"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)),
|
||||
"learning_starts": 100,
|
||||
"train_freq": 4,
|
||||
}
|
||||
)
|
||||
else:
|
||||
if algo == QRDQN:
|
||||
env_id = "CartPole-v1"
|
||||
|
|
|
|||
Loading…
Reference in New Issue