diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 368f7a6..3b393f5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -22,6 +22,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO Bug Fixes: ^^^^^^^^^^ @@ -31,8 +32,10 @@ Deprecations: Others: ^^^^^^^ - - Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl) +- Fixed ``sb3_contrib/common/maskable/*.py`` type annotations +- Fixed ``sb3_contrib/ppo_mask/ppo_mask.py`` type annotations +- Fixed ``sb3_contrib/common/vec_env/async_eval.py`` type annotations Documentation: ^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 127d2b2..bd6303e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,12 +28,6 @@ exclude = """(?x)( | sb3_contrib/ars/ars.py$ | sb3_contrib/common/recurrent/policies.py$ | sb3_contrib/common/recurrent/buffers.py$ - | sb3_contrib/common/maskable/distributions.py$ - | sb3_contrib/common/maskable/callbacks.py$ - | sb3_contrib/common/maskable/policies.py$ - | sb3_contrib/common/maskable/buffers.py$ - | sb3_contrib/common/vec_env/async_eval.py$ - | sb3_contrib/ppo_mask/ppo_mask.py$ | tests/test_train_eval_mode.py$ )""" diff --git a/sb3_contrib/common/maskable/buffers.py b/sb3_contrib/common/maskable/buffers.py index 2ccb2bc..f4430ca 100644 --- a/sb3_contrib/common/maskable/buffers.py +++ b/sb3_contrib/common/maskable/buffers.py @@ -18,7 +18,7 @@ class MaskableRolloutBufferSamples(NamedTuple): action_masks: th.Tensor -class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples): +class MaskableDictRolloutBufferSamples(NamedTuple): observations: TensorDict actions: th.Tensor old_values: th.Tensor @@ -42,6 +42,8 @@ class MaskableRolloutBuffer(RolloutBuffer): :param n_envs: Number of parallel environments """ + action_masks: np.ndarray + def __init__( self, buffer_size: int, @@ -53,7 +55,6 @@ class MaskableRolloutBuffer(RolloutBuffer): n_envs: int = 1, ): super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) - self.action_masks = None def reset(self) -> None: if isinstance(self.action_space, spaces.Discrete): @@ -61,6 +62,10 @@ class MaskableRolloutBuffer(RolloutBuffer): elif isinstance(self.action_space, spaces.MultiDiscrete): mask_dims = sum(self.action_space.nvec) elif isinstance(self.action_space, spaces.MultiBinary): + assert isinstance(self.action_space.n, int), ( + f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. " + "You can flatten it instead." + ) mask_dims = 2 * self.action_space.n # One mask per binary outcome else: raise ValueError(f"Unsupported action space {type(self.action_space)}") @@ -79,7 +84,7 @@ class MaskableRolloutBuffer(RolloutBuffer): super().add(*args, **kwargs) - def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: + def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]: # type: ignore[override] assert self.full, "" indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data @@ -105,7 +110,7 @@ class MaskableRolloutBuffer(RolloutBuffer): yield self._get_samples(indices[start_idx : start_idx + batch_size]) start_idx += batch_size - def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples: # type: ignore[override] data = ( self.observations[batch_inds], self.actions[batch_inds], @@ -143,17 +148,18 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer): :param n_envs: Number of parallel environments """ + action_masks: np.ndarray + def __init__( self, buffer_size: int, - observation_space: spaces.Space, + observation_space: spaces.Dict, action_space: spaces.Space, device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, ): - self.action_masks = None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) def reset(self) -> None: @@ -162,6 +168,10 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer): elif isinstance(self.action_space, spaces.MultiDiscrete): mask_dims = sum(self.action_space.nvec) elif isinstance(self.action_space, spaces.MultiBinary): + assert isinstance(self.action_space.n, int), ( + f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. " + "You can flatten it instead." + ) mask_dims = 2 * self.action_space.n # One mask per binary outcome else: raise ValueError(f"Unsupported action space {type(self.action_space)}") @@ -180,7 +190,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer): super().add(*args, **kwargs) - def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: + def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: # type: ignore[override] assert self.full, "" indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data @@ -203,7 +213,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer): yield self._get_samples(indices[start_idx : start_idx + batch_size]) start_idx += batch_size - def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: # type: ignore[override] return MaskableDictRolloutBufferSamples( observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, actions=self.to_torch(self.actions[batch_inds]), diff --git a/sb3_contrib/common/maskable/callbacks.py b/sb3_contrib/common/maskable/callbacks.py index f0839e8..27f93e9 100644 --- a/sb3_contrib/common/maskable/callbacks.py +++ b/sb3_contrib/common/maskable/callbacks.py @@ -55,7 +55,7 @@ class MaskableEvalCallback(EvalCallback): # Note that evaluate_policy() has been patched to support masking episode_rewards, episode_lengths = evaluate_policy( - self.model, + self.model, # type: ignore[arg-type] self.eval_env, n_eval_episodes=self.n_eval_episodes, render=self.render, @@ -67,6 +67,8 @@ class MaskableEvalCallback(EvalCallback): ) if self.log_path is not None: + assert isinstance(episode_rewards, list) + assert isinstance(episode_lengths, list) self.evaluations_timesteps.append(self.num_timesteps) self.evaluations_results.append(episode_rewards) self.evaluations_length.append(episode_lengths) @@ -87,7 +89,7 @@ class MaskableEvalCallback(EvalCallback): mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards) mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths) - self.last_mean_reward = mean_reward + self.last_mean_reward = float(mean_reward) if self.verbose > 0: print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") @@ -111,7 +113,7 @@ class MaskableEvalCallback(EvalCallback): print("New best mean reward!") if self.best_model_save_path is not None: self.model.save(os.path.join(self.best_model_save_path, "best_model")) - self.best_mean_reward = mean_reward + self.best_mean_reward = float(mean_reward) # Trigger callback on new best model, if needed if self.callback_on_new_best is not None: continue_training = self.callback_on_new_best.on_step() diff --git a/sb3_contrib/common/maskable/distributions.py b/sb3_contrib/common/maskable/distributions.py index f37c361..bacf368 100644 --- a/sb3_contrib/common/maskable/distributions.py +++ b/sb3_contrib/common/maskable/distributions.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, TypeVar +from typing import List, Optional, Tuple, TypeVar, Union import numpy as np import torch as th @@ -13,6 +13,7 @@ SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistributi SelfMaskableMultiCategoricalDistribution = TypeVar( "SelfMaskableMultiCategoricalDistribution", bound="MaskableMultiCategoricalDistribution" ) +MaybeMasks = Union[th.Tensor, np.ndarray, None] class MaskableCategorical(Categorical): @@ -36,14 +37,14 @@ class MaskableCategorical(Categorical): probs: Optional[th.Tensor] = None, logits: Optional[th.Tensor] = None, validate_args: Optional[bool] = None, - masks: Optional[np.ndarray] = None, + masks: MaybeMasks = None, ): self.masks: Optional[th.Tensor] = None super().__init__(probs, logits, validate_args) self._original_logits = self.logits self.apply_masking(masks) - def apply_masking(self, masks: Optional[np.ndarray]) -> None: + def apply_masking(self, masks: MaybeMasks) -> None: """ Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0. @@ -84,7 +85,7 @@ class MaskableCategorical(Categorical): class MaskableDistribution(Distribution, ABC): @abstractmethod - def apply_masking(self, masks: Optional[np.ndarray]) -> None: + def apply_masking(self, masks: MaybeMasks) -> None: """ Eliminate ("mask out") chosen distribution outcomes by setting their probability to 0. @@ -94,6 +95,13 @@ class MaskableDistribution(Distribution, ABC): previously applied masking is removed, and the original logits are restored. """ + @abstractmethod + def proba_distribution_net(self, *args, **kwargs) -> nn.Module: + """Create the layers and parameters that represent the distribution. + + Subclasses must define this, but the arguments and return type vary between + concrete classes.""" + class MaskableCategoricalDistribution(MaskableDistribution): """ @@ -154,7 +162,7 @@ class MaskableCategoricalDistribution(MaskableDistribution): log_prob = self.log_prob(actions) return actions, log_prob - def apply_masking(self, masks: Optional[np.ndarray]) -> None: + def apply_masking(self, masks: MaybeMasks) -> None: assert self.distribution is not None, "Must set distribution parameters" self.distribution.apply_masking(masks) @@ -192,7 +200,7 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution): reshaped_logits = action_logits.view(-1, sum(self.action_dims)) self.distributions = [ - MaskableCategorical(logits=split) for split in th.split(reshaped_logits, tuple(self.action_dims), dim=1) + MaskableCategorical(logits=split) for split in th.split(reshaped_logits, list(self.action_dims), dim=1) ] return self @@ -229,18 +237,16 @@ class MaskableMultiCategoricalDistribution(MaskableDistribution): log_prob = self.log_prob(actions) return actions, log_prob - def apply_masking(self, masks: Optional[np.ndarray]) -> None: + def apply_masking(self, masks: MaybeMasks) -> None: assert len(self.distributions) > 0, "Must set distribution parameters" split_masks = [None] * len(self.distributions) if masks is not None: - masks = th.as_tensor(masks) - + masks_tensor = th.as_tensor(masks) # Restructure shape to align with logits - masks = masks.view(-1, sum(self.action_dims)) - + masks_tensor = masks_tensor.view(-1, sum(self.action_dims)) # Then split columnwise for each discrete action - split_masks = th.split(masks, tuple(self.action_dims), dim=1) + split_masks = th.split(masks_tensor, list(self.action_dims), dim=1) # type: ignore[assignment] for distribution, mask in zip(self.distributions, split_masks): distribution.apply_masking(mask) @@ -268,10 +274,13 @@ def make_masked_proba_distribution(action_space: spaces.Space) -> MaskableDistri """ if isinstance(action_space, spaces.Discrete): - return MaskableCategoricalDistribution(action_space.n) + return MaskableCategoricalDistribution(int(action_space.n)) elif isinstance(action_space, spaces.MultiDiscrete): - return MaskableMultiCategoricalDistribution(action_space.nvec) + return MaskableMultiCategoricalDistribution(list(action_space.nvec)) elif isinstance(action_space, spaces.MultiBinary): + assert isinstance( + action_space.n, int + ), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead." return MaskableBernoulliDistribution(action_space.n) else: raise NotImplementedError( diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index f6b5eeb..77dfde3 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import ( MlpExtractor, NatureCNN, ) -from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.type_aliases import PyTorchObs, Schedule from torch import nn from sb3_contrib.common.maskable.distributions import MaskableDistribution, make_masked_proba_distribution @@ -141,8 +141,8 @@ class MaskableActorCriticPolicy(BasePolicy): log_prob = distribution.log_prob(actions) return actions, values, log_prob - def extract_features( - self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None + def extract_features( # type: ignore[override] + self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None ) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: """ Preprocess the observation if needed and extract features. @@ -233,7 +233,11 @@ class MaskableActorCriticPolicy(BasePolicy): module.apply(partial(self.init_weights, gain=gain)) # Setup optimizer with initial learning rate - self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.optimizer = self.optimizer_class( + self.parameters(), + lr=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ) def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution: """ @@ -245,9 +249,9 @@ class MaskableActorCriticPolicy(BasePolicy): action_logits = self.action_net(latent_pi) return self.action_dist.proba_distribution(action_logits=action_logits) - def _predict( + def _predict( # type: ignore[override] self, - observation: th.Tensor, + observation: PyTorchObs, deterministic: bool = False, action_masks: Optional[np.ndarray] = None, ) -> th.Tensor: @@ -284,35 +288,45 @@ class MaskableActorCriticPolicy(BasePolicy): # Switch to eval mode (this affects batch norm / dropout) self.set_training_mode(False) - observation, vectorized_env = self.obs_to_tensor(observation) + # Check for common mistake that the user does not mix Gym/VecEnv API + # Tuple obs are not supported by SB3, so we can safely do that check + if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict): + raise ValueError( + "You have passed a tuple to the predict() function instead of a Numpy array or a Dict. " + "You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) " + "vs `obs = vec_env.reset()` (SB3 VecEnv). " + "See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 " + "and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api" + ) + + obs_tensor, vectorized_env = self.obs_to_tensor(observation) with th.no_grad(): - actions = self._predict(observation, deterministic=deterministic, action_masks=action_masks) + actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks) # Convert to numpy actions = actions.cpu().numpy() if isinstance(self.action_space, spaces.Box): if self.squash_output: # Rescale to proper domain when using squashing - actions = self.unscale_action(actions) + actions = self.unscale_action(actions) # type: ignore[assignment, arg-type] else: # Actions could be on arbitrary scale, so clip the actions to avoid # out of bound error (e.g. if sampling from a Gaussian distribution) - actions = np.clip(actions, self.action_space.low, self.action_space.high) + actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type] if not vectorized_env: - if state is not None: - raise ValueError("Error: The environment must be vectorized when using recurrent policies.") + assert isinstance(actions, np.ndarray) actions = actions.squeeze(axis=0) - return actions, None + return actions, state # type: ignore[return-value] def evaluate_actions( self, obs: th.Tensor, actions: th.Tensor, - action_masks: Optional[np.ndarray] = None, - ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + action_masks: Optional[th.Tensor] = None, + ) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: """ Evaluate actions according to the current policy, given the observations. @@ -337,7 +351,7 @@ class MaskableActorCriticPolicy(BasePolicy): values = self.value_net(latent_vf) return values, log_prob, distribution.entropy() - def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution: + def get_distribution(self, obs: PyTorchObs, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution: """ Get the current policy distribution given the observations. @@ -352,7 +366,7 @@ class MaskableActorCriticPolicy(BasePolicy): distribution.apply_masking(action_masks) return distribution - def predict_values(self, obs: th.Tensor) -> th.Tensor: + def predict_values(self, obs: PyTorchObs) -> th.Tensor: """ Get the estimated values according to the current policy given the observations. diff --git a/sb3_contrib/common/vec_env/async_eval.py b/sb3_contrib/common/vec_env/async_eval.py index a7be9ba..c18b771 100644 --- a/sb3_contrib/common/vec_env/async_eval.py +++ b/sb3_contrib/common/vec_env/async_eval.py @@ -59,10 +59,11 @@ def _worker( remote.send(vec_env.seed(seed=data)) elif cmd == "set_options": # Note: the options will only be effective at the next reset - remote.send(vec_env.set_options(data)) + remote.send(vec_env.set_options(data)) # type: ignore[func-returns-value] elif cmd == "get_obs_rms": remote.send(obs_rms) elif cmd == "sync_obs_rms": + assert vec_normalize is not None, "Tried to call `sync_obs_rms` when not using VecNormalize" vec_normalize.obs_rms = data obs_rms = data elif cmd == "close": @@ -130,7 +131,7 @@ class AsyncEval: n_eval_episodes, ) # daemon=True: if the main process crashes, we should not cause things to hang - process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error + process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined] process.start() self.processes.append(process) work_remote.close() @@ -157,6 +158,10 @@ class AsyncEval: :param seed: The seed for the pseudo-random generators. :return: """ + if seed is None: + # Do nothing + return [] + for idx, remote in enumerate(self.remotes): remote.send(("seed", seed + idx)) return [remote.recv() for remote in self.remotes] diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 3248ba7..a046d2b 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -1,14 +1,12 @@ import sys import time -from collections import deque from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch as th from gymnasium import spaces -from stable_baselines3.common import utils from stable_baselines3.common.buffers import RolloutBuffer -from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback +from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -74,6 +72,8 @@ class MaskablePPO(OnPolicyAlgorithm): "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } + policy: MaskableActorCriticPolicy # type: ignore[assignment] + rollout_buffer: MaskableRolloutBuffer # type: ignore[assignment] def __init__( self, @@ -91,6 +91,8 @@ class MaskablePPO(OnPolicyAlgorithm): ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, + rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, target_kl: Optional[float] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, @@ -101,7 +103,7 @@ class MaskablePPO(OnPolicyAlgorithm): _init_setup_model: bool = True, ): super().__init__( - policy, + policy, # type: ignore[arg-type] env, learning_rate=learning_rate, n_steps=n_steps, @@ -112,6 +114,8 @@ class MaskablePPO(OnPolicyAlgorithm): max_grad_norm=max_grad_norm, use_sde=False, sde_sample_freq=-1, + rollout_buffer_class=rollout_buffer_class, + rollout_buffer_kwargs=rollout_buffer_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, @@ -140,20 +144,24 @@ class MaskablePPO(OnPolicyAlgorithm): self._setup_lr_schedule() self.set_random_seed(self.seed) - buffer_cls = MaskableDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else MaskableRolloutBuffer - - self.policy = self.policy_class( + self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, self.lr_schedule, - **self.policy_kwargs, # pytype:disable=not-instantiable + **self.policy_kwargs, ) self.policy = self.policy.to(self.device) if not isinstance(self.policy, MaskableActorCriticPolicy): raise ValueError("Policy must subclass MaskableActorCriticPolicy") - self.rollout_buffer = buffer_cls( + if self.rollout_buffer_class is None: + if isinstance(self.observation_space, spaces.Dict): + self.rollout_buffer_class = MaskableDictRolloutBuffer + else: + self.rollout_buffer_class = MaskableRolloutBuffer + + self.rollout_buffer = self.rollout_buffer_class( # type: ignore[assignment] self.n_steps, self.observation_space, self.action_space, @@ -161,6 +169,7 @@ class MaskablePPO(OnPolicyAlgorithm): gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, + **self.rollout_buffer_kwargs, ) # Initialize schedules for policy/value clipping @@ -171,86 +180,6 @@ class MaskablePPO(OnPolicyAlgorithm): self.clip_range_vf = get_schedule_fn(self.clip_range_vf) - def _init_callback( - self, - callback: MaybeCallback, - use_masking: bool = True, - progress_bar: bool = False, - ) -> BaseCallback: - """ - :param callback: Callback(s) called at every step with state of the algorithm. - :param use_masking: Whether or not to use invalid action masks during evaluation - :param progress_bar: Display a progress bar using tqdm and rich. - :return: A hybrid callback calling `callback` and performing evaluation. - """ - # Convert a list of callbacks into a callback - if isinstance(callback, list): - callback = CallbackList(callback) - - # Convert functional callback to object - if not isinstance(callback, BaseCallback): - callback = ConvertCallback(callback) - - # Add progress bar callback - if progress_bar: - callback = CallbackList([callback, ProgressBarCallback()]) - - callback.init_callback(self) - return callback - - def _setup_learn( - self, - total_timesteps: int, - callback: MaybeCallback = None, - reset_num_timesteps: bool = True, - tb_log_name: str = "run", - use_masking: bool = True, - progress_bar: bool = False, - ) -> Tuple[int, BaseCallback]: - """ - Initialize different variables needed for training. - - :param total_timesteps: The total number of samples (env steps) to train on - :param callback: Callback(s) called at every step with state of the algorithm. - :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute - :param tb_log_name: the name of the run for tensorboard log - :param use_masking: Whether or not to use invalid action masks during training - :param progress_bar: Display a progress bar using tqdm and rich. - :return: - """ - - self.start_time = time.time_ns() - if self.ep_info_buffer is None or reset_num_timesteps: - # Initialize buffers if they don't exist, or reinitialize if resetting counters - self.ep_info_buffer = deque(maxlen=self._stats_window_size) - self.ep_success_buffer = deque(maxlen=self._stats_window_size) - - if reset_num_timesteps: - self.num_timesteps = 0 - self._episode_num = 0 - else: - # Make sure training timesteps are ahead of the internal counter - total_timesteps += self.num_timesteps - self._total_timesteps = total_timesteps - self._num_timesteps_at_start = self.num_timesteps - - # Avoid resetting the environment when calling ``.learn()`` consecutive times - if reset_num_timesteps or self._last_obs is None: - self._last_obs = self.env.reset() - self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) - # Retrieve unnormalized observation for saving into the buffer - if self._vec_normalize_env is not None: - self._last_original_obs = self._vec_normalize_env.get_original_obs() - - # Configure logger's outputs if no logger was passed - if not self._custom_logger: - self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps) - - # Create eval callback if needed - callback = self._init_callback(callback, use_masking, progress_bar) - - return total_timesteps, callback - def collect_rollouts( self, env: VecEnv, @@ -341,14 +270,14 @@ class MaskablePPO(OnPolicyAlgorithm): log_probs, action_masks=action_masks, ) - self._last_obs = new_obs + self._last_obs = new_obs # type: ignore[assignment] self._last_episode_starts = dones with th.no_grad(): # Compute value for the last timestep # Masking is not needed here, the choice of action doesn't matter. # We only want the value of the current observation. - values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type] rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) @@ -356,9 +285,9 @@ class MaskablePPO(OnPolicyAlgorithm): return True - def predict( + def predict( # type: ignore[override] self, - observation: np.ndarray, + observation: Union[np.ndarray, Dict[str, np.ndarray]], state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, @@ -388,10 +317,10 @@ class MaskablePPO(OnPolicyAlgorithm): # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range - clip_range = self.clip_range(self._current_progress_remaining) + clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator] # Optional: clip range for the value function if self.clip_range_vf is not None: - clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator] entropy_losses = [] pg_losses, value_losses = [], [] @@ -499,7 +428,7 @@ class MaskablePPO(OnPolicyAlgorithm): if self.clip_range_vf is not None: self.logger.record("train/clip_range_vf", clip_range_vf) - def learn( + def learn( # type: ignore[override] self: SelfMaskablePPO, total_timesteps: int, callback: MaybeCallback = None, @@ -516,12 +445,13 @@ class MaskablePPO(OnPolicyAlgorithm): callback, reset_num_timesteps, tb_log_name, - use_masking, progress_bar, ) callback.on_training_start(locals(), globals()) + assert self.env is not None + while self.num_timesteps < total_timesteps: continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking) @@ -533,6 +463,7 @@ class MaskablePPO(OnPolicyAlgorithm): # Display training infos if log_interval is not None and iteration % log_interval == 0: + assert self.ep_info_buffer is not None time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/iterations", iteration, exclude="tensorboard") diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 39fd941..856bde3 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -149,7 +149,7 @@ class RecurrentPPO(OnPolicyAlgorithm): self.action_space, self.lr_schedule, use_sde=self.use_sde, - **self.policy_kwargs, # pytype:disable=not-instantiable + **self.policy_kwargs, ) self.policy = self.policy.to(self.device)