Fix MaskablePPO type annotations (#233)
* Fix `sb3_contrib/common/maskable/callbacks.py` type annotations * Fix `sb3_contrib/common/vec_env/async_eval.py` type annotations * Fix `sb3_contrib/common/maskable/buffers.py` type hints * Fix `sb3_contrib/common/maskable/distributions.py` type hints * Fix `sb3_contrib/common/maskable/policies.py` type hints * Fix `sb3_contrib/ppo_mask/ppo_mask.py` type hints * Update changelog and fix type hints
This commit is contained in:
parent
cd31e89e26
commit
7dd6c39fba
|
|
@ -22,6 +22,7 @@ Breaking Changes:
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
@ -31,8 +32,10 @@ Deprecations:
|
||||||
|
|
||||||
Others:
|
Others:
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl)
|
- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl)
|
||||||
|
- Fixed ``sb3_contrib/common/maskable/*.py`` type annotations
|
||||||
|
- Fixed ``sb3_contrib/ppo_mask/ppo_mask.py`` type annotations
|
||||||
|
- Fixed ``sb3_contrib/common/vec_env/async_eval.py`` type annotations
|
||||||
|
|
||||||
Documentation:
|
Documentation:
|
||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -28,12 +28,6 @@ exclude = """(?x)(
|
||||||
| sb3_contrib/ars/ars.py$
|
| sb3_contrib/ars/ars.py$
|
||||||
| sb3_contrib/common/recurrent/policies.py$
|
| sb3_contrib/common/recurrent/policies.py$
|
||||||
| sb3_contrib/common/recurrent/buffers.py$
|
| sb3_contrib/common/recurrent/buffers.py$
|
||||||
| sb3_contrib/common/maskable/distributions.py$
|
|
||||||
| sb3_contrib/common/maskable/callbacks.py$
|
|
||||||
| sb3_contrib/common/maskable/policies.py$
|
|
||||||
| sb3_contrib/common/maskable/buffers.py$
|
|
||||||
| sb3_contrib/common/vec_env/async_eval.py$
|
|
||||||
| sb3_contrib/ppo_mask/ppo_mask.py$
|
|
||||||
| tests/test_train_eval_mode.py$
|
| tests/test_train_eval_mode.py$
|
||||||
)"""
|
)"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ class MaskableRolloutBufferSamples(NamedTuple):
|
||||||
action_masks: th.Tensor
|
action_masks: th.Tensor
|
||||||
|
|
||||||
|
|
||||||
class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples):
|
class MaskableDictRolloutBufferSamples(NamedTuple):
|
||||||
observations: TensorDict
|
observations: TensorDict
|
||||||
actions: th.Tensor
|
actions: th.Tensor
|
||||||
old_values: th.Tensor
|
old_values: th.Tensor
|
||||||
|
|
@ -42,6 +42,8 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
||||||
:param n_envs: Number of parallel environments
|
:param n_envs: Number of parallel environments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
action_masks: np.ndarray
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
buffer_size: int,
|
buffer_size: int,
|
||||||
|
|
@ -53,7 +55,6 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
||||||
n_envs: int = 1,
|
n_envs: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
|
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
|
||||||
self.action_masks = None
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
if isinstance(self.action_space, spaces.Discrete):
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
|
|
@ -61,6 +62,10 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
||||||
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
||||||
mask_dims = sum(self.action_space.nvec)
|
mask_dims = sum(self.action_space.nvec)
|
||||||
elif isinstance(self.action_space, spaces.MultiBinary):
|
elif isinstance(self.action_space, spaces.MultiBinary):
|
||||||
|
assert isinstance(self.action_space.n, int), (
|
||||||
|
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
|
||||||
|
"You can flatten it instead."
|
||||||
|
)
|
||||||
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
||||||
|
|
@ -79,7 +84,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
||||||
|
|
||||||
super().add(*args, **kwargs)
|
super().add(*args, **kwargs)
|
||||||
|
|
||||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
|
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: # type: ignore[override]
|
||||||
assert self.full, ""
|
assert self.full, ""
|
||||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
|
|
@ -105,7 +110,7 @@ class MaskableRolloutBuffer(RolloutBuffer):
|
||||||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||||
start_idx += batch_size
|
start_idx += batch_size
|
||||||
|
|
||||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: # type: ignore[override]
|
||||||
data = (
|
data = (
|
||||||
self.observations[batch_inds],
|
self.observations[batch_inds],
|
||||||
self.actions[batch_inds],
|
self.actions[batch_inds],
|
||||||
|
|
@ -143,17 +148,18 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
:param n_envs: Number of parallel environments
|
:param n_envs: Number of parallel environments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
action_masks: np.ndarray
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
buffer_size: int,
|
buffer_size: int,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Dict,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
device: Union[th.device, str] = "auto",
|
device: Union[th.device, str] = "auto",
|
||||||
gae_lambda: float = 1,
|
gae_lambda: float = 1,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
n_envs: int = 1,
|
n_envs: int = 1,
|
||||||
):
|
):
|
||||||
self.action_masks = None
|
|
||||||
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
|
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
|
@ -162,6 +168,10 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
||||||
mask_dims = sum(self.action_space.nvec)
|
mask_dims = sum(self.action_space.nvec)
|
||||||
elif isinstance(self.action_space, spaces.MultiBinary):
|
elif isinstance(self.action_space, spaces.MultiBinary):
|
||||||
|
assert isinstance(self.action_space.n, int), (
|
||||||
|
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
|
||||||
|
"You can flatten it instead."
|
||||||
|
)
|
||||||
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
mask_dims = 2 * self.action_space.n # One mask per binary outcome
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
raise ValueError(f"Unsupported action space {type(self.action_space)}")
|
||||||
|
|
@ -180,7 +190,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
|
|
||||||
super().add(*args, **kwargs)
|
super().add(*args, **kwargs)
|
||||||
|
|
||||||
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
|
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: # type: ignore[override]
|
||||||
assert self.full, ""
|
assert self.full, ""
|
||||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
|
|
@ -203,7 +213,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
|
||||||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||||
start_idx += batch_size
|
start_idx += batch_size
|
||||||
|
|
||||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: # type: ignore[override]
|
||||||
return MaskableDictRolloutBufferSamples(
|
return MaskableDictRolloutBufferSamples(
|
||||||
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||||
actions=self.to_torch(self.actions[batch_inds]),
|
actions=self.to_torch(self.actions[batch_inds]),
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
|
|
||||||
# Note that evaluate_policy() has been patched to support masking
|
# Note that evaluate_policy() has been patched to support masking
|
||||||
episode_rewards, episode_lengths = evaluate_policy(
|
episode_rewards, episode_lengths = evaluate_policy(
|
||||||
self.model,
|
self.model, # type: ignore[arg-type]
|
||||||
self.eval_env,
|
self.eval_env,
|
||||||
n_eval_episodes=self.n_eval_episodes,
|
n_eval_episodes=self.n_eval_episodes,
|
||||||
render=self.render,
|
render=self.render,
|
||||||
|
|
@ -67,6 +67,8 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.log_path is not None:
|
if self.log_path is not None:
|
||||||
|
assert isinstance(episode_rewards, list)
|
||||||
|
assert isinstance(episode_lengths, list)
|
||||||
self.evaluations_timesteps.append(self.num_timesteps)
|
self.evaluations_timesteps.append(self.num_timesteps)
|
||||||
self.evaluations_results.append(episode_rewards)
|
self.evaluations_results.append(episode_rewards)
|
||||||
self.evaluations_length.append(episode_lengths)
|
self.evaluations_length.append(episode_lengths)
|
||||||
|
|
@ -87,7 +89,7 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
|
|
||||||
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
||||||
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
|
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
|
||||||
self.last_mean_reward = mean_reward
|
self.last_mean_reward = float(mean_reward)
|
||||||
|
|
||||||
if self.verbose > 0:
|
if self.verbose > 0:
|
||||||
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
|
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
|
||||||
|
|
@ -111,7 +113,7 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
print("New best mean reward!")
|
print("New best mean reward!")
|
||||||
if self.best_model_save_path is not None:
|
if self.best_model_save_path is not None:
|
||||||
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
||||||
self.best_mean_reward = mean_reward
|
self.best_mean_reward = float(mean_reward)
|
||||||
# Trigger callback on new best model, if needed
|
# Trigger callback on new best model, if needed
|
||||||
if self.callback_on_new_best is not None:
|
if self.callback_on_new_best is not None:
|
||||||
continue_training = self.callback_on_new_best.on_step()
|
continue_training = self.callback_on_new_best.on_step()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Tuple, TypeVar
|
from typing import List, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
|
|
@ -13,6 +13,7 @@ SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistributi
|
||||||
SelfMaskableMultiCategoricalDistribution = TypeVar(
|
SelfMaskableMultiCategoricalDistribution = TypeVar(
|
||||||
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
|
"SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution"
|
||||||
)
|
)
|
||||||
|
MaybeMasks = Union[th.Tensor, np.ndarray, None]
|
||||||
|
|
||||||
|
|
||||||
class MaskableCategorical(Categorical):
|
class MaskableCategorical(Categorical):
|
||||||
|
|
@ -36,14 +37,14 @@ class MaskableCategorical(Categorical):
|
||||||
probs: Optional[th.Tensor] = None,
|
probs: Optional[th.Tensor] = None,
|
||||||
logits: Optional[th.Tensor] = None,
|
logits: Optional[th.Tensor] = None,
|
||||||
validate_args: Optional[bool] = None,
|
validate_args: Optional[bool] = None,
|
||||||
masks: Optional[np.ndarray] = None,
|
masks: MaybeMasks = None,
|
||||||
):
|
):
|
||||||
self.masks: Optional[th.Tensor] = None
|
self.masks: Optional[th.Tensor] = None
|
||||||
super().__init__(probs, logits, validate_args)
|
super().__init__(probs, logits, validate_args)
|
||||||
self._original_logits = self.logits
|
self._original_logits = self.logits
|
||||||
self.apply_masking(masks)
|
self.apply_masking(masks)
|
||||||
|
|
||||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||||
"""
|
"""
|
||||||
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
|
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
|
||||||
|
|
||||||
|
|
@ -84,7 +85,7 @@ class MaskableCategorical(Categorical):
|
||||||
|
|
||||||
class MaskableDistribution(Distribution, ABC):
|
class MaskableDistribution(Distribution, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||||
"""
|
"""
|
||||||
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.
|
Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0.
|
||||||
|
|
||||||
|
|
@ -94,6 +95,13 @@ class MaskableDistribution(Distribution, ABC):
|
||||||
previously applied masking is removed, and the original logits are restored.
|
previously applied masking is removed, and the original logits are restored.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def proba_distribution_net(self, *args, **kwargs) -> nn.Module:
|
||||||
|
"""Create the layers and parameters that represent the distribution.
|
||||||
|
|
||||||
|
Subclasses must define this, but the arguments and return type vary between
|
||||||
|
concrete classes."""
|
||||||
|
|
||||||
|
|
||||||
class MaskableCategoricalDistribution(MaskableDistribution):
|
class MaskableCategoricalDistribution(MaskableDistribution):
|
||||||
"""
|
"""
|
||||||
|
|
@ -154,7 +162,7 @@ class MaskableCategoricalDistribution(MaskableDistribution):
|
||||||
log_prob = self.log_prob(actions)
|
log_prob = self.log_prob(actions)
|
||||||
return actions, log_prob
|
return actions, log_prob
|
||||||
|
|
||||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||||
assert self.distribution is not None, "Must set distribution parameters"
|
assert self.distribution is not None, "Must set distribution parameters"
|
||||||
self.distribution.apply_masking(masks)
|
self.distribution.apply_masking(masks)
|
||||||
|
|
||||||
|
|
@ -192,7 +200,7 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
||||||
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
|
reshaped_logits = action_logits.view(-1, sum(self.action_dims))
|
||||||
|
|
||||||
self.distributions = [
|
self.distributions = [
|
||||||
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1)
|
MaskableCategorical(logits=split) for split in th.split(reshaped_logits, list(self.action_dims), dim=1)
|
||||||
]
|
]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -229,18 +237,16 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution):
|
||||||
log_prob = self.log_prob(actions)
|
log_prob = self.log_prob(actions)
|
||||||
return actions, log_prob
|
return actions, log_prob
|
||||||
|
|
||||||
def apply_masking(self, masks: Optional[np.ndarray]) -> None:
|
def apply_masking(self, masks: MaybeMasks) -> None:
|
||||||
assert len(self.distributions) > 0, "Must set distribution parameters"
|
assert len(self.distributions) > 0, "Must set distribution parameters"
|
||||||
|
|
||||||
split_masks = [None] * len(self.distributions)
|
split_masks = [None] * len(self.distributions)
|
||||||
if masks is not None:
|
if masks is not None:
|
||||||
masks = th.as_tensor(masks)
|
masks_tensor = th.as_tensor(masks)
|
||||||
|
|
||||||
# Restructure shape to align with logits
|
# Restructure shape to align with logits
|
||||||
masks = masks.view(-1, sum(self.action_dims))
|
masks_tensor = masks_tensor.view(-1, sum(self.action_dims))
|
||||||
|
|
||||||
# Then split columnwise for each discrete action
|
# Then split columnwise for each discrete action
|
||||||
split_masks = th.split(masks, tuple(self.action_dims), dim=1)
|
split_masks = th.split(masks_tensor, list(self.action_dims), dim=1) # type: ignore[assignment]
|
||||||
|
|
||||||
for distribution, mask in zip(self.distributions, split_masks):
|
for distribution, mask in zip(self.distributions, split_masks):
|
||||||
distribution.apply_masking(mask)
|
distribution.apply_masking(mask)
|
||||||
|
|
@ -268,10 +274,13 @@ def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistri
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(action_space, spaces.Discrete):
|
if isinstance(action_space, spaces.Discrete):
|
||||||
return MaskableCategoricalDistribution(action_space.n)
|
return MaskableCategoricalDistribution(int(action_space.n))
|
||||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||||
return MaskableMultiCategoricalDistribution(action_space.nvec)
|
return MaskableMultiCategoricalDistribution(list(action_space.nvec))
|
||||||
elif isinstance(action_space, spaces.MultiBinary):
|
elif isinstance(action_space, spaces.MultiBinary):
|
||||||
|
assert isinstance(
|
||||||
|
action_space.n, int
|
||||||
|
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
|
||||||
return MaskableBernoulliDistribution(action_space.n)
|
return MaskableBernoulliDistribution(action_space.n)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import (
|
||||||
MlpExtractor,
|
MlpExtractor,
|
||||||
NatureCNN,
|
NatureCNN,
|
||||||
)
|
)
|
||||||
from stable_baselines3.common.type_aliases import Schedule
|
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution
|
from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution
|
||||||
|
|
@ -141,8 +141,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
log_prob = distribution.log_prob(actions)
|
log_prob = distribution.log_prob(actions)
|
||||||
return actions, values, log_prob
|
return actions, values, log_prob
|
||||||
|
|
||||||
def extract_features(
|
def extract_features( # type: ignore[override]
|
||||||
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
|
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||||
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Preprocess the observation if needed and extract features.
|
Preprocess the observation if needed and extract features.
|
||||||
|
|
@ -233,7 +233,11 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
module.apply(partial(self.init_weights, gain=gain))
|
module.apply(partial(self.init_weights, gain=gain))
|
||||||
|
|
||||||
# Setup optimizer with initial learning rate
|
# Setup optimizer with initial learning rate
|
||||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.optimizer = self.optimizer_class(
|
||||||
|
self.parameters(),
|
||||||
|
lr=lr_schedule(1), # type: ignore[call-arg]
|
||||||
|
**self.optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
|
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
|
||||||
"""
|
"""
|
||||||
|
|
@ -245,9 +249,9 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
action_logits = self.action_net(latent_pi)
|
action_logits = self.action_net(latent_pi)
|
||||||
return self.action_dist.proba_distribution(action_logits=action_logits)
|
return self.action_dist.proba_distribution(action_logits=action_logits)
|
||||||
|
|
||||||
def _predict(
|
def _predict( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
observation: th.Tensor,
|
observation: PyTorchObs,
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
action_masks: Optional[np.ndarray] = None,
|
action_masks: Optional[np.ndarray] = None,
|
||||||
) -> th.Tensor:
|
) -> th.Tensor:
|
||||||
|
|
@ -284,35 +288,45 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
# Switch to eval mode (this affects batch norm / dropout)
|
# Switch to eval mode (this affects batch norm / dropout)
|
||||||
self.set_training_mode(False)
|
self.set_training_mode(False)
|
||||||
|
|
||||||
observation, vectorized_env = self.obs_to_tensor(observation)
|
# Check for common mistake that the user does not mix Gym/VecEnv API
|
||||||
|
# Tuple obs are not supported by SB3, so we can safely do that check
|
||||||
|
if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
|
||||||
|
raise ValueError(
|
||||||
|
"You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
|
||||||
|
"You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
|
||||||
|
"vs `obs = vec_env.reset()` (SB3 VecEnv). "
|
||||||
|
"See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
|
||||||
|
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
|
||||||
|
)
|
||||||
|
|
||||||
|
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
actions = self._predict(observation, deterministic=deterministic, action_masks=action_masks)
|
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
|
||||||
# Convert to numpy
|
# Convert to numpy
|
||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
if isinstance(self.action_space, spaces.Box):
|
if isinstance(self.action_space, spaces.Box):
|
||||||
if self.squash_output:
|
if self.squash_output:
|
||||||
# Rescale to proper domain when using squashing
|
# Rescale to proper domain when using squashing
|
||||||
actions = self.unscale_action(actions)
|
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
|
||||||
else:
|
else:
|
||||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||||
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
|
||||||
|
|
||||||
if not vectorized_env:
|
if not vectorized_env:
|
||||||
if state is not None:
|
assert isinstance(actions, np.ndarray)
|
||||||
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
|
|
||||||
actions = actions.squeeze(axis=0)
|
actions = actions.squeeze(axis=0)
|
||||||
|
|
||||||
return actions, None
|
return actions, state # type: ignore[return-value]
|
||||||
|
|
||||||
def evaluate_actions(
|
def evaluate_actions(
|
||||||
self,
|
self,
|
||||||
obs: th.Tensor,
|
obs: th.Tensor,
|
||||||
actions: th.Tensor,
|
actions: th.Tensor,
|
||||||
action_masks: Optional[np.ndarray] = None,
|
action_masks: Optional[th.Tensor] = None,
|
||||||
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Evaluate actions according to the current policy,
|
Evaluate actions according to the current policy,
|
||||||
given the observations.
|
given the observations.
|
||||||
|
|
@ -337,7 +351,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
values = self.value_net(latent_vf)
|
values = self.value_net(latent_vf)
|
||||||
return values, log_prob, distribution.entropy()
|
return values, log_prob, distribution.entropy()
|
||||||
|
|
||||||
def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
|
def get_distribution(self, obs: PyTorchObs, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
|
||||||
"""
|
"""
|
||||||
Get the current policy distribution given the observations.
|
Get the current policy distribution given the observations.
|
||||||
|
|
||||||
|
|
@ -352,7 +366,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
distribution.apply_masking(action_masks)
|
distribution.apply_masking(action_masks)
|
||||||
return distribution
|
return distribution
|
||||||
|
|
||||||
def predict_values(self, obs: th.Tensor) -> th.Tensor:
|
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
|
||||||
"""
|
"""
|
||||||
Get the estimated values according to the current policy given the observations.
|
Get the estimated values according to the current policy given the observations.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,10 +59,11 @@ def _worker(
|
||||||
remote.send(vec_env.seed(seed=data))
|
remote.send(vec_env.seed(seed=data))
|
||||||
elif cmd == "set_options":
|
elif cmd == "set_options":
|
||||||
# Note: the options will only be effective at the next reset
|
# Note: the options will only be effective at the next reset
|
||||||
remote.send(vec_env.set_options(data))
|
remote.send(vec_env.set_options(data)) # type: ignore[func-returns-value]
|
||||||
elif cmd == "get_obs_rms":
|
elif cmd == "get_obs_rms":
|
||||||
remote.send(obs_rms)
|
remote.send(obs_rms)
|
||||||
elif cmd == "sync_obs_rms":
|
elif cmd == "sync_obs_rms":
|
||||||
|
assert vec_normalize is not None, "Tried to call `sync_obs_rms` when not using VecNormalize"
|
||||||
vec_normalize.obs_rms = data
|
vec_normalize.obs_rms = data
|
||||||
obs_rms = data
|
obs_rms = data
|
||||||
elif cmd == "close":
|
elif cmd == "close":
|
||||||
|
|
@ -130,7 +131,7 @@ class AsyncEval:
|
||||||
n_eval_episodes,
|
n_eval_episodes,
|
||||||
)
|
)
|
||||||
# daemon=True: if the main process crashes, we should not cause things to hang
|
# daemon=True: if the main process crashes, we should not cause things to hang
|
||||||
process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
|
process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined]
|
||||||
process.start()
|
process.start()
|
||||||
self.processes.append(process)
|
self.processes.append(process)
|
||||||
work_remote.close()
|
work_remote.close()
|
||||||
|
|
@ -157,6 +158,10 @@ class AsyncEval:
|
||||||
:param seed: The seed for the pseudo-random generators.
|
:param seed: The seed for the pseudo-random generators.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
if seed is None:
|
||||||
|
# Do nothing
|
||||||
|
return []
|
||||||
|
|
||||||
for idx, remote in enumerate(self.remotes):
|
for idx, remote in enumerate(self.remotes):
|
||||||
remote.send(("seed", seed + idx))
|
remote.send(("seed", seed + idx))
|
||||||
return [remote.recv() for remote in self.remotes]
|
return [remote.recv() for remote in self.remotes]
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,12 @@
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
|
||||||
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common import utils
|
|
||||||
from stable_baselines3.common.buffers import RolloutBuffer
|
from stable_baselines3.common.buffers import RolloutBuffer
|
||||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
|
|
@ -74,6 +72,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
"CnnPolicy": CnnPolicy,
|
"CnnPolicy": CnnPolicy,
|
||||||
"MultiInputPolicy": MultiInputPolicy,
|
"MultiInputPolicy": MultiInputPolicy,
|
||||||
}
|
}
|
||||||
|
policy: MaskableActorCriticPolicy # type: ignore[assignment]
|
||||||
|
rollout_buffer: MaskableRolloutBuffer # type: ignore[assignment]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -91,6 +91,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
ent_coef: float = 0.0,
|
ent_coef: float = 0.0,
|
||||||
vf_coef: float = 0.5,
|
vf_coef: float = 0.5,
|
||||||
max_grad_norm: float = 0.5,
|
max_grad_norm: float = 0.5,
|
||||||
|
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||||
|
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
target_kl: Optional[float] = None,
|
target_kl: Optional[float] = None,
|
||||||
stats_window_size: int = 100,
|
stats_window_size: int = 100,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
|
|
@ -101,7 +103,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
_init_setup_model: bool = True,
|
_init_setup_model: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy,
|
policy, # type: ignore[arg-type]
|
||||||
env,
|
env,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
n_steps=n_steps,
|
n_steps=n_steps,
|
||||||
|
|
@ -112,6 +114,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
use_sde=False,
|
use_sde=False,
|
||||||
sde_sample_freq=-1,
|
sde_sample_freq=-1,
|
||||||
|
rollout_buffer_class=rollout_buffer_class,
|
||||||
|
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||||
stats_window_size=stats_window_size,
|
stats_window_size=stats_window_size,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
|
|
@ -140,20 +144,24 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
self._setup_lr_schedule()
|
self._setup_lr_schedule()
|
||||||
self.set_random_seed(self.seed)
|
self.set_random_seed(self.seed)
|
||||||
|
|
||||||
buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer
|
self.policy = self.policy_class( # type: ignore[assignment]
|
||||||
|
|
||||||
self.policy = self.policy_class(
|
|
||||||
self.observation_space,
|
self.observation_space,
|
||||||
self.action_space,
|
self.action_space,
|
||||||
self.lr_schedule,
|
self.lr_schedule,
|
||||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
**self.policy_kwargs,
|
||||||
)
|
)
|
||||||
self.policy = self.policy.to(self.device)
|
self.policy = self.policy.to(self.device)
|
||||||
|
|
||||||
if not isinstance(self.policy, MaskableActorCriticPolicy):
|
if not isinstance(self.policy, MaskableActorCriticPolicy):
|
||||||
raise ValueError("Policy must subclass MaskableActorCriticPolicy")
|
raise ValueError("Policy must subclass MaskableActorCriticPolicy")
|
||||||
|
|
||||||
self.rollout_buffer = buffer_cls(
|
if self.rollout_buffer_class is None:
|
||||||
|
if isinstance(self.observation_space, spaces.Dict):
|
||||||
|
self.rollout_buffer_class = MaskableDictRolloutBuffer
|
||||||
|
else:
|
||||||
|
self.rollout_buffer_class = MaskableRolloutBuffer
|
||||||
|
|
||||||
|
self.rollout_buffer = self.rollout_buffer_class( # type: ignore[assignment]
|
||||||
self.n_steps,
|
self.n_steps,
|
||||||
self.observation_space,
|
self.observation_space,
|
||||||
self.action_space,
|
self.action_space,
|
||||||
|
|
@ -161,6 +169,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
gae_lambda=self.gae_lambda,
|
gae_lambda=self.gae_lambda,
|
||||||
n_envs=self.n_envs,
|
n_envs=self.n_envs,
|
||||||
|
**self.rollout_buffer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize schedules for policy/value clipping
|
# Initialize schedules for policy/value clipping
|
||||||
|
|
@ -171,86 +180,6 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||||
|
|
||||||
def _init_callback(
|
|
||||||
self,
|
|
||||||
callback: MaybeCallback,
|
|
||||||
use_masking: bool = True,
|
|
||||||
progress_bar: bool = False,
|
|
||||||
) -> BaseCallback:
|
|
||||||
"""
|
|
||||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
|
||||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
|
||||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
|
||||||
:return: A hybrid callback calling `callback` and performing evaluation.
|
|
||||||
"""
|
|
||||||
# Convert a list of callbacks into a callback
|
|
||||||
if isinstance(callback, list):
|
|
||||||
callback = CallbackList(callback)
|
|
||||||
|
|
||||||
# Convert functional callback to object
|
|
||||||
if not isinstance(callback, BaseCallback):
|
|
||||||
callback = ConvertCallback(callback)
|
|
||||||
|
|
||||||
# Add progress bar callback
|
|
||||||
if progress_bar:
|
|
||||||
callback = CallbackList([callback, ProgressBarCallback()])
|
|
||||||
|
|
||||||
callback.init_callback(self)
|
|
||||||
return callback
|
|
||||||
|
|
||||||
def _setup_learn(
|
|
||||||
self,
|
|
||||||
total_timesteps: int,
|
|
||||||
callback: MaybeCallback = None,
|
|
||||||
reset_num_timesteps: bool = True,
|
|
||||||
tb_log_name: str = "run",
|
|
||||||
use_masking: bool = True,
|
|
||||||
progress_bar: bool = False,
|
|
||||||
) -> Tuple[int, BaseCallback]:
|
|
||||||
"""
|
|
||||||
Initialize different variables needed for training.
|
|
||||||
|
|
||||||
:param total_timesteps: The total number of samples (env steps) to train on
|
|
||||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
|
||||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
|
||||||
:param tb_log_name: the name of the run for tensorboard log
|
|
||||||
:param use_masking: Whether or not to use invalid action masks during training
|
|
||||||
:param progress_bar: Display a progress bar using tqdm and rich.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.start_time = time.time_ns()
|
|
||||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
|
||||||
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
|
||||||
self.ep_info_buffer = deque(maxlen=self._stats_window_size)
|
|
||||||
self.ep_success_buffer = deque(maxlen=self._stats_window_size)
|
|
||||||
|
|
||||||
if reset_num_timesteps:
|
|
||||||
self.num_timesteps = 0
|
|
||||||
self._episode_num = 0
|
|
||||||
else:
|
|
||||||
# Make sure training timesteps are ahead of the internal counter
|
|
||||||
total_timesteps += self.num_timesteps
|
|
||||||
self._total_timesteps = total_timesteps
|
|
||||||
self._num_timesteps_at_start = self.num_timesteps
|
|
||||||
|
|
||||||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
|
||||||
if reset_num_timesteps or self._last_obs is None:
|
|
||||||
self._last_obs = self.env.reset()
|
|
||||||
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
|
|
||||||
# Retrieve unnormalized observation for saving into the buffer
|
|
||||||
if self._vec_normalize_env is not None:
|
|
||||||
self._last_original_obs = self._vec_normalize_env.get_original_obs()
|
|
||||||
|
|
||||||
# Configure logger's outputs if no logger was passed
|
|
||||||
if not self._custom_logger:
|
|
||||||
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
|
||||||
|
|
||||||
# Create eval callback if needed
|
|
||||||
callback = self._init_callback(callback, use_masking, progress_bar)
|
|
||||||
|
|
||||||
return total_timesteps, callback
|
|
||||||
|
|
||||||
def collect_rollouts(
|
def collect_rollouts(
|
||||||
self,
|
self,
|
||||||
env: VecEnv,
|
env: VecEnv,
|
||||||
|
|
@ -341,14 +270,14 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
log_probs,
|
log_probs,
|
||||||
action_masks=action_masks,
|
action_masks=action_masks,
|
||||||
)
|
)
|
||||||
self._last_obs = new_obs
|
self._last_obs = new_obs # type: ignore[assignment]
|
||||||
self._last_episode_starts = dones
|
self._last_episode_starts = dones
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Compute value for the last timestep
|
# Compute value for the last timestep
|
||||||
# Masking is not needed here, the choice of action doesn't matter.
|
# Masking is not needed here, the choice of action doesn't matter.
|
||||||
# We only want the value of the current observation.
|
# We only want the value of the current observation.
|
||||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
|
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
|
||||||
|
|
||||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||||
|
|
||||||
|
|
@ -356,9 +285,9 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def predict(
|
def predict( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
observation: np.ndarray,
|
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||||
episode_start: Optional[np.ndarray] = None,
|
episode_start: Optional[np.ndarray] = None,
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
|
|
@ -388,10 +317,10 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
# Update optimizer learning rate
|
# Update optimizer learning rate
|
||||||
self._update_learning_rate(self.policy.optimizer)
|
self._update_learning_rate(self.policy.optimizer)
|
||||||
# Compute current clip range
|
# Compute current clip range
|
||||||
clip_range = self.clip_range(self._current_progress_remaining)
|
clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
|
||||||
# Optional: clip range for the value function
|
# Optional: clip range for the value function
|
||||||
if self.clip_range_vf is not None:
|
if self.clip_range_vf is not None:
|
||||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
|
||||||
|
|
||||||
entropy_losses = []
|
entropy_losses = []
|
||||||
pg_losses, value_losses = [], []
|
pg_losses, value_losses = [], []
|
||||||
|
|
@ -499,7 +428,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
if self.clip_range_vf is not None:
|
if self.clip_range_vf is not None:
|
||||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||||
|
|
||||||
def learn(
|
def learn( # type: ignore[override]
|
||||||
self: SelfMaskablePPO,
|
self: SelfMaskablePPO,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
|
|
@ -516,12 +445,13 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
callback,
|
callback,
|
||||||
reset_num_timesteps,
|
reset_num_timesteps,
|
||||||
tb_log_name,
|
tb_log_name,
|
||||||
use_masking,
|
|
||||||
progress_bar,
|
progress_bar,
|
||||||
)
|
)
|
||||||
|
|
||||||
callback.on_training_start(locals(), globals())
|
callback.on_training_start(locals(), globals())
|
||||||
|
|
||||||
|
assert self.env is not None
|
||||||
|
|
||||||
while self.num_timesteps < total_timesteps:
|
while self.num_timesteps < total_timesteps:
|
||||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
|
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
|
||||||
|
|
||||||
|
|
@ -533,6 +463,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
# Display training infos
|
# Display training infos
|
||||||
if log_interval is not None and iteration % log_interval == 0:
|
if log_interval is not None and iteration % log_interval == 0:
|
||||||
|
assert self.ep_info_buffer is not None
|
||||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
self.action_space,
|
self.action_space,
|
||||||
self.lr_schedule,
|
self.lr_schedule,
|
||||||
use_sde=self.use_sde,
|
use_sde=self.use_sde,
|
||||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
**self.policy_kwargs,
|
||||||
)
|
)
|
||||||
self.policy = self.policy.to(self.device)
|
self.policy = self.policy.to(self.device)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue