From 34eceaf3825d5beb5e6664ffd3c7866c601a69e6 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 31 Mar 2024 19:51:48 +0200 Subject: [PATCH] Log success rate for PPO variants (#235) --- sb3_contrib/ppo_mask/ppo_mask.py | 18 ++------ sb3_contrib/ppo_recurrent/ppo_recurrent.py | 51 ++++------------------ sb3_contrib/version.txt | 2 +- setup.py | 2 +- 4 files changed, 14 insertions(+), 59 deletions(-) diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index a046d2b..a0d7da6 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -1,5 +1,3 @@ -import sys -import time from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union import numpy as np @@ -10,7 +8,7 @@ from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean +from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor from stable_baselines3.common.vec_env import VecEnv from torch.nn import functional as F @@ -241,7 +239,7 @@ class MaskablePPO(OnPolicyAlgorithm): if not callback.on_step(): return False - self._update_info_buffer(infos) + self._update_info_buffer(infos, dones) n_steps += 1 if isinstance(self.action_space, spaces.Discrete): @@ -463,17 +461,7 @@ class MaskablePPO(OnPolicyAlgorithm): # Display training infos if log_interval is not None and iteration % log_interval == 0: - assert self.ep_info_buffer is not None - time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) - fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) - self.logger.record("time/iterations", iteration, exclude="tensorboard") - if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: - self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) - self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) - self.logger.record("time/fps", fps) - self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") - self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") - self.logger.dump(step=self.num_timesteps) + self._dump_logs(iteration) self.train() diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 856bde3..05ffb01 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -1,5 +1,3 @@ -import sys -import time from copy import deepcopy from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union @@ -11,7 +9,7 @@ from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean +from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor from stable_baselines3.common.vec_env import VecEnv from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer @@ -260,7 +258,7 @@ class RecurrentPPO(OnPolicyAlgorithm): if not callback.on_step(): return False - self._update_info_buffer(infos) + self._update_info_buffer(infos, dones) n_steps += 1 if isinstance(self.action_space, spaces.Discrete): @@ -453,42 +451,11 @@ class RecurrentPPO(OnPolicyAlgorithm): reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfRecurrentPPO: - iteration = 0 - - total_timesteps, callback = self._setup_learn( - total_timesteps, - callback, - reset_num_timesteps, - tb_log_name, - progress_bar, + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, ) - - callback.on_training_start(locals(), globals()) - - while self.num_timesteps < total_timesteps: - continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) - - if not continue_training: - break - - iteration += 1 - self._update_current_progress_remaining(self.num_timesteps, total_timesteps) - - # Display training infos - if log_interval is not None and iteration % log_interval == 0: - time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) - fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) - self.logger.record("time/iterations", iteration, exclude="tensorboard") - if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: - self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) - self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) - self.logger.record("time/fps", fps) - self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") - self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") - self.logger.dump(step=self.num_timesteps) - - self.train() - - callback.on_training_end() - - return self diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 87ced0f..a3b489b 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.3.0a4 +2.3.0a5 diff --git a/setup.py b/setup.py index 733fb43..63be4a6 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.3.0a4,<3.0", + "stable_baselines3>=2.3.0a5,<3.0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin",