Compare commits
10 Commits
ebb74c44ec
...
adfbeb1b01
| Author | SHA1 | Date |
|---|---|---|
|
|
adfbeb1b01 | |
|
|
17cd797d5c | |
|
|
34eceaf382 | |
|
|
89d0113037 | |
|
|
7dd6c39fba | |
|
|
cd31e89e26 | |
|
|
bc3c0a9595 | |
|
|
3f0c5088b3 | |
|
|
1553b66ee4 | |
|
|
94a5daab02 |
10
Makefile
10
Makefile
|
|
@ -12,23 +12,23 @@ type: mypy
|
||||||
lint:
|
lint:
|
||||||
# stop the build if there are Python syntax errors or undefined names
|
# stop the build if there are Python syntax errors or undefined names
|
||||||
# see https://www.flake8rules.com/
|
# see https://www.flake8rules.com/
|
||||||
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
|
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
|
||||||
# exit-zero treats all errors as warnings.
|
# exit-zero treats all errors as warnings.
|
||||||
ruff ${LINT_PATHS} --exit-zero
|
ruff check ${LINT_PATHS} --exit-zero
|
||||||
|
|
||||||
|
|
||||||
format:
|
format:
|
||||||
# Sort imports
|
# Sort imports
|
||||||
ruff --select I ${LINT_PATHS} --fix
|
ruff check --select I ${LINT_PATHS} --fix
|
||||||
# Reformat using black
|
# Reformat using black
|
||||||
black ${LINT_PATHS}
|
black ${LINT_PATHS}
|
||||||
|
|
||||||
check-codestyle:
|
check-codestyle:
|
||||||
# Sort imports
|
# Sort imports
|
||||||
ruff --select I ${LINT_PATHS}
|
ruff check --select I ${LINT_PATHS}
|
||||||
# Reformat using black
|
# Reformat using black
|
||||||
black --check ${LINT_PATHS}
|
black --check ${LINT_PATHS}
|
||||||
|
|
||||||
|
|
||||||
commit-checks: format type lint
|
commit-checks: format type lint
|
||||||
|
|
||||||
doc:
|
doc:
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@
|
||||||
# add these directories to sys.path here. If the directory is relative to the
|
# add these directories to sys.path here. If the directory is relative to the
|
||||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||||
#
|
#
|
||||||
|
import datetime
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
@ -44,7 +45,7 @@ with open(version_file) as file_handler:
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "Stable Baselines3 - Contrib"
|
project = "Stable Baselines3 - Contrib"
|
||||||
copyright = "2023, Stable Baselines3"
|
copyright = f"2021-{datetime.date.today().year}, Stable Baselines3"
|
||||||
author = "Stable Baselines3 Contributors"
|
author = "Stable Baselines3 Contributors"
|
||||||
|
|
||||||
# The short X.Y version
|
# The short X.Y version
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,12 @@ MaskablePPO
|
||||||
|
|
||||||
Train a PPO with invalid action masking agent on a toy environment.
|
Train a PPO with invalid action masking agent on a toy environment.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
|
||||||
|
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from sb3_contrib import MaskablePPO
|
from sb3_contrib import MaskablePPO
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,55 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
Release 2.3.0 (2024-03-31)
|
||||||
|
--------------------------
|
||||||
|
|
||||||
Release 2.2.0 (2023-11-16)
|
**New defaults hyperparameters for QR-DQN**
|
||||||
|
|
||||||
|
Breaking Changes:
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Upgraded to Stable-Baselines3 >= 2.3.0
|
||||||
|
- The default ``learning_starts`` parameter of ``QRDQN`` have been changed to be consistent with the other offpolicy algorithms
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
|
||||||
|
# model = QRDQN("MlpPolicy", env, learning_starts=50_000)
|
||||||
|
# SB3 >= 2.3.0:
|
||||||
|
model = QRDQN("MlpPolicy", env, learning_starts=100)
|
||||||
|
|
||||||
|
|
||||||
|
New Features:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO
|
||||||
|
- Log success rate ``rollout/success_rate`` when available for on policy algorithms
|
||||||
|
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
Deprecations:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Others:
|
||||||
|
^^^^^^^
|
||||||
|
- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl)
|
||||||
|
- Fixed ``sb3_contrib/common/maskable/*.py`` type annotations
|
||||||
|
- Fixed ``sb3_contrib/ppo_mask/ppo_mask.py`` type annotations
|
||||||
|
- Fixed ``sb3_contrib/common/vec_env/async_eval.py`` type annotations
|
||||||
|
|
||||||
|
Documentation:
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
- Add some additional notes about ``MaskablePPO`` (evaluation and multi-process) (@icheered)
|
||||||
|
|
||||||
|
|
||||||
|
Release 2.2.1 (2023-11-17)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
- Upgraded to Stable-Baselines3 >= 2.2.0
|
- Upgraded to Stable-Baselines3 >= 2.2.1
|
||||||
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
|
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
|
||||||
- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)
|
- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)
|
||||||
|
|
||||||
|
|
@ -512,4 +554,4 @@ Contributors:
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
||||||
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher
|
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,16 @@ MultiBinary ✔️ ✔️
|
||||||
Dict ❌ ✔️
|
Dict ❌ ✔️
|
||||||
============= ====== ===========
|
============= ====== ===========
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
You must use ``MaskableEvalCallback`` from ``sb3_contrib.common.maskable.callbacks`` instead of the base ``EvalCallback`` to properly evaluate a model with action masks.
|
||||||
|
Similarly, you must use ``evaluate_policy`` from ``sb3_contrib.common.maskable.evaluation`` instead of the SB3 one.
|
||||||
|
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
In order to use ``SubprocVecEnv`` with ``MaskablePPO``, you must implement the ``action_masks`` inside the environment (``ActionMasker`` cannot be used).
|
||||||
|
You can have a look at the `built-in environments with invalid action masks <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/envs/invalid_actions_env.py>`_ to have a working example.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Example
|
Example
|
||||||
-------
|
-------
|
||||||
|
|
@ -58,6 +68,8 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe
|
||||||
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
|
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
|
||||||
from sb3_contrib.common.maskable.evaluation import evaluate_policy
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy
|
||||||
from sb3_contrib.common.maskable.utils import get_action_masks
|
from sb3_contrib.common.maskable.utils import get_action_masks
|
||||||
|
# This is a drop-in replacement for EvalCallback
|
||||||
|
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
||||||
|
|
||||||
|
|
||||||
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
env = InvalidActionEnvDiscrete(dim=80, n_invalid_actions=60)
|
||||||
|
|
|
||||||
|
|
@ -3,16 +3,18 @@
|
||||||
line-length = 127
|
line-length = 127
|
||||||
# Assume Python 3.8
|
# Assume Python 3.8
|
||||||
target-version = "py38"
|
target-version = "py38"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
select = ["E", "F", "B", "UP", "C90", "RUF"]
|
select = ["E", "F", "B", "UP", "C90", "RUF"]
|
||||||
# B028: Ignore explicit stacklevel`
|
# B028: Ignore explicit stacklevel`
|
||||||
# RUF013: Too many false positives (implicit optional)
|
# RUF013: Too many false positives (implicit optional)
|
||||||
ignore = ["B028", "RUF013"]
|
ignore = ["B028", "RUF013"]
|
||||||
|
|
||||||
[tool.ruff.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
# ClassVar, implicit optional check not needed for tests
|
# ClassVar, implicit optional check not needed for tests
|
||||||
"./tests/*.py"= ["RUF012", "RUF013"]
|
"./tests/*.py"= ["RUF012", "RUF013"]
|
||||||
|
|
||||||
[tool.ruff.mccabe]
|
[tool.ruff.lint.mccabe]
|
||||||
# Unlike Flake8, ruff default to a complexity level of 10.
|
# Unlike Flake8, ruff default to a complexity level of 10.
|
||||||
max-complexity = 15
|
max-complexity = 15
|
||||||
|
|
||||||
|
|
@ -28,12 +30,6 @@ exclude = """(?x)(
|
||||||
| sb3_contrib/ars/ars.py$
|
| sb3_contrib/ars/ars.py$
|
||||||
| sb3_contrib/common/recurrent/policies.py$
|
| sb3_contrib/common/recurrent/policies.py$
|
||||||
| sb3_contrib/common/recurrent/buffers.py$
|
| sb3_contrib/common/recurrent/buffers.py$
|
||||||
| sb3_contrib/common/maskable/distributions.py$
|
|
||||||
| sb3_contrib/common/maskable/callbacks.py$
|
|
||||||
| sb3_contrib/common/maskable/policies.py$
|
|
||||||
| sb3_contrib/common/maskable/buffers.py$
|
|
||||||
| sb3_contrib/common/vec_env/async_eval.py$
|
|
||||||
| sb3_contrib/ppo_mask/ppo_mask.py$
|
|
||||||
| tests/test_train_eval_mode.py$
|
| tests/test_train_eval_mode.py$
|
||||||
)"""
|
)"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ class MaskableRolloutBufferSamples(NamedTuple):
|
||||||
action_masks: th.Tensor
|
action_masks: th.Tensor
|
||||||
|
|
||||||
|
|
||||||
class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples):
|
class MaskableDictRolloutBufferSamples(NamedTuple):
|
||||||
observations: TensorDict
|
observations: TensorDict
|
||||||
actions: th.Tensor
|
actions: th.Tensor
|
||||||
old_values: th.Tensor
|
old_values: th.Tensor
|
||||||
|
|
@ -42,6 +42,8 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
||||||
:param n_envs: Number of parallel environments
|
:param n_envs: Number of parallel environments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
action_masks: np.ndarray
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
buffer_size: int,
|
buffer_size: int,
|
||||||
|
|
@ -53,7 +55,6 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
||||||
n_envs: int = 1,
|
n_envs: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
|
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
|
||||||
self.action_masks = None
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
if isinstance(self.action_space, spaces.Discrete):
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
|
|
@ -61,6 +62,10 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
||||||
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
||||||
mask_dims = sum(self.action_space.nvec)
|
mask_dims = sum(self.action_space.nvec)
|
||||||
elif isinstance(self.action_space, spaces.MultiBinary):
|
elif isinstance(self.action_space, spaces.MultiBinary):
|
||||||
|
assert isinstance(self.action_space.n, int), (
|
||||||
|
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
|
||||||
|
"You can flatten it instead."
|
||||||
|
)
|
||||||
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
||||||
|
|
@ -79,7 +84,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
||||||
|
|
||||||
super().add(*args, **kwargs)
|
super().add(*args, **kwargs)
|
||||||
|
|
||||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
|
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: # type: ignore[override]
|
||||||
assert self.full, ""
|
assert self.full, ""
|
||||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
|
|
@ -105,7 +110,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
||||||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||||
start_idx += batch_size
|
start_idx += batch_size
|
||||||
|
|
||||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: # type: ignore[override]
|
||||||
data = (
|
data = (
|
||||||
self.observations[batch_inds],
|
self.observations[batch_inds],
|
||||||
self.actions[batch_inds],
|
self.actions[batch_inds],
|
||||||
|
|
@ -143,17 +148,18 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
:param n_envs: Number of parallel environments
|
:param n_envs: Number of parallel environments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
action_masks: np.ndarray
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
buffer_size: int,
|
buffer_size: int,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Dict,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
device: Union[th.device, str] = "auto",
|
device: Union[th.device, str] = "auto",
|
||||||
gae_lambda: float = 1,
|
gae_lambda: float = 1,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
n_envs: int = 1,
|
n_envs: int = 1,
|
||||||
):
|
):
|
||||||
self.action_masks = None
|
|
||||||
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
|
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
|
@ -162,6 +168,10 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
||||||
mask_dims = sum(self.action_space.nvec)
|
mask_dims = sum(self.action_space.nvec)
|
||||||
elif isinstance(self.action_space, spaces.MultiBinary):
|
elif isinstance(self.action_space, spaces.MultiBinary):
|
||||||
|
assert isinstance(self.action_space.n, int), (
|
||||||
|
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
|
||||||
|
"You can flatten it instead."
|
||||||
|
)
|
||||||
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
||||||
|
|
@ -180,7 +190,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
|
|
||||||
super().add(*args, **kwargs)
|
super().add(*args, **kwargs)
|
||||||
|
|
||||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
|
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: # type: ignore[override]
|
||||||
assert self.full, ""
|
assert self.full, ""
|
||||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
|
|
@ -203,7 +213,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||||
start_idx += batch_size
|
start_idx += batch_size
|
||||||
|
|
||||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: # type: ignore[override]
|
||||||
return MaskableDictRolloutBufferSamples(
|
return MaskableDictRolloutBufferSamples(
|
||||||
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||||
actions=self.to_torch(self.actions[batch_inds]),
|
actions=self.to_torch(self.actions[batch_inds]),
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
|
|
||||||
# Note that evaluate_policy() has been patched to support masking
|
# Note that evaluate_policy() has been patched to support masking
|
||||||
episode_rewards, episode_lengths = evaluate_policy(
|
episode_rewards, episode_lengths = evaluate_policy(
|
||||||
self.model,
|
self.model, # type: ignore[arg-type]
|
||||||
self.eval_env,
|
self.eval_env,
|
||||||
n_eval_episodes=self.n_eval_episodes,
|
n_eval_episodes=self.n_eval_episodes,
|
||||||
render=self.render,
|
render=self.render,
|
||||||
|
|
@ -67,6 +67,8 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.log_path is not None:
|
if self.log_path is not None:
|
||||||
|
assert isinstance(episode_rewards, list)
|
||||||
|
assert isinstance(episode_lengths, list)
|
||||||
self.evaluations_timesteps.append(self.num_timesteps)
|
self.evaluations_timesteps.append(self.num_timesteps)
|
||||||
self.evaluations_results.append(episode_rewards)
|
self.evaluations_results.append(episode_rewards)
|
||||||
self.evaluations_length.append(episode_lengths)
|
self.evaluations_length.append(episode_lengths)
|
||||||
|
|
@ -87,7 +89,7 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
|
|
||||||
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
||||||
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
|
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
|
||||||
self.last_mean_reward = mean_reward
|
self.last_mean_reward = float(mean_reward)
|
||||||
|
|
||||||
if self.verbose > 0:
|
if self.verbose > 0:
|
||||||
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
|
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
|
||||||
|
|
@ -111,7 +113,7 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
print("New best mean reward!")
|
print("New best mean reward!")
|
||||||
if self.best_model_save_path is not None:
|
if self.best_model_save_path is not None:
|
||||||
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
||||||
self.best_mean_reward = mean_reward
|
self.best_mean_reward = float(mean_reward)
|
||||||
# Trigger callback on new best model, if needed
|
# Trigger callback on new best model, if needed
|
||||||
if self.callback_on_new_best is not None:
|
if self.callback_on_new_best is not None:
|
||||||
continue_training = self.callback_on_new_best.on_step()
|
continue_training = self.callback_on_new_best.on_step()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Tuple, TypeVar
|
from typing import List, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
|
|
@ -13,6 +13,7 @@ SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistributi
|
||||||
SelfMaskableMultiCategoricalDistribution = TypeVar(
|
SelfMaskableMultiCategoricalDistribution = TypeVar(
|
||||||
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
|
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
|
||||||
)
|
)
|
||||||
|
MaybeMasks = Union[th.Tensor, np.ndarray, None]
|
||||||
|
|
||||||
|
|
||||||
class MaskableCategorical(Categorical):
|
class MaskableCategorical(Categorical):
|
||||||
|
|
@ -36,14 +37,14 @@ class MaskableCategorical(Categorical):
|
||||||
probs: Optional[th.Tensor] = None,
|
probs: Optional[th.Tensor] = None,
|
||||||
logits: Optional[th.Tensor] = None,
|
logits: Optional[th.Tensor] = None,
|
||||||
validate_args: Optional[bool] = None,
|
validate_args: Optional[bool] = None,
|
||||||
masks: Optional[np.ndarray] = None,
|
masks: MaybeMasks = None,
|
||||||
):
|
):
|
||||||
self.masks: Optional[th.Tensor] = None
|
self.masks: Optional[th.Tensor] = None
|
||||||
super().__init__(probs, logits, validate_args)
|
super().__init__(probs, logits, validate_args)
|
||||||
self._original_logits = self.logits
|
self._original_logits = self.logits
|
||||||
self.apply_masking(masks)
|
self.apply_masking(masks)
|
||||||
|
|
||||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||||
"""
|
"""
|
||||||
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
|
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
|
||||||
|
|
||||||
|
|
@ -84,7 +85,7 @@ class MaskableCategorical(Categorical):
|
||||||
|
|
||||||
class MaskableDistribution(Distribution, ABC):
|
class MaskableDistribution(Distribution, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||||
"""
|
"""
|
||||||
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.
|
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.
|
||||||
|
|
||||||
|
|
@ -94,6 +95,13 @@ class MaskableDistribution(Distribution, ABC):
|
||||||
previously applied masking is removed, and the original logits are restored.
|
previously applied masking is removed, and the original logits are restored.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def proba_distribution_net(self, *args, **kwargs) -> nn.Module:
|
||||||
|
"""Create the layers and parameters that represent the distribution.
|
||||||
|
|
||||||
|
Subclasses must define this, but the arguments and return type vary between
|
||||||
|
concrete classes."""
|
||||||
|
|
||||||
|
|
||||||
class MaskableCategoricalDistribution(MaskableDistribution):
|
class MaskableCategoricalDistribution(MaskableDistribution):
|
||||||
"""
|
"""
|
||||||
|
|
@ -154,7 +162,7 @@ class MaskableCategoricalDistribution(MaskableDistribution):
|
||||||
log_prob = self.log_prob(actions)
|
log_prob = self.log_prob(actions)
|
||||||
return actions, log_prob
|
return actions, log_prob
|
||||||
|
|
||||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||||
assert self.distribution is not None, "Must set distribution parameters"
|
assert self.distribution is not None, "Must set distribution parameters"
|
||||||
self.distribution.apply_masking(masks)
|
self.distribution.apply_masking(masks)
|
||||||
|
|
||||||
|
|
@ -192,7 +200,7 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
||||||
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
|
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
|
||||||
|
|
||||||
self.distributions = [
|
self.distributions = [
|
||||||
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1)
|
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, list(self.action_dims), dim=1)
|
||||||
]
|
]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -229,18 +237,16 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
||||||
log_prob = self.log_prob(actions)
|
log_prob = self.log_prob(actions)
|
||||||
return actions, log_prob
|
return actions, log_prob
|
||||||
|
|
||||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||||
assert len(self.distributions) > 0, "Must set distribution parameters"
|
assert len(self.distributions) > 0, "Must set distribution parameters"
|
||||||
|
|
||||||
split_masks = [None] * len(self.distributions)
|
split_masks = [None] * len(self.distributions)
|
||||||
if masks is not None:
|
if masks is not None:
|
||||||
masks = th.as_tensor(masks)
|
masks_tensor = th.as_tensor(masks)
|
||||||
|
|
||||||
# Restructure shape to align with logits
|
# Restructure shape to align with logits
|
||||||
masks = masks.view(-1, sum(self.action_dims))
|
masks_tensor = masks_tensor.view(-1, sum(self.action_dims))
|
||||||
|
|
||||||
# Then split columnwise for each discrete action
|
# Then split columnwise for each discrete action
|
||||||
split_masks = th.split(masks, tuple(self.action_dims), dim=1)
|
split_masks = th.split(masks_tensor, list(self.action_dims), dim=1) # type: ignore[assignment]
|
||||||
|
|
||||||
for distribution, mask in zip(self.distributions, split_masks):
|
for distribution, mask in zip(self.distributions, split_masks):
|
||||||
distribution.apply_masking(mask)
|
distribution.apply_masking(mask)
|
||||||
|
|
@ -268,10 +274,13 @@ def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistri
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(action_space, spaces.Discrete):
|
if isinstance(action_space, spaces.Discrete):
|
||||||
return MaskableCategoricalDistribution(action_space.n)
|
return MaskableCategoricalDistribution(int(action_space.n))
|
||||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||||
return MaskableMultiCategoricalDistribution(action_space.nvec)
|
return MaskableMultiCategoricalDistribution(list(action_space.nvec))
|
||||||
elif isinstance(action_space, spaces.MultiBinary):
|
elif isinstance(action_space, spaces.MultiBinary):
|
||||||
|
assert isinstance(
|
||||||
|
action_space.n, int
|
||||||
|
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
|
||||||
return MaskableBernoulliDistribution(action_space.n)
|
return MaskableBernoulliDistribution(action_space.n)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import (
|
||||||
MlpExtractor,
|
MlpExtractor,
|
||||||
NatureCNN,
|
NatureCNN,
|
||||||
)
|
)
|
||||||
from stable_baselines3.common.type_aliases import Schedule
|
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution
|
from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution
|
||||||
|
|
@ -141,8 +141,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
log_prob = distribution.log_prob(actions)
|
log_prob = distribution.log_prob(actions)
|
||||||
return actions, values, log_prob
|
return actions, values, log_prob
|
||||||
|
|
||||||
def extract_features(
|
def extract_features( # type: ignore[override]
|
||||||
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
|
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||||
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Preprocess the observation if needed and extract features.
|
Preprocess the observation if needed and extract features.
|
||||||
|
|
@ -233,7 +233,11 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
module.apply(partial(self.init_weights, gain=gain))
|
module.apply(partial(self.init_weights, gain=gain))
|
||||||
|
|
||||||
# Setup optimizer with initial learning rate
|
# Setup optimizer with initial learning rate
|
||||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.optimizer = self.optimizer_class(
|
||||||
|
self.parameters(),
|
||||||
|
lr=lr_schedule(1), # type: ignore[call-arg]
|
||||||
|
**self.optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
|
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
|
||||||
"""
|
"""
|
||||||
|
|
@ -245,9 +249,9 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
action_logits = self.action_net(latent_pi)
|
action_logits = self.action_net(latent_pi)
|
||||||
return self.action_dist.proba_distribution(action_logits=action_logits)
|
return self.action_dist.proba_distribution(action_logits=action_logits)
|
||||||
|
|
||||||
def _predict(
|
def _predict( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
observation: th.Tensor,
|
observation: PyTorchObs,
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
action_masks: Optional[np.ndarray] = None,
|
action_masks: Optional[np.ndarray] = None,
|
||||||
) -> th.Tensor:
|
) -> th.Tensor:
|
||||||
|
|
@ -284,35 +288,45 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
# Switch to eval mode (this affects batch norm / dropout)
|
# Switch to eval mode (this affects batch norm / dropout)
|
||||||
self.set_training_mode(False)
|
self.set_training_mode(False)
|
||||||
|
|
||||||
observation, vectorized_env = self.obs_to_tensor(observation)
|
# Check for common mistake that the user does not mix Gym/VecEnv API
|
||||||
|
# Tuple obs are not supported by SB3, so we can safely do that check
|
||||||
|
if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
|
||||||
|
raise ValueError(
|
||||||
|
"You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
|
||||||
|
"You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
|
||||||
|
"vs `obs = vec_env.reset()` (SB3 VecEnv). "
|
||||||
|
"See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
|
||||||
|
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
|
||||||
|
)
|
||||||
|
|
||||||
|
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
actions = self._predict(observation, deterministic=deterministic, action_masks=action_masks)
|
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
|
||||||
# Convert to numpy
|
# Convert to numpy
|
||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
if isinstance(self.action_space, spaces.Box):
|
if isinstance(self.action_space, spaces.Box):
|
||||||
if self.squash_output:
|
if self.squash_output:
|
||||||
# Rescale to proper domain when using squashing
|
# Rescale to proper domain when using squashing
|
||||||
actions = self.unscale_action(actions)
|
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
|
||||||
else:
|
else:
|
||||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||||
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
|
||||||
|
|
||||||
if not vectorized_env:
|
if not vectorized_env:
|
||||||
if state is not None:
|
assert isinstance(actions, np.ndarray)
|
||||||
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
|
|
||||||
actions = actions.squeeze(axis=0)
|
actions = actions.squeeze(axis=0)
|
||||||
|
|
||||||
return actions, None
|
return actions, state # type: ignore[return-value]
|
||||||
|
|
||||||
def evaluate_actions(
|
def evaluate_actions(
|
||||||
self,
|
self,
|
||||||
obs: th.Tensor,
|
obs: th.Tensor,
|
||||||
actions: th.Tensor,
|
actions: th.Tensor,
|
||||||
action_masks: Optional[np.ndarray] = None,
|
action_masks: Optional[th.Tensor] = None,
|
||||||
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Evaluate actions according to the current policy,
|
Evaluate actions according to the current policy,
|
||||||
given the observations.
|
given the observations.
|
||||||
|
|
@ -337,7 +351,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
values = self.value_net(latent_vf)
|
values = self.value_net(latent_vf)
|
||||||
return values, log_prob, distribution.entropy()
|
return values, log_prob, distribution.entropy()
|
||||||
|
|
||||||
def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
|
def get_distribution(self, obs: PyTorchObs, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
|
||||||
"""
|
"""
|
||||||
Get the current policy distribution given the observations.
|
Get the current policy distribution given the observations.
|
||||||
|
|
||||||
|
|
@ -352,7 +366,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
distribution.apply_masking(action_masks)
|
distribution.apply_masking(action_masks)
|
||||||
return distribution
|
return distribution
|
||||||
|
|
||||||
def predict_values(self, obs: th.Tensor) -> th.Tensor:
|
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
|
||||||
"""
|
"""
|
||||||
Get the estimated values according to the current policy given the observations.
|
Get the estimated values according to the current policy given the observations.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,7 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy):
|
||||||
Do a forward pass in the LSTM network.
|
Do a forward pass in the LSTM network.
|
||||||
|
|
||||||
:param features: Input tensor
|
:param features: Input tensor
|
||||||
:param lstm_states: previous cell and hidden states of the LSTM
|
:param lstm_states: previous hidden and cell states of the LSTM, respectively
|
||||||
:param episode_starts: Indicates when a new episode starts,
|
:param episode_starts: Indicates when a new episode starts,
|
||||||
in that case, we need to reset LSTM states.
|
in that case, we need to reset LSTM states.
|
||||||
:param lstm: LSTM object.
|
:param lstm: LSTM object.
|
||||||
|
|
|
||||||
|
|
@ -59,10 +59,11 @@ def _worker(
|
||||||
remote.send(vec_env.seed(seed=data))
|
remote.send(vec_env.seed(seed=data))
|
||||||
elif cmd == "set_options":
|
elif cmd == "set_options":
|
||||||
# Note: the options will only be effective at the next reset
|
# Note: the options will only be effective at the next reset
|
||||||
remote.send(vec_env.set_options(data))
|
remote.send(vec_env.set_options(data)) # type: ignore[func-returns-value]
|
||||||
elif cmd == "get_obs_rms":
|
elif cmd == "get_obs_rms":
|
||||||
remote.send(obs_rms)
|
remote.send(obs_rms)
|
||||||
elif cmd == "sync_obs_rms":
|
elif cmd == "sync_obs_rms":
|
||||||
|
assert vec_normalize is not None, "Tried to call `sync_obs_rms` when not using VecNormalize"
|
||||||
vec_normalize.obs_rms = data
|
vec_normalize.obs_rms = data
|
||||||
obs_rms = data
|
obs_rms = data
|
||||||
elif cmd == "close":
|
elif cmd == "close":
|
||||||
|
|
@ -130,7 +131,7 @@ class AsyncEval:
|
||||||
n_eval_episodes,
|
n_eval_episodes,
|
||||||
)
|
)
|
||||||
# daemon=True: if the main process crashes, we should not cause things to hang
|
# daemon=True: if the main process crashes, we should not cause things to hang
|
||||||
process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
|
process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined]
|
||||||
process.start()
|
process.start()
|
||||||
self.processes.append(process)
|
self.processes.append(process)
|
||||||
work_remote.close()
|
work_remote.close()
|
||||||
|
|
@ -157,6 +158,10 @@ class AsyncEval:
|
||||||
:param seed: The seed for the pseudo-random generators.
|
:param seed: The seed for the pseudo-random generators.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
if seed is None:
|
||||||
|
# Do nothing
|
||||||
|
return []
|
||||||
|
|
||||||
for idx, remote in enumerate(self.remotes):
|
for idx, remote in enumerate(self.remotes):
|
||||||
remote.send(("seed", seed + idx))
|
remote.send(("seed", seed + idx))
|
||||||
return [remote.recv() for remote in self.remotes]
|
return [remote.recv() for remote in self.remotes]
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,14 @@
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from collections import deque
|
|
||||||
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common import utils
|
|
||||||
from stable_baselines3.common.buffers import RolloutBuffer
|
from stable_baselines3.common.buffers import RolloutBuffer
|
||||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
|
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
|
||||||
from stable_baselines3.common.vec_env import VecEnv
|
from stable_baselines3.common.vec_env import VecEnv
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
@ -74,6 +70,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
"CnnPolicy": CnnPolicy,
|
"CnnPolicy": CnnPolicy,
|
||||||
"MultiInputPolicy": MultiInputPolicy,
|
"MultiInputPolicy": MultiInputPolicy,
|
||||||
}
|
}
|
||||||
|
policy: MaskableActorCriticPolicy # type: ignore[assignment]
|
||||||
|
rollout_buffer: MaskableRolloutBuffer # type: ignore[assignment]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -91,6 +89,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
ent_coef: float = 0.0,
|
ent_coef: float = 0.0,
|
||||||
vf_coef: float = 0.5,
|
vf_coef: float = 0.5,
|
||||||
max_grad_norm: float = 0.5,
|
max_grad_norm: float = 0.5,
|
||||||
|
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||||
|
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
target_kl: Optional[float] = None,
|
target_kl: Optional[float] = None,
|
||||||
stats_window_size: int = 100,
|
stats_window_size: int = 100,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
|
|
@ -101,7 +101,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
_init_setup_model: bool = True,
|
_init_setup_model: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy,
|
policy, # type: ignore[arg-type]
|
||||||
env,
|
env,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
n_steps=n_steps,
|
n_steps=n_steps,
|
||||||
|
|
@ -112,6 +112,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
use_sde=False,
|
use_sde=False,
|
||||||
sde_sample_freq=-1,
|
sde_sample_freq=-1,
|
||||||
|
rollout_buffer_class=rollout_buffer_class,
|
||||||
|
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||||
stats_window_size=stats_window_size,
|
stats_window_size=stats_window_size,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
|
|
@ -140,20 +142,24 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
self._setup_lr_schedule()
|
self._setup_lr_schedule()
|
||||||
self.set_random_seed(self.seed)
|
self.set_random_seed(self.seed)
|
||||||
|
|
||||||
buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer
|
self.policy = self.policy_class( # type: ignore[assignment]
|
||||||
|
|
||||||
self.policy = self.policy_class(
|
|
||||||
self.observation_space,
|
self.observation_space,
|
||||||
self.action_space,
|
self.action_space,
|
||||||
self.lr_schedule,
|
self.lr_schedule,
|
||||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
**self.policy_kwargs,
|
||||||
)
|
)
|
||||||
self.policy = self.policy.to(self.device)
|
self.policy = self.policy.to(self.device)
|
||||||
|
|
||||||
if not isinstance(self.policy, MaskableActorCriticPolicy):
|
if not isinstance(self.policy, MaskableActorCriticPolicy):
|
||||||
raise ValueError("Policy must subclass MaskableActorCriticPolicy")
|
raise ValueError("Policy must subclass MaskableActorCriticPolicy")
|
||||||
|
|
||||||
self.rollout_buffer = buffer_cls(
|
if self.rollout_buffer_class is None:
|
||||||
|
if isinstance(self.observation_space, spaces.Dict):
|
||||||
|
self.rollout_buffer_class = MaskableDictRolloutBuffer
|
||||||
|
else:
|
||||||
|
self.rollout_buffer_class = MaskableRolloutBuffer
|
||||||
|
|
||||||
|
self.rollout_buffer = self.rollout_buffer_class( # type: ignore[assignment]
|
||||||
self.n_steps,
|
self.n_steps,
|
||||||
self.observation_space,
|
self.observation_space,
|
||||||
self.action_space,
|
self.action_space,
|
||||||
|
|
@ -161,6 +167,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
gae_lambda=self.gae_lambda,
|
gae_lambda=self.gae_lambda,
|
||||||
n_envs=self.n_envs,
|
n_envs=self.n_envs,
|
||||||
|
**self.rollout_buffer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize schedules for policy/value clipping
|
# Initialize schedules for policy/value clipping
|
||||||
|
|
@ -171,86 +178,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||||
|
|
||||||
def _init_callback(
|
|
||||||
self,
|
|
||||||
callback: MaybeCallback,
|
|
||||||
use_masking: bool = True,
|
|
||||||
progress_bar: bool = False,
|
|
||||||
) -> BaseCallback:
|
|
||||||
"""
|
|
||||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
|
||||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
|
||||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
|
||||||
:return: A hybrid callback calling `callback` and performing evaluation.
|
|
||||||
"""
|
|
||||||
# Convert a list of callbacks into a callback
|
|
||||||
if isinstance(callback, list):
|
|
||||||
callback = CallbackList(callback)
|
|
||||||
|
|
||||||
# Convert functional callback to object
|
|
||||||
if not isinstance(callback, BaseCallback):
|
|
||||||
callback = ConvertCallback(callback)
|
|
||||||
|
|
||||||
# Add progress bar callback
|
|
||||||
if progress_bar:
|
|
||||||
callback = CallbackList([callback, ProgressBarCallback()])
|
|
||||||
|
|
||||||
callback.init_callback(self)
|
|
||||||
return callback
|
|
||||||
|
|
||||||
def _setup_learn(
|
|
||||||
self,
|
|
||||||
total_timesteps: int,
|
|
||||||
callback: MaybeCallback = None,
|
|
||||||
reset_num_timesteps: bool = True,
|
|
||||||
tb_log_name: str = "run",
|
|
||||||
use_masking: bool = True,
|
|
||||||
progress_bar: bool = False,
|
|
||||||
) -> Tuple[int, BaseCallback]:
|
|
||||||
"""
|
|
||||||
Initialize different variables needed for training.
|
|
||||||
|
|
||||||
:param total_timesteps: The total number of samples (env steps) to train on
|
|
||||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
|
||||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
|
||||||
:param tb_log_name: the name of the run for tensorboard log
|
|
||||||
:param use_masking: Whether or not to use invalid action masks during training
|
|
||||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.start_time = time.time_ns()
|
|
||||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
|
||||||
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
|
||||||
self.ep_info_buffer = deque(maxlen=self._stats_window_size)
|
|
||||||
self.ep_success_buffer = deque(maxlen=self._stats_window_size)
|
|
||||||
|
|
||||||
if reset_num_timesteps:
|
|
||||||
self.num_timesteps = 0
|
|
||||||
self._episode_num = 0
|
|
||||||
else:
|
|
||||||
# Make sure training timesteps are ahead of the internal counter
|
|
||||||
total_timesteps += self.num_timesteps
|
|
||||||
self._total_timesteps = total_timesteps
|
|
||||||
self._num_timesteps_at_start = self.num_timesteps
|
|
||||||
|
|
||||||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
|
||||||
if reset_num_timesteps or self._last_obs is None:
|
|
||||||
self._last_obs = self.env.reset()
|
|
||||||
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
|
|
||||||
# Retrieve unnormalized observation for saving into the buffer
|
|
||||||
if self._vec_normalize_env is not None:
|
|
||||||
self._last_original_obs = self._vec_normalize_env.get_original_obs()
|
|
||||||
|
|
||||||
# Configure logger's outputs if no logger was passed
|
|
||||||
if not self._custom_logger:
|
|
||||||
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
|
||||||
|
|
||||||
# Create eval callback if needed
|
|
||||||
callback = self._init_callback(callback, use_masking, progress_bar)
|
|
||||||
|
|
||||||
return total_timesteps, callback
|
|
||||||
|
|
||||||
def collect_rollouts(
|
def collect_rollouts(
|
||||||
self,
|
self,
|
||||||
env: VecEnv,
|
env: VecEnv,
|
||||||
|
|
@ -312,7 +239,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
if not callback.on_step():
|
if not callback.on_step():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._update_info_buffer(infos)
|
self._update_info_buffer(infos, dones)
|
||||||
n_steps += 1
|
n_steps += 1
|
||||||
|
|
||||||
if isinstance(self.action_space, spaces.Discrete):
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
|
|
@ -341,14 +268,14 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
log_probs,
|
log_probs,
|
||||||
action_masks=action_masks,
|
action_masks=action_masks,
|
||||||
)
|
)
|
||||||
self._last_obs = new_obs
|
self._last_obs = new_obs # type: ignore[assignment]
|
||||||
self._last_episode_starts = dones
|
self._last_episode_starts = dones
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Compute value for the last timestep
|
# Compute value for the last timestep
|
||||||
# Masking is not needed here, the choice of action doesn't matter.
|
# Masking is not needed here, the choice of action doesn't matter.
|
||||||
# We only want the value of the current observation.
|
# We only want the value of the current observation.
|
||||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
|
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
|
||||||
|
|
||||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||||
|
|
||||||
|
|
@ -356,9 +283,9 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def predict(
|
def predict( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
observation: np.ndarray,
|
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||||
episode_start: Optional[np.ndarray] = None,
|
episode_start: Optional[np.ndarray] = None,
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
|
|
@ -388,10 +315,10 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
# Update optimizer learning rate
|
# Update optimizer learning rate
|
||||||
self._update_learning_rate(self.policy.optimizer)
|
self._update_learning_rate(self.policy.optimizer)
|
||||||
# Compute current clip range
|
# Compute current clip range
|
||||||
clip_range = self.clip_range(self._current_progress_remaining)
|
clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
|
||||||
# Optional: clip range for the value function
|
# Optional: clip range for the value function
|
||||||
if self.clip_range_vf is not None:
|
if self.clip_range_vf is not None:
|
||||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
|
||||||
|
|
||||||
entropy_losses = []
|
entropy_losses = []
|
||||||
pg_losses, value_losses = [], []
|
pg_losses, value_losses = [], []
|
||||||
|
|
@ -499,7 +426,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
if self.clip_range_vf is not None:
|
if self.clip_range_vf is not None:
|
||||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||||
|
|
||||||
def learn(
|
def learn( # type: ignore[override]
|
||||||
self: SelfMaskablePPO,
|
self: SelfMaskablePPO,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
|
|
@ -516,12 +443,13 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
callback,
|
callback,
|
||||||
reset_num_timesteps,
|
reset_num_timesteps,
|
||||||
tb_log_name,
|
tb_log_name,
|
||||||
use_masking,
|
|
||||||
progress_bar,
|
progress_bar,
|
||||||
)
|
)
|
||||||
|
|
||||||
callback.on_training_start(locals(), globals())
|
callback.on_training_start(locals(), globals())
|
||||||
|
|
||||||
|
assert self.env is not None
|
||||||
|
|
||||||
while self.num_timesteps < total_timesteps:
|
while self.num_timesteps < total_timesteps:
|
||||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
|
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
|
||||||
|
|
||||||
|
|
@ -533,16 +461,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
# Display training infos
|
# Display training infos
|
||||||
if log_interval is not None and iteration % log_interval == 0:
|
if log_interval is not None and iteration % log_interval == 0:
|
||||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
self._dump_logs(iteration)
|
||||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
|
||||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
|
||||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
|
||||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
|
||||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
|
||||||
self.logger.record("time/fps", fps)
|
|
||||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
|
||||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
|
||||||
self.logger.dump(step=self.num_timesteps)
|
|
||||||
|
|
||||||
self.train()
|
self.train()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
|
|
@ -11,7 +9,7 @@ from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
|
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
|
||||||
from stable_baselines3.common.vec_env import VecEnv
|
from stable_baselines3.common.vec_env import VecEnv
|
||||||
|
|
||||||
from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer
|
from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer
|
||||||
|
|
@ -149,7 +147,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
self.action_space,
|
self.action_space,
|
||||||
self.lr_schedule,
|
self.lr_schedule,
|
||||||
use_sde=self.use_sde,
|
use_sde=self.use_sde,
|
||||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
**self.policy_kwargs,
|
||||||
)
|
)
|
||||||
self.policy = self.policy.to(self.device)
|
self.policy = self.policy.to(self.device)
|
||||||
|
|
||||||
|
|
@ -260,7 +258,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
if not callback.on_step():
|
if not callback.on_step():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._update_info_buffer(infos)
|
self._update_info_buffer(infos, dones)
|
||||||
n_steps += 1
|
n_steps += 1
|
||||||
|
|
||||||
if isinstance(self.action_space, spaces.Discrete):
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
|
|
@ -453,42 +451,11 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> SelfRecurrentPPO:
|
) -> SelfRecurrentPPO:
|
||||||
iteration = 0
|
return super().learn(
|
||||||
|
total_timesteps=total_timesteps,
|
||||||
total_timesteps, callback = self._setup_learn(
|
callback=callback,
|
||||||
total_timesteps,
|
log_interval=log_interval,
|
||||||
callback,
|
tb_log_name=tb_log_name,
|
||||||
reset_num_timesteps,
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
tb_log_name,
|
progress_bar=progress_bar,
|
||||||
progress_bar,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
callback.on_training_start(locals(), globals())
|
|
||||||
|
|
||||||
while self.num_timesteps < total_timesteps:
|
|
||||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
|
||||||
|
|
||||||
if not continue_training:
|
|
||||||
break
|
|
||||||
|
|
||||||
iteration += 1
|
|
||||||
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
|
||||||
|
|
||||||
# Display training infos
|
|
||||||
if log_interval is not None and iteration % log_interval == 0:
|
|
||||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
|
||||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
|
||||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
|
||||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
|
||||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
|
||||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
|
||||||
self.logger.record("time/fps", fps)
|
|
||||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
|
||||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
|
||||||
self.logger.dump(step=self.num_timesteps)
|
|
||||||
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
callback.on_training_end()
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,8 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
Quantile Regression Deep Q-Network (QR-DQN)
|
Quantile Regression Deep Q-Network (QR-DQN)
|
||||||
Paper: https://arxiv.org/abs/1710.10044
|
Paper: https://arxiv.org/abs/1710.10044
|
||||||
Default hyperparameters are taken from the paper and are tuned for Atari games.
|
Default hyperparameters are taken from the paper and are tuned for Atari games
|
||||||
|
(except for the ``learning_starts`` parameter).
|
||||||
|
|
||||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||||
|
|
@ -77,11 +78,11 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
env: Union[GymEnv, str],
|
env: Union[GymEnv, str],
|
||||||
learning_rate: Union[float, Schedule] = 5e-5,
|
learning_rate: Union[float, Schedule] = 5e-5,
|
||||||
buffer_size: int = 1000000, # 1e6
|
buffer_size: int = 1000000, # 1e6
|
||||||
learning_starts: int = 50000,
|
learning_starts: int = 100,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
tau: float = 1.0,
|
tau: float = 1.0,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
train_freq: int = 4,
|
train_freq: Union[int, Tuple[int, str]] = 4,
|
||||||
gradient_steps: int = 1,
|
gradient_steps: int = 1,
|
||||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
batch_size: int = 256,
|
batch_size: int = 256,
|
||||||
tau: float = 0.005,
|
tau: float = 0.005,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
train_freq: int = 1,
|
train_freq: Union[int, Tuple[int, str]] = 1,
|
||||||
gradient_steps: int = 1,
|
gradient_steps: int = 1,
|
||||||
action_noise: Optional[ActionNoise] = None,
|
action_noise: Optional[ActionNoise] = None,
|
||||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
2.2.0
|
2.3.0
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=2.2.0,<3.0",
|
"stable_baselines3>=2.3.0,<3.0",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue