Log success rate for PPO variants (#235)

This commit is contained in:
Antonin RAFFIN 2024-03-31 19:51:48 +02:00 committed by Andreas Schaarschmidt
parent 89d0113037
commit 34eceaf382
4 changed files with 14 additions and 59 deletions

View File

@ -1,5 +1,3 @@
import sys
import time
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
import numpy as np
@ -10,7 +8,7 @@ from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
from stable_baselines3.common.vec_env import VecEnv
from torch.nn import functional as F
@ -241,7 +239,7 @@ class MaskablePPO(OnPolicyAlgorithm):
if not callback.on_step():
return False
self._update_info_buffer(infos)
self._update_info_buffer(infos, dones)
n_steps += 1
if isinstance(self.action_space, spaces.Discrete):
@ -463,17 +461,7 @@ class MaskablePPO(OnPolicyAlgorithm):
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
self._dump_logs(iteration)
self.train()

View File

@ -1,5 +1,3 @@
import sys
import time
from copy import deepcopy
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
@ -11,7 +9,7 @@ from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
from stable_baselines3.common.vec_env import VecEnv
from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer
@ -260,7 +258,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
if not callback.on_step():
return False
self._update_info_buffer(infos)
self._update_info_buffer(infos, dones)
n_steps += 1
if isinstance(self.action_space, spaces.Discrete):
@ -453,42 +451,11 @@ class RecurrentPPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfRecurrentPPO:
iteration = 0
total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if not continue_training:
break
iteration += 1
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
self.train()
callback.on_training_end()
return self

View File

@ -1 +1 @@
2.3.0a4
2.3.0a5

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.3.0a4,<3.0",
"stable_baselines3>=2.3.0a5,<3.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",