Log success rate for PPO variants (#235)

This commit is contained in:
Antonin RAFFIN 2024-03-31 19:51:48 +02:00 committed by Andreas Schaarschmidt
parent 89d0113037
commit 34eceaf382
4 changed files with 14 additions and 59 deletions

View File

@ -1,5 +1,3 @@
import sys
import time
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
@ -10,7 +8,7 @@ from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.vec_env import VecEnv
from torch.nn import functional as F from torch.nn import functional as F
@ -241,7 +239,7 @@ class MaskablePPO(OnPolicyAlgorithm):
if not callback.on_step(): if not callback.on_step():
return False return False
self._update_info_buffer(infos) self._update_info_buffer(infos, dones)
n_steps += 1 n_steps += 1
if isinstance(self.action_space, spaces.Discrete): if isinstance(self.action_space, spaces.Discrete):
@ -463,17 +461,7 @@ class MaskablePPO(OnPolicyAlgorithm):
# Display training infos # Display training infos
if log_interval is not None and iteration % log_interval == 0: if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None self._dump_logs(iteration)
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
self.train() self.train()

View File

@ -1,5 +1,3 @@
import sys
import time
from copy import deepcopy from copy import deepcopy
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
@ -11,7 +9,7 @@ from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.vec_env import VecEnv
from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer
@ -260,7 +258,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
if not callback.on_step(): if not callback.on_step():
return False return False
self._update_info_buffer(infos) self._update_info_buffer(infos, dones)
n_steps += 1 n_steps += 1
if isinstance(self.action_space, spaces.Discrete): if isinstance(self.action_space, spaces.Discrete):
@ -453,42 +451,11 @@ class RecurrentPPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False, progress_bar: bool = False,
) -> SelfRecurrentPPO: ) -> SelfRecurrentPPO:
iteration = 0 return super().learn(
total_timesteps=total_timesteps,
total_timesteps, callback = self._setup_learn( callback=callback,
total_timesteps, log_interval=log_interval,
callback, tb_log_name=tb_log_name,
reset_num_timesteps, reset_num_timesteps=reset_num_timesteps,
tb_log_name, progress_bar=progress_bar,
progress_bar,
) )
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if not continue_training:
break
iteration += 1
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
self.train()
callback.on_training_end()
return self

View File

@ -1 +1 @@
2.3.0a4 2.3.0a5

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=2.3.0a4,<3.0", "stable_baselines3>=2.3.0a5,<3.0",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",