Fix MaskablePPO type annotations (#233)

* Fix `sb3_contrib/common/maskable/callbacks.py` type annotations

* Fix `sb3_contrib/common/vec_env/async_eval.py` type annotations

* Fix `sb3_contrib/common/maskable/buffers.py` type hints

* Fix `sb3_contrib/common/maskable/distributions.py` type hints

* Fix `sb3_contrib/common/maskable/policies.py` type hints

* Fix `sb3_contrib/ppo_mask/ppo_mask.py` type hints

* Update changelog and fix type hints
This commit is contained in:
Antonin RAFFIN 2024-03-11 14:10:12 +01:00 committed by Andreas Schaarschmidt
parent cd31e89e26
commit 7dd6c39fba
9 changed files with 117 additions and 149 deletions

View File

@ -22,6 +22,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO
Bug Fixes:
^^^^^^^^^^
@ -31,8 +32,10 @@ Deprecations:
Others:
^^^^^^^
- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl)
- Fixed ``sb3_contrib/common/maskable/*.py`` type annotations
- Fixed ``sb3_contrib/ppo_mask/ppo_mask.py`` type annotations
- Fixed ``sb3_contrib/common/vec_env/async_eval.py`` type annotations
Documentation:
^^^^^^^^^^^^^^

View File

@ -28,12 +28,6 @@ exclude = """(?x)(
| sb3_contrib/ars/ars.py$
| sb3_contrib/common/recurrent/policies.py$
| sb3_contrib/common/recurrent/buffers.py$
| sb3_contrib/common/maskable/distributions.py$
| sb3_contrib/common/maskable/callbacks.py$
| sb3_contrib/common/maskable/policies.py$
| sb3_contrib/common/maskable/buffers.py$
| sb3_contrib/common/vec_env/async_eval.py$
| sb3_contrib/ppo_mask/ppo_mask.py$
| tests/test_train_eval_mode.py$
)"""

View File

@ -18,7 +18,7 @@ class MaskableRolloutBufferSamples(NamedTuple):
action_masks: th.Tensor
class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples):
class MaskableDictRolloutBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
old_values: th.Tensor
@ -42,6 +42,8 @@ class MaskableRolloutBuffer(RolloutBuffer):
:param n_envs: Number of parallel environments
"""
action_masks: np.ndarray
def __init__(
self,
buffer_size: int,
@ -53,7 +55,6 @@ class MaskableRolloutBuffer(RolloutBuffer):
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
self.action_masks = None
def reset(self) -> None:
if isinstance(self.action_space, spaces.Discrete):
@ -61,6 +62,10 @@ class MaskableRolloutBuffer(RolloutBuffer):
elif isinstance(self.action_space, spaces.MultiDiscrete):
mask_dims = sum(self.action_space.nvec)
elif isinstance(self.action_space, spaces.MultiBinary):
assert isinstance(self.action_space.n, int), (
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
"You can flatten it instead."
)
mask_dims = 2 * self.action_space.n # One mask per binary outcome
else:
raise ValueError(f"Unsupported action space {type(self.action_space)}")
@ -79,7 +84,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: # type: ignore[override]
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
@ -105,7 +110,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: # type: ignore[override]
data = (
self.observations[batch_inds],
self.actions[batch_inds],
@ -143,17 +148,18 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
:param n_envs: Number of parallel environments
"""
action_masks: np.ndarray
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
self.action_masks = None
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
def reset(self) -> None:
@ -162,6 +168,10 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
elif isinstance(self.action_space, spaces.MultiDiscrete):
mask_dims = sum(self.action_space.nvec)
elif isinstance(self.action_space, spaces.MultiBinary):
assert isinstance(self.action_space.n, int), (
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
"You can flatten it instead."
)
mask_dims = 2 * self.action_space.n # One mask per binary outcome
else:
raise ValueError(f"Unsupported action space {type(self.action_space)}")
@ -180,7 +190,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: # type: ignore[override]
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
@ -203,7 +213,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: # type: ignore[override]
return MaskableDictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),

View File

@ -55,7 +55,7 @@ class MaskableEvalCallback(EvalCallback):
# Note that evaluate_policy() has been patched to support masking
episode_rewards, episode_lengths = evaluate_policy(
self.model,
self.model, # type: ignore[arg-type]
self.eval_env,
n_eval_episodes=self.n_eval_episodes,
render=self.render,
@ -67,6 +67,8 @@ class MaskableEvalCallback(EvalCallback):
)
if self.log_path is not None:
assert isinstance(episode_rewards, list)
assert isinstance(episode_lengths, list)
self.evaluations_timesteps.append(self.num_timesteps)
self.evaluations_results.append(episode_rewards)
self.evaluations_length.append(episode_lengths)
@ -87,7 +89,7 @@ class MaskableEvalCallback(EvalCallback):
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
self.last_mean_reward = mean_reward
self.last_mean_reward = float(mean_reward)
if self.verbose > 0:
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
@ -111,7 +113,7 @@ class MaskableEvalCallback(EvalCallback):
print("New best mean reward!")
if self.best_model_save_path is not None:
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
self.best_mean_reward = mean_reward
self.best_mean_reward = float(mean_reward)
# Trigger callback on new best model, if needed
if self.callback_on_new_best is not None:
continue_training = self.callback_on_new_best.on_step()

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, TypeVar
from typing import List, Optional, Tuple, TypeVar, Union
import numpy as np
import torch as th
@ -13,6 +13,7 @@ SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistributi
SelfMaskableMultiCategoricalDistribution = TypeVar(
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
)
MaybeMasks = Union[th.Tensor, np.ndarray, None]
class MaskableCategorical(Categorical):
@ -36,14 +37,14 @@ class MaskableCategorical(Categorical):
probs: Optional[th.Tensor] = None,
logits: Optional[th.Tensor] = None,
validate_args: Optional[bool] = None,
masks: Optional[np.ndarray] = None,
masks: MaybeMasks = None,
):
self.masks: Optional[th.Tensor] = None
super().__init__(probs, logits, validate_args)
self._original_logits = self.logits
self.apply_masking(masks)
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
"""
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
@ -84,7 +85,7 @@ class MaskableCategorical(Categorical):
class MaskableDistribution(Distribution, ABC):
@abstractmethod
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
"""
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.
@ -94,6 +95,13 @@ class MaskableDistribution(Distribution, ABC):
previously applied masking is removed, and the original logits are restored.
"""
@abstractmethod
def proba_distribution_net(self, *args, **kwargs) -> nn.Module:
"""Create the layers and parameters that represent the distribution.
Subclasses must define this, but the arguments and return type vary between
concrete classes."""
class MaskableCategoricalDistribution(MaskableDistribution):
"""
@ -154,7 +162,7 @@ class MaskableCategoricalDistribution(MaskableDistribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
assert self.distribution is not None, "Must set distribution parameters"
self.distribution.apply_masking(masks)
@ -192,7 +200,7 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
self.distributions = [
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1)
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, list(self.action_dims), dim=1)
]
return self
@ -229,18 +237,16 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
def apply_masking(self, masks: MaybeMasks) -> None:
assert len(self.distributions) > 0, "Must set distribution parameters"
split_masks = [None] * len(self.distributions)
if masks is not None:
masks = th.as_tensor(masks)
masks_tensor = th.as_tensor(masks)
# Restructure shape to align with logits
masks = masks.view(-1, sum(self.action_dims))
masks_tensor = masks_tensor.view(-1, sum(self.action_dims))
# Then split columnwise for each discrete action
split_masks = th.split(masks, tuple(self.action_dims), dim=1)
split_masks = th.split(masks_tensor, list(self.action_dims), dim=1) # type: ignore[assignment]
for distribution, mask in zip(self.distributions, split_masks):
distribution.apply_masking(mask)
@ -268,10 +274,13 @@ def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistri
"""
if isinstance(action_space, spaces.Discrete):
return MaskableCategoricalDistribution(action_space.n)
return MaskableCategoricalDistribution(int(action_space.n))
elif isinstance(action_space, spaces.MultiDiscrete):
return MaskableMultiCategoricalDistribution(action_space.nvec)
return MaskableMultiCategoricalDistribution(list(action_space.nvec))
elif isinstance(action_space, spaces.MultiBinary):
assert isinstance(
action_space.n, int
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return MaskableBernoulliDistribution(action_space.n)
else:
raise NotImplementedError(

View File

@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import (
MlpExtractor,
NatureCNN,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from torch import nn
from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution
@ -141,8 +141,8 @@ class MaskableActorCriticPolicy(BasePolicy):
log_prob = distribution.log_prob(actions)
return actions, values, log_prob
def extract_features(
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
def extract_features( # type: ignore[override]
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
@ -233,7 +233,11 @@ class MaskableActorCriticPolicy(BasePolicy):
module.apply(partial(self.init_weights, gain=gain))
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.optimizer = self.optimizer_class(
self.parameters(),
lr=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
)
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
"""
@ -245,9 +249,9 @@ class MaskableActorCriticPolicy(BasePolicy):
action_logits = self.action_net(latent_pi)
return self.action_dist.proba_distribution(action_logits=action_logits)
def _predict(
def _predict( # type: ignore[override]
self,
observation: th.Tensor,
observation: PyTorchObs,
deterministic: bool = False,
action_masks: Optional[np.ndarray] = None,
) -> th.Tensor:
@ -284,35 +288,45 @@ class MaskableActorCriticPolicy(BasePolicy):
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)
observation, vectorized_env = self.obs_to_tensor(observation)
# Check for common mistake that the user does not mix Gym/VecEnv API
# Tuple obs are not supported by SB3, so we can safely do that check
if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
raise ValueError(
"You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
"You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
"vs `obs = vec_env.reset()` (SB3 VecEnv). "
"See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
)
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
with th.no_grad():
actions = self._predict(observation, deterministic=deterministic, action_masks=action_masks)
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
# Convert to numpy
actions = actions.cpu().numpy()
if isinstance(self.action_space, spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
assert isinstance(actions, np.ndarray)
actions = actions.squeeze(axis=0)
return actions, None
return actions, state # type: ignore[return-value]
def evaluate_actions(
self,
obs: th.Tensor,
actions: th.Tensor,
action_masks: Optional[np.ndarray] = None,
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
action_masks: Optional[th.Tensor] = None,
) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
"""
Evaluate actions according to the current policy,
given the observations.
@ -337,7 +351,7 @@ class MaskableActorCriticPolicy(BasePolicy):
values = self.value_net(latent_vf)
return values, log_prob, distribution.entropy()
def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
def get_distribution(self, obs: PyTorchObs, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
"""
Get the current policy distribution given the observations.
@ -352,7 +366,7 @@ class MaskableActorCriticPolicy(BasePolicy):
distribution.apply_masking(action_masks)
return distribution
def predict_values(self, obs: th.Tensor) -> th.Tensor:
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
"""
Get the estimated values according to the current policy given the observations.

View File

@ -59,10 +59,11 @@ def _worker(
remote.send(vec_env.seed(seed=data))
elif cmd == "set_options":
# Note: the options will only be effective at the next reset
remote.send(vec_env.set_options(data))
remote.send(vec_env.set_options(data)) # type: ignore[func-returns-value]
elif cmd == "get_obs_rms":
remote.send(obs_rms)
elif cmd == "sync_obs_rms":
assert vec_normalize is not None, "Tried to call `sync_obs_rms` when not using VecNormalize"
vec_normalize.obs_rms = data
obs_rms = data
elif cmd == "close":
@ -130,7 +131,7 @@ class AsyncEval:
n_eval_episodes,
)
# daemon=True: if the main process crashes, we should not cause things to hang
process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined]
process.start()
self.processes.append(process)
work_remote.close()
@ -157,6 +158,10 @@ class AsyncEval:
:param seed: The seed for the pseudo-random generators.
:return:
"""
if seed is None:
# Do nothing
return []
for idx, remote in enumerate(self.remotes):
remote.send(("seed", seed + idx))
return [remote.recv() for remote in self.remotes]

View File

@ -1,14 +1,12 @@
import sys
import time
from collections import deque
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@ -74,6 +72,8 @@ class MaskablePPO(OnPolicyAlgorithm):
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
policy: MaskableActorCriticPolicy # type: ignore[assignment]
rollout_buffer: MaskableRolloutBuffer # type: ignore[assignment]
def __init__(
self,
@ -91,6 +91,8 @@ class MaskablePPO(OnPolicyAlgorithm):
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
@ -101,7 +103,7 @@ class MaskablePPO(OnPolicyAlgorithm):
_init_setup_model: bool = True,
):
super().__init__(
policy,
policy, # type: ignore[arg-type]
env,
learning_rate=learning_rate,
n_steps=n_steps,
@ -112,6 +114,8 @@ class MaskablePPO(OnPolicyAlgorithm):
max_grad_norm=max_grad_norm,
use_sde=False,
sde_sample_freq=-1,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
@ -140,20 +144,24 @@ class MaskablePPO(OnPolicyAlgorithm):
self._setup_lr_schedule()
self.set_random_seed(self.seed)
buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer
self.policy = self.policy_class(
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space,
self.action_space,
self.lr_schedule,
**self.policy_kwargs, # pytype:disable=not-instantiable
**self.policy_kwargs,
)
self.policy = self.policy.to(self.device)
if not isinstance(self.policy, MaskableActorCriticPolicy):
raise ValueError("Policy must subclass MaskableActorCriticPolicy")
self.rollout_buffer = buffer_cls(
if self.rollout_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.rollout_buffer_class = MaskableDictRolloutBuffer
else:
self.rollout_buffer_class = MaskableRolloutBuffer
self.rollout_buffer = self.rollout_buffer_class( # type: ignore[assignment]
self.n_steps,
self.observation_space,
self.action_space,
@ -161,6 +169,7 @@ class MaskablePPO(OnPolicyAlgorithm):
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
**self.rollout_buffer_kwargs,
)
# Initialize schedules for policy/value clipping
@ -171,86 +180,6 @@ class MaskablePPO(OnPolicyAlgorithm):
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
def _init_callback(
self,
callback: MaybeCallback,
use_masking: bool = True,
progress_bar: bool = False,
) -> BaseCallback:
"""
:param callback: Callback(s) called at every step with state of the algorithm.
:param use_masking: Whether or not to use invalid action masks during evaluation
:param progress_bar: Display a progress bar using tqdm and rich.
:return: A hybrid callback calling `callback` and performing evaluation.
"""
# Convert a list of callbacks into a callback
if isinstance(callback, list):
callback = CallbackList(callback)
# Convert functional callback to object
if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback)
# Add progress bar callback
if progress_bar:
callback = CallbackList([callback, ProgressBarCallback()])
callback.init_callback(self)
return callback
def _setup_learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
use_masking: bool = True,
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on
:param callback: Callback(s) called at every step with state of the algorithm.
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:param use_masking: Whether or not to use invalid action masks during training
:param progress_bar: Display a progress bar using tqdm and rich.
:return:
"""
self.start_time = time.time_ns()
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
self.ep_info_buffer = deque(maxlen=self._stats_window_size)
self.ep_success_buffer = deque(maxlen=self._stats_window_size)
if reset_num_timesteps:
self.num_timesteps = 0
self._episode_num = 0
else:
# Make sure training timesteps are ahead of the internal counter
total_timesteps += self.num_timesteps
self._total_timesteps = total_timesteps
self._num_timesteps_at_start = self.num_timesteps
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset()
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
# Configure logger's outputs if no logger was passed
if not self._custom_logger:
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
# Create eval callback if needed
callback = self._init_callback(callback, use_masking, progress_bar)
return total_timesteps, callback
def collect_rollouts(
self,
env: VecEnv,
@ -341,14 +270,14 @@ class MaskablePPO(OnPolicyAlgorithm):
log_probs,
action_masks=action_masks,
)
self._last_obs = new_obs
self._last_obs = new_obs # type: ignore[assignment]
self._last_episode_starts = dones
with th.no_grad():
# Compute value for the last timestep
# Masking is not needed here, the choice of action doesn't matter.
# We only want the value of the current observation.
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
@ -356,9 +285,9 @@ class MaskablePPO(OnPolicyAlgorithm):
return True
def predict(
def predict( # type: ignore[override]
self,
observation: np.ndarray,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
@ -388,10 +317,10 @@ class MaskablePPO(OnPolicyAlgorithm):
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# Compute current clip range
clip_range = self.clip_range(self._current_progress_remaining)
clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
entropy_losses = []
pg_losses, value_losses = [], []
@ -499,7 +428,7 @@ class MaskablePPO(OnPolicyAlgorithm):
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
def learn( # type: ignore[override]
self: SelfMaskablePPO,
total_timesteps: int,
callback: MaybeCallback = None,
@ -516,12 +445,13 @@ class MaskablePPO(OnPolicyAlgorithm):
callback,
reset_num_timesteps,
tb_log_name,
use_masking,
progress_bar,
)
callback.on_training_start(locals(), globals())
assert self.env is not None
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
@ -533,6 +463,7 @@ class MaskablePPO(OnPolicyAlgorithm):
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")

View File

@ -149,7 +149,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
self.action_space,
self.lr_schedule,
use_sde=self.use_sde,
**self.policy_kwargs, # pytype:disable=not-instantiable
**self.policy_kwargs,
)
self.policy = self.policy.to(self.device)