Fix MaskablePPO type annotations (#233)
* Fix `sb3_contrib/common/maskable/callbacks.py` type annotations * Fix `sb3_contrib/common/vec_env/async_eval.py` type annotations * Fix `sb3_contrib/common/maskable/buffers.py` type hints * Fix `sb3_contrib/common/maskable/distributions.py` type hints * Fix `sb3_contrib/common/maskable/policies.py` type hints * Fix `sb3_contrib/ppo_mask/ppo_mask.py` type hints * Update changelog and fix type hints
This commit is contained in:
parent
cd31e89e26
commit
7dd6c39fba
|
|
@ -22,6 +22,7 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -31,8 +32,10 @@ Deprecations:
|
|||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl)
|
||||
- Fixed ``sb3_contrib/common/maskable/*.py`` type annotations
|
||||
- Fixed ``sb3_contrib/ppo_mask/ppo_mask.py`` type annotations
|
||||
- Fixed ``sb3_contrib/common/vec_env/async_eval.py`` type annotations
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -28,12 +28,6 @@ exclude = """(?x)(
|
|||
| sb3_contrib/ars/ars.py$
|
||||
| sb3_contrib/common/recurrent/policies.py$
|
||||
| sb3_contrib/common/recurrent/buffers.py$
|
||||
| sb3_contrib/common/maskable/distributions.py$
|
||||
| sb3_contrib/common/maskable/callbacks.py$
|
||||
| sb3_contrib/common/maskable/policies.py$
|
||||
| sb3_contrib/common/maskable/buffers.py$
|
||||
| sb3_contrib/common/vec_env/async_eval.py$
|
||||
| sb3_contrib/ppo_mask/ppo_mask.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
)"""
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class MaskableRolloutBufferSamples(NamedTuple):
|
|||
action_masks: th.Tensor
|
||||
|
||||
|
||||
class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples):
|
||||
class MaskableDictRolloutBufferSamples(NamedTuple):
|
||||
observations: TensorDict
|
||||
actions: th.Tensor
|
||||
old_values: th.Tensor
|
||||
|
|
@ -42,6 +42,8 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
action_masks: np.ndarray
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
|
|
@ -53,7 +55,6 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
n_envs: int = 1,
|
||||
):
|
||||
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
|
||||
self.action_masks = None
|
||||
|
||||
def reset(self) -> None:
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
|
|
@ -61,6 +62,10 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
||||
mask_dims = sum(self.action_space.nvec)
|
||||
elif isinstance(self.action_space, spaces.MultiBinary):
|
||||
assert isinstance(self.action_space.n, int), (
|
||||
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
|
||||
"You can flatten it instead."
|
||||
)
|
||||
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
||||
|
|
@ -79,7 +84,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: # type: ignore[override]
|
||||
assert self.full, ""
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
|
|
@ -105,7 +110,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
|||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: # type: ignore[override]
|
||||
data = (
|
||||
self.observations[batch_inds],
|
||||
self.actions[batch_inds],
|
||||
|
|
@ -143,17 +148,18 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
action_masks: np.ndarray
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "auto",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
):
|
||||
self.action_masks = None
|
||||
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
|
||||
|
||||
def reset(self) -> None:
|
||||
|
|
@ -162,6 +168,10 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
||||
mask_dims = sum(self.action_space.nvec)
|
||||
elif isinstance(self.action_space, spaces.MultiBinary):
|
||||
assert isinstance(self.action_space.n, int), (
|
||||
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
|
||||
"You can flatten it instead."
|
||||
)
|
||||
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
||||
|
|
@ -180,7 +190,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
|
||||
super().add(*args, **kwargs)
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: # type: ignore[override]
|
||||
assert self.full, ""
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
|
|
@ -203,7 +213,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
|||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: # type: ignore[override]
|
||||
return MaskableDictRolloutBufferSamples(
|
||||
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||
actions=self.to_torch(self.actions[batch_inds]),
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class MaskableEvalCallback(EvalCallback):
|
|||
|
||||
# Note that evaluate_policy() has been patched to support masking
|
||||
episode_rewards, episode_lengths = evaluate_policy(
|
||||
self.model,
|
||||
self.model, # type: ignore[arg-type]
|
||||
self.eval_env,
|
||||
n_eval_episodes=self.n_eval_episodes,
|
||||
render=self.render,
|
||||
|
|
@ -67,6 +67,8 @@ class MaskableEvalCallback(EvalCallback):
|
|||
)
|
||||
|
||||
if self.log_path is not None:
|
||||
assert isinstance(episode_rewards, list)
|
||||
assert isinstance(episode_lengths, list)
|
||||
self.evaluations_timesteps.append(self.num_timesteps)
|
||||
self.evaluations_results.append(episode_rewards)
|
||||
self.evaluations_length.append(episode_lengths)
|
||||
|
|
@ -87,7 +89,7 @@ class MaskableEvalCallback(EvalCallback):
|
|||
|
||||
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
||||
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
|
||||
self.last_mean_reward = mean_reward
|
||||
self.last_mean_reward = float(mean_reward)
|
||||
|
||||
if self.verbose > 0:
|
||||
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
|
||||
|
|
@ -111,7 +113,7 @@ class MaskableEvalCallback(EvalCallback):
|
|||
print("New best mean reward!")
|
||||
if self.best_model_save_path is not None:
|
||||
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
||||
self.best_mean_reward = mean_reward
|
||||
self.best_mean_reward = float(mean_reward)
|
||||
# Trigger callback on new best model, if needed
|
||||
if self.callback_on_new_best is not None:
|
||||
continue_training = self.callback_on_new_best.on_step()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple, TypeVar
|
||||
from typing import List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -13,6 +13,7 @@ SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistributi
|
|||
SelfMaskableMultiCategoricalDistribution = TypeVar(
|
||||
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
|
||||
)
|
||||
MaybeMasks = Union[th.Tensor, np.ndarray, None]
|
||||
|
||||
|
||||
class MaskableCategorical(Categorical):
|
||||
|
|
@ -36,14 +37,14 @@ class MaskableCategorical(Categorical):
|
|||
probs: Optional[th.Tensor] = None,
|
||||
logits: Optional[th.Tensor] = None,
|
||||
validate_args: Optional[bool] = None,
|
||||
masks: Optional[np.ndarray] = None,
|
||||
masks: MaybeMasks = None,
|
||||
):
|
||||
self.masks: Optional[th.Tensor] = None
|
||||
super().__init__(probs, logits, validate_args)
|
||||
self._original_logits = self.logits
|
||||
self.apply_masking(masks)
|
||||
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||
"""
|
||||
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
|
||||
|
||||
|
|
@ -84,7 +85,7 @@ class MaskableCategorical(Categorical):
|
|||
|
||||
class MaskableDistribution(Distribution, ABC):
|
||||
@abstractmethod
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||
"""
|
||||
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.
|
||||
|
||||
|
|
@ -94,6 +95,13 @@ class MaskableDistribution(Distribution, ABC):
|
|||
previously applied masking is removed, and the original logits are restored.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def proba_distribution_net(self, *args, **kwargs) -> nn.Module:
|
||||
"""Create the layers and parameters that represent the distribution.
|
||||
|
||||
Subclasses must define this, but the arguments and return type vary between
|
||||
concrete classes."""
|
||||
|
||||
|
||||
class MaskableCategoricalDistribution(MaskableDistribution):
|
||||
"""
|
||||
|
|
@ -154,7 +162,7 @@ class MaskableCategoricalDistribution(MaskableDistribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||
assert self.distribution is not None, "Must set distribution parameters"
|
||||
self.distribution.apply_masking(masks)
|
||||
|
||||
|
|
@ -192,7 +200,7 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
|||
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
|
||||
|
||||
self.distributions = [
|
||||
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1)
|
||||
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, list(self.action_dims), dim=1)
|
||||
]
|
||||
return self
|
||||
|
||||
|
|
@ -229,18 +237,16 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
||||
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||
assert len(self.distributions) > 0, "Must set distribution parameters"
|
||||
|
||||
split_masks = [None] * len(self.distributions)
|
||||
if masks is not None:
|
||||
masks = th.as_tensor(masks)
|
||||
|
||||
masks_tensor = th.as_tensor(masks)
|
||||
# Restructure shape to align with logits
|
||||
masks = masks.view(-1, sum(self.action_dims))
|
||||
|
||||
masks_tensor = masks_tensor.view(-1, sum(self.action_dims))
|
||||
# Then split columnwise for each discrete action
|
||||
split_masks = th.split(masks, tuple(self.action_dims), dim=1)
|
||||
split_masks = th.split(masks_tensor, list(self.action_dims), dim=1) # type: ignore[assignment]
|
||||
|
||||
for distribution, mask in zip(self.distributions, split_masks):
|
||||
distribution.apply_masking(mask)
|
||||
|
|
@ -268,10 +274,13 @@ def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistri
|
|||
"""
|
||||
|
||||
if isinstance(action_space, spaces.Discrete):
|
||||
return MaskableCategoricalDistribution(action_space.n)
|
||||
return MaskableCategoricalDistribution(int(action_space.n))
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
return MaskableMultiCategoricalDistribution(action_space.nvec)
|
||||
return MaskableMultiCategoricalDistribution(list(action_space.nvec))
|
||||
elif isinstance(action_space, spaces.MultiBinary):
|
||||
assert isinstance(
|
||||
action_space.n, int
|
||||
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
|
||||
return MaskableBernoulliDistribution(action_space.n)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import (
|
|||
MlpExtractor,
|
||||
NatureCNN,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||
from torch import nn
|
||||
|
||||
from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution
|
||||
|
|
@ -141,8 +141,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
log_prob = distribution.log_prob(actions)
|
||||
return actions, values, log_prob
|
||||
|
||||
def extract_features(
|
||||
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||
def extract_features( # type: ignore[override]
|
||||
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
|
@ -233,7 +233,11 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
module.apply(partial(self.init_weights, gain=gain))
|
||||
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
self.optimizer = self.optimizer_class(
|
||||
self.parameters(),
|
||||
lr=lr_schedule(1), # type: ignore[call-arg]
|
||||
**self.optimizer_kwargs,
|
||||
)
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
|
||||
"""
|
||||
|
|
@ -245,9 +249,9 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
action_logits = self.action_net(latent_pi)
|
||||
return self.action_dist.proba_distribution(action_logits=action_logits)
|
||||
|
||||
def _predict(
|
||||
def _predict( # type: ignore[override]
|
||||
self,
|
||||
observation: th.Tensor,
|
||||
observation: PyTorchObs,
|
||||
deterministic: bool = False,
|
||||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> th.Tensor:
|
||||
|
|
@ -284,35 +288,45 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.set_training_mode(False)
|
||||
|
||||
observation, vectorized_env = self.obs_to_tensor(observation)
|
||||
# Check for common mistake that the user does not mix Gym/VecEnv API
|
||||
# Tuple obs are not supported by SB3, so we can safely do that check
|
||||
if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
|
||||
raise ValueError(
|
||||
"You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
|
||||
"You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
|
||||
"vs `obs = vec_env.reset()` (SB3 VecEnv). "
|
||||
"See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
|
||||
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
|
||||
)
|
||||
|
||||
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
|
||||
|
||||
with th.no_grad():
|
||||
actions = self._predict(observation, deterministic=deterministic, action_masks=action_masks)
|
||||
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
|
||||
# Convert to numpy
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
if self.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.unscale_action(actions)
|
||||
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
|
||||
else:
|
||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
|
||||
|
||||
if not vectorized_env:
|
||||
if state is not None:
|
||||
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
|
||||
assert isinstance(actions, np.ndarray)
|
||||
actions = actions.squeeze(axis=0)
|
||||
|
||||
return actions, None
|
||||
return actions, state # type: ignore[return-value]
|
||||
|
||||
def evaluate_actions(
|
||||
self,
|
||||
obs: th.Tensor,
|
||||
actions: th.Tensor,
|
||||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
action_masks: Optional[th.Tensor] = None,
|
||||
) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
|
@ -337,7 +351,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
values = self.value_net(latent_vf)
|
||||
return values, log_prob, distribution.entropy()
|
||||
|
||||
def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
|
||||
def get_distribution(self, obs: PyTorchObs, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
|
||||
"""
|
||||
Get the current policy distribution given the observations.
|
||||
|
||||
|
|
@ -352,7 +366,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
distribution.apply_masking(action_masks)
|
||||
return distribution
|
||||
|
||||
def predict_values(self, obs: th.Tensor) -> th.Tensor:
|
||||
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
|
||||
"""
|
||||
Get the estimated values according to the current policy given the observations.
|
||||
|
||||
|
|
|
|||
|
|
@ -59,10 +59,11 @@ def _worker(
|
|||
remote.send(vec_env.seed(seed=data))
|
||||
elif cmd == "set_options":
|
||||
# Note: the options will only be effective at the next reset
|
||||
remote.send(vec_env.set_options(data))
|
||||
remote.send(vec_env.set_options(data)) # type: ignore[func-returns-value]
|
||||
elif cmd == "get_obs_rms":
|
||||
remote.send(obs_rms)
|
||||
elif cmd == "sync_obs_rms":
|
||||
assert vec_normalize is not None, "Tried to call `sync_obs_rms` when not using VecNormalize"
|
||||
vec_normalize.obs_rms = data
|
||||
obs_rms = data
|
||||
elif cmd == "close":
|
||||
|
|
@ -130,7 +131,7 @@ class AsyncEval:
|
|||
n_eval_episodes,
|
||||
)
|
||||
# daemon=True: if the main process crashes, we should not cause things to hang
|
||||
process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
|
||||
process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined]
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
work_remote.close()
|
||||
|
|
@ -157,6 +158,10 @@ class AsyncEval:
|
|||
:param seed: The seed for the pseudo-random generators.
|
||||
:return:
|
||||
"""
|
||||
if seed is None:
|
||||
# Do nothing
|
||||
return []
|
||||
|
||||
for idx, remote in enumerate(self.remotes):
|
||||
remote.send(("seed", seed + idx))
|
||||
return [remote.recv() for remote in self.remotes]
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common import utils
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
|
|
@ -74,6 +72,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
}
|
||||
policy: MaskableActorCriticPolicy # type: ignore[assignment]
|
||||
rollout_buffer: MaskableRolloutBuffer # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -91,6 +91,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
ent_coef: float = 0.0,
|
||||
vf_coef: float = 0.5,
|
||||
max_grad_norm: float = 0.5,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
target_kl: Optional[float] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
|
|
@ -101,7 +103,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
_init_setup_model: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
policy,
|
||||
policy, # type: ignore[arg-type]
|
||||
env,
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
|
|
@ -112,6 +114,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
max_grad_norm=max_grad_norm,
|
||||
use_sde=False,
|
||||
sde_sample_freq=-1,
|
||||
rollout_buffer_class=rollout_buffer_class,
|
||||
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||
stats_window_size=stats_window_size,
|
||||
tensorboard_log=tensorboard_log,
|
||||
policy_kwargs=policy_kwargs,
|
||||
|
|
@ -140,20 +144,24 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer
|
||||
|
||||
self.policy = self.policy_class(
|
||||
self.policy = self.policy_class( # type: ignore[assignment]
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.lr_schedule,
|
||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
||||
**self.policy_kwargs,
|
||||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
if not isinstance(self.policy, MaskableActorCriticPolicy):
|
||||
raise ValueError("Policy must subclass MaskableActorCriticPolicy")
|
||||
|
||||
self.rollout_buffer = buffer_cls(
|
||||
if self.rollout_buffer_class is None:
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
self.rollout_buffer_class = MaskableDictRolloutBuffer
|
||||
else:
|
||||
self.rollout_buffer_class = MaskableRolloutBuffer
|
||||
|
||||
self.rollout_buffer = self.rollout_buffer_class( # type: ignore[assignment]
|
||||
self.n_steps,
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
|
|
@ -161,6 +169,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
gamma=self.gamma,
|
||||
gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs,
|
||||
**self.rollout_buffer_kwargs,
|
||||
)
|
||||
|
||||
# Initialize schedules for policy/value clipping
|
||||
|
|
@ -171,86 +180,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
|
||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||
|
||||
def _init_callback(
|
||||
self,
|
||||
callback: MaybeCallback,
|
||||
use_masking: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> BaseCallback:
|
||||
"""
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||
:return: A hybrid callback calling `callback` and performing evaluation.
|
||||
"""
|
||||
# Convert a list of callbacks into a callback
|
||||
if isinstance(callback, list):
|
||||
callback = CallbackList(callback)
|
||||
|
||||
# Convert functional callback to object
|
||||
if not isinstance(callback, BaseCallback):
|
||||
callback = ConvertCallback(callback)
|
||||
|
||||
# Add progress bar callback
|
||||
if progress_bar:
|
||||
callback = CallbackList([callback, ProgressBarCallback()])
|
||||
|
||||
callback.init_callback(self)
|
||||
return callback
|
||||
|
||||
def _setup_learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
tb_log_name: str = "run",
|
||||
use_masking: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> Tuple[int, BaseCallback]:
|
||||
"""
|
||||
Initialize different variables needed for training.
|
||||
|
||||
:param total_timesteps: The total number of samples (env steps) to train on
|
||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
||||
:param tb_log_name: the name of the run for tensorboard log
|
||||
:param use_masking: Whether or not to use invalid action masks during training
|
||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||
:return:
|
||||
"""
|
||||
|
||||
self.start_time = time.time_ns()
|
||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
||||
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
||||
self.ep_info_buffer = deque(maxlen=self._stats_window_size)
|
||||
self.ep_success_buffer = deque(maxlen=self._stats_window_size)
|
||||
|
||||
if reset_num_timesteps:
|
||||
self.num_timesteps = 0
|
||||
self._episode_num = 0
|
||||
else:
|
||||
# Make sure training timesteps are ahead of the internal counter
|
||||
total_timesteps += self.num_timesteps
|
||||
self._total_timesteps = total_timesteps
|
||||
self._num_timesteps_at_start = self.num_timesteps
|
||||
|
||||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
||||
if reset_num_timesteps or self._last_obs is None:
|
||||
self._last_obs = self.env.reset()
|
||||
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
self._last_original_obs = self._vec_normalize_env.get_original_obs()
|
||||
|
||||
# Configure logger's outputs if no logger was passed
|
||||
if not self._custom_logger:
|
||||
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
||||
|
||||
# Create eval callback if needed
|
||||
callback = self._init_callback(callback, use_masking, progress_bar)
|
||||
|
||||
return total_timesteps, callback
|
||||
|
||||
def collect_rollouts(
|
||||
self,
|
||||
env: VecEnv,
|
||||
|
|
@ -341,14 +270,14 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
log_probs,
|
||||
action_masks=action_masks,
|
||||
)
|
||||
self._last_obs = new_obs
|
||||
self._last_obs = new_obs # type: ignore[assignment]
|
||||
self._last_episode_starts = dones
|
||||
|
||||
with th.no_grad():
|
||||
# Compute value for the last timestep
|
||||
# Masking is not needed here, the choice of action doesn't matter.
|
||||
# We only want the value of the current observation.
|
||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
|
||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||
|
||||
|
|
@ -356,9 +285,9 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
|
||||
return True
|
||||
|
||||
def predict(
|
||||
def predict( # type: ignore[override]
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
|
|
@ -388,10 +317,10 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# Compute current clip range
|
||||
clip_range = self.clip_range(self._current_progress_remaining)
|
||||
clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
|
||||
# Optional: clip range for the value function
|
||||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
|
||||
|
||||
entropy_losses = []
|
||||
pg_losses, value_losses = [], []
|
||||
|
|
@ -499,7 +428,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
if self.clip_range_vf is not None:
|
||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
def learn( # type: ignore[override]
|
||||
self: SelfMaskablePPO,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
|
|
@ -516,12 +445,13 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
callback,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
use_masking,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
assert self.env is not None
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
|
||||
|
||||
|
|
@ -533,6 +463,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
assert self.ep_info_buffer is not None
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
|||
self.action_space,
|
||||
self.lr_schedule,
|
||||
use_sde=self.use_sde,
|
||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
||||
**self.policy_kwargs,
|
||||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue