import time from collections import deque from typing import Any, Dict, Optional, Tuple, Type, Union import numpy as np import torch as th from gym import spaces from stable_baselines3.common import utils from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv from torch.nn import functional as F from sb3_contrib.common.maskable.buffers import MaskableRolloutBuffer from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported class MaskablePPO(OnPolicyAlgorithm): """ Proximal Policy Optimization algorithm (PPO) (clip version) with Invalid Action Masking. Based on the original Stable Baselines 3 implementation. Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html Background on Invalid Action Masking: https://arxiv.org/abs/2006.14171 :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) :param n_steps: The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) :param batch_size: Minibatch size :param n_epochs: Number of epoch when optimizing the surrogate loss :param gamma: Discount factor :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param clip_range: Clipping parameter, it can be a function of the current progress remaining (from 1 to 0). :param clip_range_vf: Clipping parameter for the value function, it can be a function of the current progress remaining (from 1 to 0). This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling. :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping :param target_kl: Limit the KL divergence between updates, because the clipping is not enough to prevent large update see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) By default, there is no limit on the kl div. :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ def __init__( self, policy: Union[str, Type[MaskableActorCriticPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, batch_size: Optional[int] = 64, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, clip_range: Union[float, Schedule] = 0.2, clip_range_vf: Union[None, float, Schedule] = None, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, target_kl: Optional[float] = None, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): super().__init__( policy, env, learning_rate=learning_rate, n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, use_sde=False, sde_sample_freq=-1, tensorboard_log=tensorboard_log, create_eval_env=create_eval_env, policy_kwargs=policy_kwargs, policy_base=MaskableActorCriticPolicy, verbose=verbose, seed=seed, device=device, _init_setup_model=False, supported_action_spaces=( spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary, ), ) self.batch_size = batch_size self.n_epochs = n_epochs self.clip_range = clip_range self.clip_range_vf = clip_range_vf self.target_kl = target_kl if _init_setup_model: self._setup_model() def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) self.policy = self.policy_class( self.observation_space, self.action_space, self.lr_schedule, **self.policy_kwargs, # pytype:disable=not-instantiable ) self.policy = self.policy.to(self.device) if not isinstance(self.policy, MaskableActorCriticPolicy): raise ValueError("Policy must subclass MaskableActorCriticPolicy") self.rollout_buffer = MaskableRolloutBuffer( self.n_steps, self.observation_space, self.action_space, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, ) # Initialize schedules for policy/value clipping self.clip_range = get_schedule_fn(self.clip_range) if self.clip_range_vf is not None: if isinstance(self.clip_range_vf, (float, int)): assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" self.clip_range_vf = get_schedule_fn(self.clip_range_vf) def _init_callback( self, callback: MaybeCallback, eval_env: Optional[VecEnv] = None, eval_freq: int = 10000, n_eval_episodes: int = 5, log_path: Optional[str] = None, use_masking: bool = True, ) -> BaseCallback: """ :param callback: Callback(s) called at every step with state of the algorithm. :param eval_freq: How many steps between evaluations; if None, do not evaluate. :param n_eval_episodes: How many episodes to play per evaluation :param n_eval_episodes: Number of episodes to rollout during evaluation. :param log_path: Path to a folder where the evaluations will be saved :param use_masking: Whether or not to use invalid action masks during evaluation :return: A hybrid callback calling `callback` and performing evaluation. """ # Convert a list of callbacks into a callback if isinstance(callback, list): callback = CallbackList(callback) # Convert functional callback to object if not isinstance(callback, BaseCallback): callback = ConvertCallback(callback) # Create eval callback in charge of the evaluation if eval_env is not None: # Avoid circular import error from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback eval_callback = MaskableEvalCallback( eval_env, best_model_save_path=log_path, log_path=log_path, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, use_masking=use_masking, ) callback = CallbackList([callback, eval_callback]) callback.init_callback(self) return callback def _setup_learn( self, total_timesteps: int, eval_env: Optional[GymEnv], callback: MaybeCallback = None, eval_freq: int = 10000, n_eval_episodes: int = 5, log_path: Optional[str] = None, reset_num_timesteps: bool = True, tb_log_name: str = "run", use_masking: bool = True, ) -> Tuple[int, BaseCallback]: """ Initialize different variables needed for training. :param total_timesteps: The total number of samples (env steps) to train on :param eval_env: Environment to use for evaluation. :param callback: Callback(s) called at every step with state of the algorithm. :param eval_freq: How many steps between evaluations :param n_eval_episodes: How many episodes to play per evaluation :param log_path: Path to a folder where the evaluations will be saved :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute :param tb_log_name: the name of the run for tensorboard log :param use_masking: Whether or not to use invalid action masks during training :return: """ self.start_time = time.time() if self.ep_info_buffer is None or reset_num_timesteps: # Initialize buffers if they don't exist, or reinitialize if resetting counters self.ep_info_buffer = deque(maxlen=100) self.ep_success_buffer = deque(maxlen=100) if reset_num_timesteps: self.num_timesteps = 0 self._episode_num = 0 else: # Make sure training timesteps are ahead of the internal counter total_timesteps += self.num_timesteps self._total_timesteps = total_timesteps # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: self._last_obs = self.env.reset() self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) # Retrieve unnormalized observation for saving into the buffer if self._vec_normalize_env is not None: self._last_original_obs = self._vec_normalize_env.get_original_obs() if eval_env is not None and self.seed is not None: eval_env.seed(self.seed) eval_env = self._get_eval_env(eval_env) # Configure logger's outputs if no logger was passed if not self._custom_logger: self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps) # Create eval callback if needed callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking) return total_timesteps, callback def collect_rollouts( self, env: VecEnv, callback: BaseCallback, rollout_buffer: RolloutBuffer, n_rollout_steps: int, use_masking: bool = True, ) -> bool: """ Collect experiences using the current policy and fill a ``RolloutBuffer``. The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning. This method is largely identical to the implementation found in the parent class. :param env: The training environment :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) :param rollout_buffer: Buffer to fill with rollouts :param n_steps: Number of experiences to collect per environment :param use_masking: Whether or not to use invalid action masks during training :return: True if function returned with at least `n_rollout_steps` collected, False if callback terminated rollout prematurely. """ assert isinstance(rollout_buffer, MaskableRolloutBuffer), "RolloutBuffer doesn't support action masking" assert self._last_obs is not None, "No previous observation was provided" # Switch to eval mode (this affects batch norm / dropout) self.policy.set_training_mode(False) n_steps = 0 action_masks = None rollout_buffer.reset() if use_masking and not is_masking_supported(env): raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper") callback.on_rollout_start() while n_steps < n_rollout_steps: with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) # This is the only change related to invalid action masking if use_masking: action_masks = get_action_masks(env) actions, values, log_probs = self.policy.forward(obs_tensor, action_masks=action_masks) actions = actions.cpu().numpy() new_obs, rewards, dones, infos = env.step(actions) self.num_timesteps += env.num_envs # Give access to local variables callback.update_locals(locals()) if callback.on_step() is False: return False self._update_info_buffer(infos) n_steps += 1 if isinstance(self.action_space, spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) rollout_buffer.add( self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, action_masks=action_masks, ) self._last_obs = new_obs self._last_episode_starts = dones with th.no_grad(): # Compute value for the last timestep obs_tensor = obs_as_tensor(new_obs, self.device) # Masking is not needed here, the choice of action doesn't matter. # We only want the value of the current observation. _, values, _ = self.policy.forward(obs_tensor) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) callback.on_rollout_end() return True def predict( self, observation: np.ndarray, state: Optional[np.ndarray] = None, mask: Optional[np.ndarray] = None, deterministic: bool = False, action_masks: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ Get the model's action(s) from an observation. :param observation: the input observation :param state: The last states (can be None, used in recurrent policies) :param mask: The last masks (can be None, used in recurrent policies) :param deterministic: Whether or not to return deterministic actions. :param action_masks: Action masks to apply to the action distribution. :return: the model's action and the next state (used in recurrent policies) """ return self.policy.predict(observation, state, mask, deterministic, action_masks=action_masks) def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range clip_range = self.clip_range(self._current_progress_remaining) # Optional: clip range for the value function if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf(self._current_progress_remaining) entropy_losses = [] pg_losses, value_losses = [], [] clip_fractions = [] continue_training = True # train for n_epochs epochs for epoch in range(self.n_epochs): approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions, action_masks=rollout_data.action_masks, ) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) # clipped surrogate loss policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() # Logging pg_losses.append(policy_loss.item()) clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: # No clipping values_pred = values else: # Clip the different between old and new value # NOTE: this depends on the reward scaling values_pred = rollout_data.old_values + th.clamp( values - rollout_data.old_values, -clip_range_vf, clip_range_vf ) # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Calculate approximate form of reverse KL Divergence for early stopping # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 # and Schulman blog: http://joschu.net/blog/kl-approx.html with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: continue_training = False if self.verbose >= 1: print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") break # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() if not continue_training: break self._n_updates += self.n_epochs explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs self.logger.record("train/entropy_loss", np.mean(entropy_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item()) self.logger.record("train/explained_variance", explained_var) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/clip_range", clip_range) if self.clip_range_vf is not None: self.logger.record("train/clip_range_vf", clip_range_vf) def learn( self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, n_eval_episodes: int = 5, tb_log_name: str = "PPO", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, use_masking: bool = True, ) -> "MaskablePPO": iteration = 0 total_timesteps, callback = self._setup_learn( total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name, use_masking, ) callback.on_training_start(locals(), globals()) while self.num_timesteps < total_timesteps: continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking) if continue_training is False: break iteration += 1 self._update_current_progress_remaining(self.num_timesteps, total_timesteps) # Display training infos if log_interval is not None and iteration % log_interval == 0: fps = int(self.num_timesteps / (time.time() - self.start_time)) self.logger.record("time/iterations", iteration, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) self.logger.record("time/fps", fps) self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.dump(step=self.num_timesteps) self.train() callback.on_training_end() return self