Compare commits

...

10 Commits

Author SHA1 Message Date
Antonin RAFFIN adfbeb1b01 Fix typo in changelog (#237) 2024-05-06 14:20:28 +01:00
Antonin RAFFIN 17cd797d5c Release v2.3.0 (#236) 2024-05-06 14:20:28 +01:00
Antonin RAFFIN 34eceaf382 Log success rate for PPO variants (#235) 2024-05-06 14:20:28 +01:00
Antonin RAFFIN 89d0113037 Update ruff and SB3 dependencies (#232) 2024-05-06 14:20:28 +01:00
Antonin RAFFIN 7dd6c39fba Fix MaskablePPO type annotations (#233)
* Fix `sb3_contrib/common/maskable/callbacks.py` type annotations

* Fix `sb3_contrib/common/vec_env/async_eval.py` type annotations

* Fix `sb3_contrib/common/maskable/buffers.py` type hints

* Fix `sb3_contrib/common/maskable/distributions.py` type hints

* Fix `sb3_contrib/common/maskable/policies.py` type hints

* Fix `sb3_contrib/ppo_mask/ppo_mask.py` type hints

* Update changelog and fix type hints
2024-05-06 14:20:28 +01:00
Armand du Parc Locmaria cd31e89e26 Fix `train_freq` type annotation for TQC and QR-DQN (#229)
* fix train_freq type for tqc and qrdn

* fix typo

* Update changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2024-05-06 14:20:28 +01:00
Tjeerd Bakker bc3c0a9595 Add notes about MaskablePPO (#227)
* Add note about MaskableEvalCallback

* Update doc

---------

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2024-05-06 14:20:28 +01:00
Antonin RAFFIN 3f0c5088b3 Update QRDQN defaults (#225) 2024-05-06 14:20:28 +01:00
Rogério Júnior 1553b66ee4 Update `_process_sequence()` docstring (#219)
Small modification in the documentation of `_process_sequence` to reflect the correct order (`hidden_state`, `cell_state`) of `lstm_states`.

See order here: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
2024-05-06 14:20:28 +01:00
Antonin RAFFIN 94a5daab02 Update SB3 version (#217) 2024-05-06 14:20:28 +01:00
18 changed files with 206 additions and 222 deletions

View File

@ -12,23 +12,23 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero
format:
# Sort imports
ruff --select I ${LINT_PATHS} --fix
ruff check --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}
check-codestyle:
# Sort imports
ruff --select I ${LINT_PATHS}
ruff check --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}
commit-checks: format type lint
doc:

View File

@ -11,6 +11,7 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import datetime
import os
import sys
from typing import Dict
@ -44,7 +45,7 @@ with open(version_file) as file_handler:
# -- Project information -----------------------------------------------------
project = "Stable Baselines3 - Contrib"
copyright = "2023, Stable Baselines3"
copyright = f"2021-{datetime.date.today().year}, Stable Baselines3"
author = "Stable Baselines3 Contributors"
# The short X.Y version

View File

@ -35,6 +35,12 @@ MaskablePPO
Train a PPO with invalid action masking agent on a toy environment.
.. warning::
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
.. code-block:: python
from sb3_contrib import MaskablePPO

View File

@ -3,13 +3,55 @@
Changelog
==========
Release 2.3.0 (2024-03-31)
--------------------------
Release 2.2.0 (2023-11-16)
**New defaults hyperparameters for QR-DQN**
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 2.3.0
- The default ``learning_starts`` parameter of ``QRDQN`` have been changed to be consistent with the other offpolicy algorithms
.. code-block:: python
# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
# model = QRDQN("MlpPolicy", env, learning_starts=50_000)
# SB3 >= 2.3.0:
model = QRDQN("MlpPolicy", env, learning_starts=100)
New Features:
^^^^^^^^^^^^^
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO
- Log success rate ``rollout/success_rate`` when available for on policy algorithms
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl)
- Fixed ``sb3_contrib/common/maskable/*.py`` type annotations
- Fixed ``sb3_contrib/ppo_mask/ppo_mask.py`` type annotations
- Fixed ``sb3_contrib/common/vec_env/async_eval.py`` type annotations
Documentation:
^^^^^^^^^^^^^^
- Add some additional notes about ``MaskablePPO`` (evaluation and multi-process) (@icheered)
Release 2.2.1 (2023-11-17)
--------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 2.2.0
- Upgraded to Stable-Baselines3 >= 2.2.1
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)
@ -512,4 +554,4 @@ Contributors:
-------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl

View File

@ -45,6 +45,16 @@ MultiBinary ✔️ ✔️
Dict ❌ ✔️
============= ====== ===========
.. warning::
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
.. warning::
In order to use ``SubprocVecEnv`` with ``MaskablePPO``, you must implement the ``action_masks`` inside the environment (``ActionMasker`` cannot be used).
You can have a look at the `built-in environments with invalid action masks <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/envs/invalid_actions_env.py>`_ to have a working example.
Example
-------
@ -58,6 +68,8 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.utils import get_action_masks
# This is a drop-in replacement for EvalCallback
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)

View File

@ -3,16 +3,18 @@
line-length = 127
# Assume Python 3.8
target-version = "py38"
[tool.ruff.lint]
select = ["E", "F", "B", "UP", "C90", "RUF"]
# B028: Ignore explicit stacklevel`
# RUF013: Too many false positives (implicit optional)
ignore = ["B028", "RUF013"]
[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# ClassVar, implicit optional check not needed for tests
"./tests/*.py"= ["RUF012", "RUF013"]
[tool.ruff.mccabe]
[tool.ruff.lint.mccabe]
# Unlike Flake8, ruff default to a complexity level of 10.
max-complexity = 15
@ -28,12 +30,6 @@ exclude = """(?x)(
| sb3_contrib/ars/ars.py$
| sb3_contrib/common/recurrent/policies.py$
| sb3_contrib/common/recurrent/buffers.py$
| sb3_contrib/common/maskable/distributions.py$
| sb3_contrib/common/maskable/callbacks.py$
| sb3_contrib/common/maskable/policies.py$
| sb3_contrib/common/maskable/buffers.py$
| sb3_contrib/common/vec_env/async_eval.py$
| sb3_contrib/ppo_mask/ppo_mask.py$
| tests/test_train_eval_mode.py$
)"""

View File

@ -18,7 +18,7 @@ class MaskableRolloutBufferSamples(NamedTuple):
action_masks: th.Tensor
class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples):
class MaskableDictRolloutBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
old_values: th.Tensor
@ -42,6 +42,8 @@ class MaskableRolloutBuffer(RolloutBuffer):
:param n_envs: Number of parallel environments
"""
action_masks: np.ndarray
def __init__(
self,
buffer_size: int,
@ -53,7 +55,6 @@ class MaskableRolloutBuffer(RolloutBuffer):
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
self.action_masks = None
def reset(self) -> None:
if isinstance(self.action_space, spaces.Discrete):
@ -61,6 +62,10 @@ class MaskableRolloutBuffer(RolloutBuffer):
elif isinstance(self.action_space, spaces.MultiDiscrete):
mask_dims = sum(self.action_space.nvec)
elif isinstance(self.action_space, spaces.MultiBinary):
assert isinstance(self.action_space.n, int), (
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
"You can flatten it instead."
)
mask_dims = 2 * self.action_space.n # One mask per binary outcome
else:
raise ValueError(f"Unsupported action space {type(self.action_space)}")
@ -79,7 +84,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: # type: ignore[override]
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
@ -105,7 +110,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: # type: ignore[override]
data = (
self.observations[batch_inds],
self.actions[batch_inds],
@ -143,17 +148,18 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
:param n_envs: Number of parallel environments
"""
action_masks: np.ndarray
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
self.action_masks = None
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
def reset(self) -> None:
@ -162,6 +168,10 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
elif isinstance(self.action_space, spaces.MultiDiscrete):
mask_dims = sum(self.action_space.nvec)
elif isinstance(self.action_space, spaces.MultiBinary):
assert isinstance(self.action_space.n, int), (
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
"You can flatten it instead."
)
mask_dims = 2 * self.action_space.n # One mask per binary outcome
else:
raise ValueError(f"Unsupported action space {type(self.action_space)}")
@ -180,7 +190,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: # type: ignore[override]
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
@ -203,7 +213,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: # type: ignore[override]
return MaskableDictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),

View File

@ -55,7 +55,7 @@ class MaskableEvalCallback(EvalCallback):
# Note that evaluate_policy() has been patched to support masking
episode_rewards, episode_lengths = evaluate_policy(
self.model,
self.model, # type: ignore[arg-type]
self.eval_env,
n_eval_episodes=self.n_eval_episodes,
render=self.render,
@ -67,6 +67,8 @@ class MaskableEvalCallback(EvalCallback):
)
if self.log_path is not None:
assert isinstance(episode_rewards, list)
assert isinstance(episode_lengths, list)
self.evaluations_timesteps.append(self.num_timesteps)
self.evaluations_results.append(episode_rewards)
self.evaluations_length.append(episode_lengths)
@ -87,7 +89,7 @@ class MaskableEvalCallback(EvalCallback):
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
self.last_mean_reward = mean_reward
self.last_mean_reward = float(mean_reward)
if self.verbose > 0:
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
@ -111,7 +113,7 @@ class MaskableEvalCallback(EvalCallback):
print("New best mean reward!")
if self.best_model_save_path is not None:
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
self.best_mean_reward = mean_reward
self.best_mean_reward = float(mean_reward)
# Trigger callback on new best model, if needed
if self.callback_on_new_best is not None:
continue_training = self.callback_on_new_best.on_step()

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, TypeVar
from typing import List, Optional, Tuple, TypeVar, Union
import numpy as np
import torch as th
@ -13,6 +13,7 @@ SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistributi
SelfMaskableMultiCategoricalDistribution = TypeVar(
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
)
MaybeMasks = Union[th.Tensor, np.ndarray, None]
class MaskableCategorical(Categorical):
@ -36,14 +37,14 @@ class MaskableCategorical(Categorical):
probs: Optional[th.Tensor] = None,
logits: Optional[th.Tensor] = None,
validate_args: Optional[bool] = None,
masks: Optional[np.ndarray] = None,
masks: MaybeMasks = None,
):
self.masks: Optional[th.Tensor] = None
super().__init__(probs, logits, validate_args)
self._original_logits = self.logits
self.apply_masking(masks)
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
"""
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
@ -84,7 +85,7 @@ class MaskableCategorical(Categorical):
class MaskableDistribution(Distribution, ABC):
@abstractmethod
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
"""
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.
@ -94,6 +95,13 @@ class MaskableDistribution(Distribution, ABC):
previously applied masking is removed, and the original logits are restored.
"""
@abstractmethod
def proba_distribution_net(self, *args, **kwargs) -> nn.Module:
"""Create the layers and parameters that represent the distribution.
Subclasses must define this, but the arguments and return type vary between
concrete classes."""
class MaskableCategoricalDistribution(MaskableDistribution):
"""
@ -154,7 +162,7 @@ class MaskableCategoricalDistribution(MaskableDistribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
assert self.distribution is not None, "Must set distribution parameters"
self.distribution.apply_masking(masks)
@ -192,7 +200,7 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
self.distributions = [
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1)
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, list(self.action_dims), dim=1)
]
return self
@ -229,18 +237,16 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
assert len(self.distributions) > 0, "Must set distribution parameters"
split_masks = [None] * len(self.distributions)
if masks is not None:
masks = th.as_tensor(masks)
masks_tensor = th.as_tensor(masks)
# Restructure shape to align with logits
masks = masks.view(-1, sum(self.action_dims))
masks_tensor = masks_tensor.view(-1, sum(self.action_dims))
# Then split columnwise for each discrete action
split_masks = th.split(masks, tuple(self.action_dims), dim=1)
split_masks = th.split(masks_tensor, list(self.action_dims), dim=1) # type: ignore[assignment]
for distribution, mask in zip(self.distributions, split_masks):
distribution.apply_masking(mask)
@ -268,10 +274,13 @@ def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistri
"""
if isinstance(action_space, spaces.Discrete):
return MaskableCategoricalDistribution(action_space.n)
return MaskableCategoricalDistribution(int(action_space.n))
elif isinstance(action_space, spaces.MultiDiscrete):
return MaskableMultiCategoricalDistribution(action_space.nvec)
return MaskableMultiCategoricalDistribution(list(action_space.nvec))
elif isinstance(action_space, spaces.MultiBinary):
assert isinstance(
action_space.n, int
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return MaskableBernoulliDistribution(action_space.n)
else:
raise NotImplementedError(

View File

@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import (
MlpExtractor,
NatureCNN,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from torch import nn
from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution
@ -141,8 +141,8 @@ class MaskableActorCriticPolicy(BasePolicy):
log_prob = distribution.log_prob(actions)
return actions, values, log_prob
def extract_features(
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
def extract_features( # type: ignore[override]
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
@ -233,7 +233,11 @@ class MaskableActorCriticPolicy(BasePolicy):
module.apply(partial(self.init_weights, gain=gain))
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.optimizer = self.optimizer_class(
self.parameters(),
lr=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
)
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
"""
@ -245,9 +249,9 @@ class MaskableActorCriticPolicy(BasePolicy):
action_logits = self.action_net(latent_pi)
return self.action_dist.proba_distribution(action_logits=action_logits)
def _predict(
def _predict( # type: ignore[override]
self,
observation: th.Tensor,
observation: PyTorchObs,
deterministic: bool = False,
action_masks: Optional[np.ndarray] = None,
) -> th.Tensor:
@ -284,35 +288,45 @@ class MaskableActorCriticPolicy(BasePolicy):
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)
observation, vectorized_env = self.obs_to_tensor(observation)
# Check for common mistake that the user does not mix Gym/VecEnv API
# Tuple obs are not supported by SB3, so we can safely do that check
if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
raise ValueError(
"You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
"You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
"vs `obs = vec_env.reset()` (SB3 VecEnv). "
"See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
)
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
with th.no_grad():
actions = self._predict(observation, deterministic=deterministic, action_masks=action_masks)
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
# Convert to numpy
actions = actions.cpu().numpy()
if isinstance(self.action_space, spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
assert isinstance(actions, np.ndarray)
actions = actions.squeeze(axis=0)
return actions, None
return actions, state # type: ignore[return-value]
def evaluate_actions(
self,
obs: th.Tensor,
actions: th.Tensor,
action_masks: Optional[np.ndarray] = None,
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
action_masks: Optional[th.Tensor] = None,
) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
"""
Evaluate actions according to the current policy,
given the observations.
@ -337,7 +351,7 @@ class MaskableActorCriticPolicy(BasePolicy):
values = self.value_net(latent_vf)
return values, log_prob, distribution.entropy()
def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
def get_distribution(self, obs: PyTorchObs, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
"""
Get the current policy distribution given the observations.
@ -352,7 +366,7 @@ class MaskableActorCriticPolicy(BasePolicy):
distribution.apply_masking(action_masks)
return distribution
def predict_values(self, obs: th.Tensor) -> th.Tensor:
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
"""
Get the estimated values according to the current policy given the observations.

View File

@ -170,7 +170,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
Do a forward pass in the LSTM network.
:param features: Input tensor
:param lstm_states: previous cell and hidden states of the LSTM
:param lstm_states: previous hidden and cell states of the LSTM, respectively
:param episode_starts: Indicates when a new episode starts,
in that case, we need to reset LSTM states.
:param lstm: LSTM object.

View File

@ -59,10 +59,11 @@ def _worker(
remote.send(vec_env.seed(seed=data))
elif cmd == "set_options":
# Note: the options will only be effective at the next reset
remote.send(vec_env.set_options(data))
remote.send(vec_env.set_options(data)) # type: ignore[func-returns-value]
elif cmd == "get_obs_rms":
remote.send(obs_rms)
elif cmd == "sync_obs_rms":
assert vec_normalize is not None, "Tried to call `sync_obs_rms` when not using VecNormalize"
vec_normalize.obs_rms = data
obs_rms = data
elif cmd == "close":
@ -130,7 +131,7 @@ class AsyncEval:
n_eval_episodes,
)
# daemon=True: if the main process crashes, we should not cause things to hang
process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined]
process.start()
self.processes.append(process)
work_remote.close()
@ -157,6 +158,10 @@ class AsyncEval:
:param seed: The seed for the pseudo-random generators.
:return:
"""
if seed is None:
# Do nothing
return []
for idx, remote in enumerate(self.remotes):
remote.send(("seed", seed + idx))
return [remote.recv() for remote in self.remotes]

View File

@ -1,18 +1,14 @@
import sys
import time
from collections import deque
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
from stable_baselines3.common.vec_env import VecEnv
from torch.nn import functional as F
@ -74,6 +70,8 @@ class MaskablePPO(OnPolicyAlgorithm):
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
policy: MaskableActorCriticPolicy # type: ignore[assignment]
rollout_buffer: MaskableRolloutBuffer # type: ignore[assignment]
def __init__(
self,
@ -91,6 +89,8 @@ class MaskablePPO(OnPolicyAlgorithm):
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
@ -101,7 +101,7 @@ class MaskablePPO(OnPolicyAlgorithm):
_init_setup_model: bool = True,
):
super().__init__(
policy,
policy, # type: ignore[arg-type]
env,
learning_rate=learning_rate,
n_steps=n_steps,
@ -112,6 +112,8 @@ class MaskablePPO(OnPolicyAlgorithm):
max_grad_norm=max_grad_norm,
use_sde=False,
sde_sample_freq=-1,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
@ -140,20 +142,24 @@ class MaskablePPO(OnPolicyAlgorithm):
self._setup_lr_schedule()
self.set_random_seed(self.seed)
buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer
self.policy = self.policy_class(
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space,
self.action_space,
self.lr_schedule,
**self.policy_kwargs, # pytype:disable=not-instantiable
**self.policy_kwargs,
)
self.policy = self.policy.to(self.device)
if not isinstance(self.policy, MaskableActorCriticPolicy):
raise ValueError("Policy must subclass MaskableActorCriticPolicy")
self.rollout_buffer = buffer_cls(
if self.rollout_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.rollout_buffer_class = MaskableDictRolloutBuffer
else:
self.rollout_buffer_class = MaskableRolloutBuffer
self.rollout_buffer = self.rollout_buffer_class( # type: ignore[assignment]
self.n_steps,
self.observation_space,
self.action_space,
@ -161,6 +167,7 @@ class MaskablePPO(OnPolicyAlgorithm):
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
**self.rollout_buffer_kwargs,
)
# Initialize schedules for policy/value clipping
@ -171,86 +178,6 @@ class MaskablePPO(OnPolicyAlgorithm):
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
def _init_callback(
self,
callback: MaybeCallback,
use_masking: bool = True,
progress_bar: bool = False,
) -> BaseCallback:
"""
:param callback: Callback(s) called at every step with state of the algorithm.
:param use_masking: Whether or not to use invalid action masks during evaluation
:param progress_bar: Display a progress bar using tqdm and rich.
:return: A hybrid callback calling `callback` and performing evaluation.
"""
# Convert a list of callbacks into a callback
if isinstance(callback, list):
callback = CallbackList(callback)
# Convert functional callback to object
if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback)
# Add progress bar callback
if progress_bar:
callback = CallbackList([callback, ProgressBarCallback()])
callback.init_callback(self)
return callback
def _setup_learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
use_masking: bool = True,
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on
:param callback: Callback(s) called at every step with state of the algorithm.
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:param use_masking: Whether or not to use invalid action masks during training
:param progress_bar: Display a progress bar using tqdm and rich.
:return:
"""
self.start_time = time.time_ns()
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
self.ep_info_buffer = deque(maxlen=self._stats_window_size)
self.ep_success_buffer = deque(maxlen=self._stats_window_size)
if reset_num_timesteps:
self.num_timesteps = 0
self._episode_num = 0
else:
# Make sure training timesteps are ahead of the internal counter
total_timesteps += self.num_timesteps
self._total_timesteps = total_timesteps
self._num_timesteps_at_start = self.num_timesteps
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset()
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
# Configure logger's outputs if no logger was passed
if not self._custom_logger:
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
# Create eval callback if needed
callback = self._init_callback(callback, use_masking, progress_bar)
return total_timesteps, callback
def collect_rollouts(
self,
env: VecEnv,
@ -312,7 +239,7 @@ class MaskablePPO(OnPolicyAlgorithm):
if not callback.on_step():
return False
self._update_info_buffer(infos)
self._update_info_buffer(infos, dones)
n_steps += 1
if isinstance(self.action_space, spaces.Discrete):
@ -341,14 +268,14 @@ class MaskablePPO(OnPolicyAlgorithm):
log_probs,
action_masks=action_masks,
)
self._last_obs = new_obs
self._last_obs = new_obs # type: ignore[assignment]
self._last_episode_starts = dones
with th.no_grad():
# Compute value for the last timestep
# Masking is not needed here, the choice of action doesn't matter.
# We only want the value of the current observation.
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
@ -356,9 +283,9 @@ class MaskablePPO(OnPolicyAlgorithm):
return True
def predict(
def predict( # type: ignore[override]
self,
observation: np.ndarray,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
@ -388,10 +315,10 @@ class MaskablePPO(OnPolicyAlgorithm):
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# Compute current clip range
clip_range = self.clip_range(self._current_progress_remaining)
clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
entropy_losses = []
pg_losses, value_losses = [], []
@ -499,7 +426,7 @@ class MaskablePPO(OnPolicyAlgorithm):
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
def learn( # type: ignore[override]
self: SelfMaskablePPO,
total_timesteps: int,
callback: MaybeCallback = None,
@ -516,12 +443,13 @@ class MaskablePPO(OnPolicyAlgorithm):
callback,
reset_num_timesteps,
tb_log_name,
use_masking,
progress_bar,
)
callback.on_training_start(locals(), globals())
assert self.env is not None
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
@ -533,16 +461,7 @@ class MaskablePPO(OnPolicyAlgorithm):
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
self._dump_logs(iteration)
self.train()

View File

@ -1,5 +1,3 @@
import sys
import time
from copy import deepcopy
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
@ -11,7 +9,7 @@ from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
from stable_baselines3.common.vec_env import VecEnv
from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer
@ -149,7 +147,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
self.action_space,
self.lr_schedule,
use_sde=self.use_sde,
**self.policy_kwargs, # pytype:disable=not-instantiable
**self.policy_kwargs,
)
self.policy = self.policy.to(self.device)
@ -260,7 +258,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
if not callback.on_step():
return False
self._update_info_buffer(infos)
self._update_info_buffer(infos, dones)
n_steps += 1
if isinstance(self.action_space, spaces.Discrete):
@ -453,42 +451,11 @@ class RecurrentPPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfRecurrentPPO:
iteration = 0
total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if not continue_training:
break
iteration += 1
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
self.train()
callback.on_training_end()
return self

View File

@ -20,7 +20,8 @@ class QRDQN(OffPolicyAlgorithm):
"""
Quantile Regression Deep Q-Network (QR-DQN)
Paper: https://arxiv.org/abs/1710.10044
Default hyperparameters are taken from the paper and are tuned for Atari games.
Default hyperparameters are taken from the paper and are tuned for Atari games
(except for the ``learning_starts`` parameter).
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
@ -77,11 +78,11 @@ class QRDQN(OffPolicyAlgorithm):
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 5e-5,
buffer_size: int = 1000000, # 1e6
learning_starts: int = 50000,
learning_starts: int = 100,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: int = 4,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,

View File

@ -88,7 +88,7 @@ class TQC(OffPolicyAlgorithm):
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: int = 1,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,

View File

@ -1 +1 @@
2.2.0
2.3.0

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.2.0,<3.0",
"stable_baselines3>=2.3.0,<3.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",