Compare commits
10 Commits
ebb74c44ec
...
adfbeb1b01
| Author | SHA1 | Date |
|---|---|---|
|
|
adfbeb1b01 | |
|
|
17cd797d5c | |
|
|
34eceaf382 | |
|
|
89d0113037 | |
|
|
7dd6c39fba | |
|
|
cd31e89e26 | |
|
|
bc3c0a9595 | |
|
|
3f0c5088b3 | |
|
|
1553b66ee4 | |
|
|
94a5daab02 |
10
Makefile
10
Makefile
|
|
@ -12,23 +12,23 @@ type: mypy
|
|||
lint:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
# see https://www.flake8rules.com/
|
||||
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
|
||||
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
|
||||
# exit-zero treats all errors as warnings.
|
||||
ruff ${LINT_PATHS} --exit-zero
|
||||
|
||||
ruff check ${LINT_PATHS} --exit-zero
|
||||
|
||||
format:
|
||||
# Sort imports
|
||||
ruff --select I ${LINT_PATHS} --fix
|
||||
ruff check --select I ${LINT_PATHS} --fix
|
||||
# Reformat using black
|
||||
black ${LINT_PATHS}
|
||||
|
||||
check-codestyle:
|
||||
# Sort imports
|
||||
ruff --select I ${LINT_PATHS}
|
||||
ruff check --select I ${LINT_PATHS}
|
||||
# Reformat using black
|
||||
black --check ${LINT_PATHS}
|
||||
|
||||
|
||||
commit-checks: format type lint
|
||||
|
||||
doc:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
|
@ -44,7 +45,7 @@ with open(version_file) as file_handler:
|
|||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "Stable Baselines3 - Contrib"
|
||||
copyright = "2023, Stable Baselines3"
|
||||
copyright = f"2021-{datetime.date.today().year}, Stable Baselines3"
|
||||
author = "Stable Baselines3 Contributors"
|
||||
|
||||
# The short X.Y version
|
||||
|
|
|
|||
|
|
@ -35,6 +35,12 @@ MaskablePPO
|
|||
|
||||
Train a PPO with invalid action masking agent on a toy environment.
|
||||
|
||||
.. warning::
|
||||
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
|
||||
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
|
||||
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sb3_contrib import MaskablePPO
|
||||
|
|
|
|||
|
|
@ -3,13 +3,55 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.3.0 (2024-03-31)
|
||||
--------------------------
|
||||
|
||||
Release 2.2.0 (2023-11-16)
|
||||
**New defaults hyperparameters for QR-DQN**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Upgraded to Stable-Baselines3 >= 2.3.0
|
||||
- The default ``learning_starts`` parameter of ``QRDQN`` have been changed to be consistent with the other offpolicy algorithms
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
|
||||
# model = QRDQN("MlpPolicy", env, learning_starts=50_000)
|
||||
# SB3 >= 2.3.0:
|
||||
model = QRDQN("MlpPolicy", env, learning_starts=100)
|
||||
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO
|
||||
- Log success rate ``rollout/success_rate`` when available for on policy algorithms
|
||||
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl)
|
||||
- Fixed ``sb3_contrib/common/maskable/*.py`` type annotations
|
||||
- Fixed ``sb3_contrib/ppo_mask/ppo_mask.py`` type annotations
|
||||
- Fixed ``sb3_contrib/common/vec_env/async_eval.py`` type annotations
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Add some additional notes about ``MaskablePPO`` (evaluation and multi-process) (@icheered)
|
||||
|
||||
|
||||
Release 2.2.1 (2023-11-17)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Upgraded to Stable-Baselines3 >= 2.2.0
|
||||
- Upgraded to Stable-Baselines3 >= 2.2.1
|
||||
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
|
||||
- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)
|
||||
|
||||
|
|
@ -512,4 +554,4 @@ Contributors:
|
|||
-------------
|
||||
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
||||
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher
|
||||
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl
|
||||
|
|
|
|||
|
|
@ -45,6 +45,16 @@ MultiBinary ✔️ ✔️
|
|||
Dict ❌ ✔️
|
||||
============= ====== ===========
|
||||
|
||||
.. warning::
|
||||
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
|
||||
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
|
||||
|
||||
|
||||
.. warning::
|
||||
In order to use ``SubprocVecEnv`` with ``MaskablePPO``, you must implement the ``action_masks`` inside the environment (``ActionMasker`` cannot be used).
|
||||
You can have a look at the `built-in environments with invalid action masks <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/envs/invalid_actions_env.py>`_ to have a working example.
|
||||
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
|
@ -58,6 +68,8 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
|
|||
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
|
||||
from sb3_contrib.common.maskable.evaluation import evaluate_policy
|
||||
from sb3_contrib.common.maskable.utils import get_action_masks
|
||||
# This is a drop-in replacement for EvalCallback
|
||||
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
||||
|
||||
|
||||
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
||||
|
|
|
|||
|
|
@ -3,16 +3,18 @@
|
|||
line-length = 127
|
||||
# Assume Python 3.8
|
||||
target-version = "py38"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "B", "UP", "C90", "RUF"]
|
||||
# B028: Ignore explicit stacklevel`
|
||||
# RUF013: Too many false positives (implicit optional)
|
||||
ignore = ["B028", "RUF013"]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
# ClassVar, implicit optional check not needed for tests
|
||||
"./tests/*.py"= ["RUF012", "RUF013"]
|
||||
|
||||
[tool.ruff.mccabe]
|
||||
[tool.ruff.lint.mccabe]
|
||||
# Unlike Flake8, ruff default to a complexity level of 10.
|
||||
max-complexity = 15
|
||||
|
||||
|
|
@ -28,12 +30,6 @@ exclude = """(?x)(
|
|||
| sb3_contrib/ars/ars.py$
|
||||
| sb3_contrib/common/recurrent/policies.py$
|
||||
| sb3_contrib/common/recurrent/buffers.py$
|
||||
| sb3_contrib/common/maskable/distributions.py$
|
||||
| sb3_contrib/common/maskable/callbacks.py$
|
||||
| sb3_contrib/common/maskable/policies.py$
|
||||
| sb3_contrib/common/maskable/buffers.py$
|
||||
| sb3_contrib/common/vec_env/async_eval.py$
|
||||
| sb3_contrib/ppo_mask/ppo_mask.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
)"""
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class MaskableRolloutBufferSamples(NamedTuple):
|
|||
action_masks: th.Tensor
|
||||
|
||||
|
||||
class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples):
|
||||
class MaskableDictRolloutBufferSamples(NamedTuple):
|
||||
observations: TensorDict
|
||||
actions: th.Tensor
|
||||
old_values: th.Tensor
|
||||
|
|
@ -42,6 +42,8 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
action_masks: np.ndarray
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
|
|
@ -53,7 +55,6 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
n_envs: int = 1,
|
||||
):
|
||||
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
|
||||
self.action_masks = None
|
||||
|
||||
def reset(self) -> None:
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
|
|
@ -61,6 +62,10 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
||||
mask_dims = sum(self.action_space.nvec)
|
||||
elif isinstance(self.action_space, spaces.MultiBinary):
|
||||
assert isinstance(self.action_space.n, int), (
|
||||
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
|
||||
"You can flatten it instead."
|
||||
)
|
||||
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
||||
|
|
@ -79,7 +84,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: # type: ignore[override]
|
||||
assert self.full, ""
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
|
|
@ -105,7 +110,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: # type: ignore[override]
|
||||
data = (
|
||||
self.observations[batch_inds],
|
||||
self.actions[batch_inds],
|
||||
|
|
@ -143,17 +148,18 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
action_masks: np.ndarray
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "auto",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
):
|
||||
self.action_masks = None
|
||||
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
|
||||
|
||||
def reset(self) -> None:
|
||||
|
|
@ -162,6 +168,10 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
||||
mask_dims = sum(self.action_space.nvec)
|
||||
elif isinstance(self.action_space, spaces.MultiBinary):
|
||||
assert isinstance(self.action_space.n, int), (
|
||||
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
|
||||
"You can flatten it instead."
|
||||
)
|
||||
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
||||
|
|
@ -180,7 +190,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: # type: ignore[override]
|
||||
assert self.full, ""
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
|
|
@ -203,7 +213,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: # type: ignore[override]
|
||||
return MaskableDictRolloutBufferSamples(
|
||||
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||
actions=self.to_torch(self.actions[batch_inds]),
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class MaskableEvalCallback(EvalCallback):
|
|||
|
||||
# Note that evaluate_policy() has been patched to support masking
|
||||
episode_rewards, episode_lengths = evaluate_policy(
|
||||
self.model,
|
||||
self.model, # type: ignore[arg-type]
|
||||
self.eval_env,
|
||||
n_eval_episodes=self.n_eval_episodes,
|
||||
render=self.render,
|
||||
|
|
@ -67,6 +67,8 @@ class MaskableEvalCallback(EvalCallback):
|
|||
)
|
||||
|
||||
if self.log_path is not None:
|
||||
assert isinstance(episode_rewards, list)
|
||||
assert isinstance(episode_lengths, list)
|
||||
self.evaluations_timesteps.append(self.num_timesteps)
|
||||
self.evaluations_results.append(episode_rewards)
|
||||
self.evaluations_length.append(episode_lengths)
|
||||
|
|
@ -87,7 +89,7 @@ class MaskableEvalCallback(EvalCallback):
|
|||
|
||||
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
||||
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
|
||||
self.last_mean_reward = mean_reward
|
||||
self.last_mean_reward = float(mean_reward)
|
||||
|
||||
if self.verbose > 0:
|
||||
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
|
||||
|
|
@ -111,7 +113,7 @@ class MaskableEvalCallback(EvalCallback):
|
|||
print("New best mean reward!")
|
||||
if self.best_model_save_path is not None:
|
||||
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
||||
self.best_mean_reward = mean_reward
|
||||
self.best_mean_reward = float(mean_reward)
|
||||
# Trigger callback on new best model, if needed
|
||||
if self.callback_on_new_best is not None:
|
||||
continue_training = self.callback_on_new_best.on_step()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple, TypeVar
|
||||
from typing import List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -13,6 +13,7 @@ SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistributi
|
|||
SelfMaskableMultiCategoricalDistribution = TypeVar(
|
||||
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
|
||||
)
|
||||
MaybeMasks = Union[th.Tensor, np.ndarray, None]
|
||||
|
||||
|
||||
class MaskableCategorical(Categorical):
|
||||
|
|
@ -36,14 +37,14 @@ class MaskableCategorical(Categorical):
|
|||
probs: Optional[th.Tensor] = None,
|
||||
logits: Optional[th.Tensor] = None,
|
||||
validate_args: Optional[bool] = None,
|
||||
masks: Optional[np.ndarray] = None,
|
||||
masks: MaybeMasks = None,
|
||||
):
|
||||
self.masks: Optional[th.Tensor] = None
|
||||
super().__init__(probs, logits, validate_args)
|
||||
self._original_logits = self.logits
|
||||
self.apply_masking(masks)
|
||||
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||
"""
|
||||
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
|
||||
|
||||
|
|
@ -84,7 +85,7 @@ class MaskableCategorical(Categorical):
|
|||
|
||||
class MaskableDistribution(Distribution, ABC):
|
||||
@abstractmethod
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||
"""
|
||||
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.
|
||||
|
||||
|
|
@ -94,6 +95,13 @@ class MaskableDistribution(Distribution, ABC):
|
|||
previously applied masking is removed, and the original logits are restored.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def proba_distribution_net(self, *args, **kwargs) -> nn.Module:
|
||||
"""Create the layers and parameters that represent the distribution.
|
||||
|
||||
Subclasses must define this, but the arguments and return type vary between
|
||||
concrete classes."""
|
||||
|
||||
|
||||
class MaskableCategoricalDistribution(MaskableDistribution):
|
||||
"""
|
||||
|
|
@ -154,7 +162,7 @@ class MaskableCategoricalDistribution(MaskableDistribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||
assert self.distribution is not None, "Must set distribution parameters"
|
||||
self.distribution.apply_masking(masks)
|
||||
|
||||
|
|
@ -192,7 +200,7 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
|||
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
|
||||
|
||||
self.distributions = [
|
||||
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1)
|
||||
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, list(self.action_dims), dim=1)
|
||||
]
|
||||
return self
|
||||
|
||||
|
|
@ -229,18 +237,16 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||
assert len(self.distributions) > 0, "Must set distribution parameters"
|
||||
|
||||
split_masks = [None] * len(self.distributions)
|
||||
if masks is not None:
|
||||
masks = th.as_tensor(masks)
|
||||
|
||||
masks_tensor = th.as_tensor(masks)
|
||||
# Restructure shape to align with logits
|
||||
masks = masks.view(-1, sum(self.action_dims))
|
||||
|
||||
masks_tensor = masks_tensor.view(-1, sum(self.action_dims))
|
||||
# Then split columnwise for each discrete action
|
||||
split_masks = th.split(masks, tuple(self.action_dims), dim=1)
|
||||
split_masks = th.split(masks_tensor, list(self.action_dims), dim=1) # type: ignore[assignment]
|
||||
|
||||
for distribution, mask in zip(self.distributions, split_masks):
|
||||
distribution.apply_masking(mask)
|
||||
|
|
@ -268,10 +274,13 @@ def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistri
|
|||
"""
|
||||
|
||||
if isinstance(action_space, spaces.Discrete):
|
||||
return MaskableCategoricalDistribution(action_space.n)
|
||||
return MaskableCategoricalDistribution(int(action_space.n))
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
return MaskableMultiCategoricalDistribution(action_space.nvec)
|
||||
return MaskableMultiCategoricalDistribution(list(action_space.nvec))
|
||||
elif isinstance(action_space, spaces.MultiBinary):
|
||||
assert isinstance(
|
||||
action_space.n, int
|
||||
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
|
||||
return MaskableBernoulliDistribution(action_space.n)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import (
|
|||
MlpExtractor,
|
||||
NatureCNN,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||
from torch import nn
|
||||
|
||||
from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution
|
||||
|
|
@ -141,8 +141,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
log_prob = distribution.log_prob(actions)
|
||||
return actions, values, log_prob
|
||||
|
||||
def extract_features(
|
||||
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||
def extract_features( # type: ignore[override]
|
||||
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
|
@ -233,7 +233,11 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
module.apply(partial(self.init_weights, gain=gain))
|
||||
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
self.optimizer = self.optimizer_class(
|
||||
self.parameters(),
|
||||
lr=lr_schedule(1), # type: ignore[call-arg]
|
||||
**self.optimizer_kwargs,
|
||||
)
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
|
||||
"""
|
||||
|
|
@ -245,9 +249,9 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
action_logits = self.action_net(latent_pi)
|
||||
return self.action_dist.proba_distribution(action_logits=action_logits)
|
||||
|
||||
def _predict(
|
||||
def _predict( # type: ignore[override]
|
||||
self,
|
||||
observation: th.Tensor,
|
||||
observation: PyTorchObs,
|
||||
deterministic: bool = False,
|
||||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> th.Tensor:
|
||||
|
|
@ -284,35 +288,45 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.set_training_mode(False)
|
||||
|
||||
observation, vectorized_env = self.obs_to_tensor(observation)
|
||||
# Check for common mistake that the user does not mix Gym/VecEnv API
|
||||
# Tuple obs are not supported by SB3, so we can safely do that check
|
||||
if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
|
||||
raise ValueError(
|
||||
"You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
|
||||
"You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
|
||||
"vs `obs = vec_env.reset()` (SB3 VecEnv). "
|
||||
"See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
|
||||
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
|
||||
)
|
||||
|
||||
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
|
||||
|
||||
with th.no_grad():
|
||||
actions = self._predict(observation, deterministic=deterministic, action_masks=action_masks)
|
||||
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
|
||||
# Convert to numpy
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
if self.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.unscale_action(actions)
|
||||
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
|
||||
else:
|
||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
|
||||
|
||||
if not vectorized_env:
|
||||
if state is not None:
|
||||
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
|
||||
assert isinstance(actions, np.ndarray)
|
||||
actions = actions.squeeze(axis=0)
|
||||
|
||||
return actions, None
|
||||
return actions, state # type: ignore[return-value]
|
||||
|
||||
def evaluate_actions(
|
||||
self,
|
||||
obs: th.Tensor,
|
||||
actions: th.Tensor,
|
||||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
action_masks: Optional[th.Tensor] = None,
|
||||
) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
|
@ -337,7 +351,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
values = self.value_net(latent_vf)
|
||||
return values, log_prob, distribution.entropy()
|
||||
|
||||
def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
|
||||
def get_distribution(self, obs: PyTorchObs, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
|
||||
"""
|
||||
Get the current policy distribution given the observations.
|
||||
|
||||
|
|
@ -352,7 +366,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
distribution.apply_masking(action_masks)
|
||||
return distribution
|
||||
|
||||
def predict_values(self, obs: th.Tensor) -> th.Tensor:
|
||||
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
|
||||
"""
|
||||
Get the estimated values according to the current policy given the observations.
|
||||
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
|||
Do a forward pass in the LSTM network.
|
||||
|
||||
:param features: Input tensor
|
||||
:param lstm_states: previous cell and hidden states of the LSTM
|
||||
:param lstm_states: previous hidden and cell states of the LSTM, respectively
|
||||
:param episode_starts: Indicates when a new episode starts,
|
||||
in that case, we need to reset LSTM states.
|
||||
:param lstm: LSTM object.
|
||||
|
|
|
|||
|
|
@ -59,10 +59,11 @@ def _worker(
|
|||
remote.send(vec_env.seed(seed=data))
|
||||
elif cmd == "set_options":
|
||||
# Note: the options will only be effective at the next reset
|
||||
remote.send(vec_env.set_options(data))
|
||||
remote.send(vec_env.set_options(data)) # type: ignore[func-returns-value]
|
||||
elif cmd == "get_obs_rms":
|
||||
remote.send(obs_rms)
|
||||
elif cmd == "sync_obs_rms":
|
||||
assert vec_normalize is not None, "Tried to call `sync_obs_rms` when not using VecNormalize"
|
||||
vec_normalize.obs_rms = data
|
||||
obs_rms = data
|
||||
elif cmd == "close":
|
||||
|
|
@ -130,7 +131,7 @@ class AsyncEval:
|
|||
n_eval_episodes,
|
||||
)
|
||||
# daemon=True: if the main process crashes, we should not cause things to hang
|
||||
process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
|
||||
process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined]
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
work_remote.close()
|
||||
|
|
@ -157,6 +158,10 @@ class AsyncEval:
|
|||
:param seed: The seed for the pseudo-random generators.
|
||||
:return:
|
||||
"""
|
||||
if seed is None:
|
||||
# Do nothing
|
||||
return []
|
||||
|
||||
for idx, remote in enumerate(self.remotes):
|
||||
remote.send(("seed", seed + idx))
|
||||
return [remote.recv() for remote in self.remotes]
|
||||
|
|
|
|||
|
|
@ -1,18 +1,14 @@
|
|||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common import utils
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
|
@ -74,6 +70,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
}
|
||||
policy: MaskableActorCriticPolicy # type: ignore[assignment]
|
||||
rollout_buffer: MaskableRolloutBuffer # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -91,6 +89,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
ent_coef: float = 0.0,
|
||||
vf_coef: float = 0.5,
|
||||
max_grad_norm: float = 0.5,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
target_kl: Optional[float] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
|
|
@ -101,7 +101,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
_init_setup_model: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
policy,
|
||||
policy, # type: ignore[arg-type]
|
||||
env,
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
|
|
@ -112,6 +112,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
max_grad_norm=max_grad_norm,
|
||||
use_sde=False,
|
||||
sde_sample_freq=-1,
|
||||
rollout_buffer_class=rollout_buffer_class,
|
||||
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||
stats_window_size=stats_window_size,
|
||||
tensorboard_log=tensorboard_log,
|
||||
policy_kwargs=policy_kwargs,
|
||||
|
|
@ -140,20 +142,24 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer
|
||||
|
||||
self.policy = self.policy_class(
|
||||
self.policy = self.policy_class( # type: ignore[assignment]
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.lr_schedule,
|
||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
||||
**self.policy_kwargs,
|
||||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
if not isinstance(self.policy, MaskableActorCriticPolicy):
|
||||
raise ValueError("Policy must subclass MaskableActorCriticPolicy")
|
||||
|
||||
self.rollout_buffer = buffer_cls(
|
||||
if self.rollout_buffer_class is None:
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
self.rollout_buffer_class = MaskableDictRolloutBuffer
|
||||
else:
|
||||
self.rollout_buffer_class = MaskableRolloutBuffer
|
||||
|
||||
self.rollout_buffer = self.rollout_buffer_class( # type: ignore[assignment]
|
||||
self.n_steps,
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
|
|
@ -161,6 +167,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
gamma=self.gamma,
|
||||
gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs,
|
||||
**self.rollout_buffer_kwargs,
|
||||
)
|
||||
|
||||
# Initialize schedules for policy/value clipping
|
||||
|
|
@ -171,86 +178,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
|
||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||
|
||||
def _init_callback(
|
||||
self,
|
||||
callback: MaybeCallback,
|
||||
use_masking: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> BaseCallback:
|
||||
"""
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||
:return: A hybrid callback calling `callback` and performing evaluation.
|
||||
"""
|
||||
# Convert a list of callbacks into a callback
|
||||
if isinstance(callback, list):
|
||||
callback = CallbackList(callback)
|
||||
|
||||
# Convert functional callback to object
|
||||
if not isinstance(callback, BaseCallback):
|
||||
callback = ConvertCallback(callback)
|
||||
|
||||
# Add progress bar callback
|
||||
if progress_bar:
|
||||
callback = CallbackList([callback, ProgressBarCallback()])
|
||||
|
||||
callback.init_callback(self)
|
||||
return callback
|
||||
|
||||
def _setup_learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
tb_log_name: str = "run",
|
||||
use_masking: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> Tuple[int, BaseCallback]:
|
||||
"""
|
||||
Initialize different variables needed for training.
|
||||
|
||||
:param total_timesteps: The total number of samples (env steps) to train on
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
||||
:param tb_log_name: the name of the run for tensorboard log
|
||||
:param use_masking: Whether or not to use invalid action masks during training
|
||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||
:return:
|
||||
"""
|
||||
|
||||
self.start_time = time.time_ns()
|
||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
||||
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
||||
self.ep_info_buffer = deque(maxlen=self._stats_window_size)
|
||||
self.ep_success_buffer = deque(maxlen=self._stats_window_size)
|
||||
|
||||
if reset_num_timesteps:
|
||||
self.num_timesteps = 0
|
||||
self._episode_num = 0
|
||||
else:
|
||||
# Make sure training timesteps are ahead of the internal counter
|
||||
total_timesteps += self.num_timesteps
|
||||
self._total_timesteps = total_timesteps
|
||||
self._num_timesteps_at_start = self.num_timesteps
|
||||
|
||||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
||||
if reset_num_timesteps or self._last_obs is None:
|
||||
self._last_obs = self.env.reset()
|
||||
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
self._last_original_obs = self._vec_normalize_env.get_original_obs()
|
||||
|
||||
# Configure logger's outputs if no logger was passed
|
||||
if not self._custom_logger:
|
||||
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
||||
|
||||
# Create eval callback if needed
|
||||
callback = self._init_callback(callback, use_masking, progress_bar)
|
||||
|
||||
return total_timesteps, callback
|
||||
|
||||
def collect_rollouts(
|
||||
self,
|
||||
env: VecEnv,
|
||||
|
|
@ -312,7 +239,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
if not callback.on_step():
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
self._update_info_buffer(infos, dones)
|
||||
n_steps += 1
|
||||
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
|
|
@ -341,14 +268,14 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
log_probs,
|
||||
action_masks=action_masks,
|
||||
)
|
||||
self._last_obs = new_obs
|
||||
self._last_obs = new_obs # type: ignore[assignment]
|
||||
self._last_episode_starts = dones
|
||||
|
||||
with th.no_grad():
|
||||
# Compute value for the last timestep
|
||||
# Masking is not needed here, the choice of action doesn't matter.
|
||||
# We only want the value of the current observation.
|
||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
|
||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||
|
||||
|
|
@ -356,9 +283,9 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
|
||||
return True
|
||||
|
||||
def predict(
|
||||
def predict( # type: ignore[override]
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
|
|
@ -388,10 +315,10 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# Compute current clip range
|
||||
clip_range = self.clip_range(self._current_progress_remaining)
|
||||
clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
|
||||
# Optional: clip range for the value function
|
||||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
|
||||
|
||||
entropy_losses = []
|
||||
pg_losses, value_losses = [], []
|
||||
|
|
@ -499,7 +426,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
if self.clip_range_vf is not None:
|
||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
def learn( # type: ignore[override]
|
||||
self: SelfMaskablePPO,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
|
|
@ -516,12 +443,13 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
callback,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
use_masking,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
assert self.env is not None
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
|
||||
|
||||
|
|
@ -533,16 +461,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("time/fps", fps)
|
||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
self._dump_logs(iteration)
|
||||
|
||||
self.train()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import sys
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
||||
|
||||
|
|
@ -11,7 +9,7 @@ from stable_baselines3.common.callbacks import BaseCallback
|
|||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer
|
||||
|
|
@ -149,7 +147,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
self.action_space,
|
||||
self.lr_schedule,
|
||||
use_sde=self.use_sde,
|
||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
||||
**self.policy_kwargs,
|
||||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
|
|
@ -260,7 +258,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
if not callback.on_step():
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
self._update_info_buffer(infos, dones)
|
||||
n_steps += 1
|
||||
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
|
|
@ -453,42 +451,11 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfRecurrentPPO:
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
callback,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
progress_bar,
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
tb_log_name=tb_log_name,
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
progress_bar=progress_bar,
|
||||
)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
||||
|
||||
if not continue_training:
|
||||
break
|
||||
|
||||
iteration += 1
|
||||
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
||||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("time/fps", fps)
|
||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
|
||||
self.train()
|
||||
|
||||
callback.on_training_end()
|
||||
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
"""
|
||||
Quantile Regression Deep Q-Network (QR-DQN)
|
||||
Paper: https://arxiv.org/abs/1710.10044
|
||||
Default hyperparameters are taken from the paper and are tuned for Atari games.
|
||||
Default hyperparameters are taken from the paper and are tuned for Atari games
|
||||
(except for the ``learning_starts`` parameter).
|
||||
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||
|
|
@ -77,11 +78,11 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 5e-5,
|
||||
buffer_size: int = 1000000, # 1e6
|
||||
learning_starts: int = 50000,
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 32,
|
||||
tau: float = 1.0,
|
||||
gamma: float = 0.99,
|
||||
train_freq: int = 4,
|
||||
train_freq: Union[int, Tuple[int, str]] = 4,
|
||||
gradient_steps: int = 1,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
train_freq: int = 1,
|
||||
train_freq: Union[int, Tuple[int, str]] = 1,
|
||||
gradient_steps: int = 1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.2.0
|
||||
2.3.0
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=2.2.0,<3.0",
|
||||
"stable_baselines3>=2.3.0,<3.0",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
Loading…
Reference in New Issue