Log success rate for PPO variants (#235)
This commit is contained in:
parent
89d0113037
commit
34eceaf382
|
|
@ -1,5 +1,3 @@
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -10,7 +8,7 @@ from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
|
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
|
||||||
from stable_baselines3.common.vec_env import VecEnv
|
from stable_baselines3.common.vec_env import VecEnv
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
@ -241,7 +239,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
if not callback.on_step():
|
if not callback.on_step():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._update_info_buffer(infos)
|
self._update_info_buffer(infos, dones)
|
||||||
n_steps += 1
|
n_steps += 1
|
||||||
|
|
||||||
if isinstance(self.action_space, spaces.Discrete):
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
|
|
@ -463,17 +461,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
# Display training infos
|
# Display training infos
|
||||||
if log_interval is not None and iteration % log_interval == 0:
|
if log_interval is not None and iteration % log_interval == 0:
|
||||||
assert self.ep_info_buffer is not None
|
self._dump_logs(iteration)
|
||||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
|
||||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
|
||||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
|
||||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
|
||||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
|
||||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
|
||||||
self.logger.record("time/fps", fps)
|
|
||||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
|
||||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
|
||||||
self.logger.dump(step=self.num_timesteps)
|
|
||||||
|
|
||||||
self.train()
|
self.train()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
|
|
@ -11,7 +9,7 @@ from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
|
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
|
||||||
from stable_baselines3.common.vec_env import VecEnv
|
from stable_baselines3.common.vec_env import VecEnv
|
||||||
|
|
||||||
from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer
|
from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer
|
||||||
|
|
@ -260,7 +258,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
if not callback.on_step():
|
if not callback.on_step():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._update_info_buffer(infos)
|
self._update_info_buffer(infos, dones)
|
||||||
n_steps += 1
|
n_steps += 1
|
||||||
|
|
||||||
if isinstance(self.action_space, spaces.Discrete):
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
|
|
@ -453,42 +451,11 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> SelfRecurrentPPO:
|
) -> SelfRecurrentPPO:
|
||||||
iteration = 0
|
return super().learn(
|
||||||
|
total_timesteps=total_timesteps,
|
||||||
total_timesteps, callback = self._setup_learn(
|
callback=callback,
|
||||||
total_timesteps,
|
log_interval=log_interval,
|
||||||
callback,
|
tb_log_name=tb_log_name,
|
||||||
reset_num_timesteps,
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
tb_log_name,
|
progress_bar=progress_bar,
|
||||||
progress_bar,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
callback.on_training_start(locals(), globals())
|
|
||||||
|
|
||||||
while self.num_timesteps < total_timesteps:
|
|
||||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
|
||||||
|
|
||||||
if not continue_training:
|
|
||||||
break
|
|
||||||
|
|
||||||
iteration += 1
|
|
||||||
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
|
||||||
|
|
||||||
# Display training infos
|
|
||||||
if log_interval is not None and iteration % log_interval == 0:
|
|
||||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
|
||||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
|
||||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
|
||||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
|
||||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
|
||||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
|
||||||
self.logger.record("time/fps", fps)
|
|
||||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
|
||||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
|
||||||
self.logger.dump(step=self.num_timesteps)
|
|
||||||
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
callback.on_training_end()
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
2.3.0a4
|
2.3.0a5
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=2.3.0a4,<3.0",
|
"stable_baselines3>=2.3.0a5,<3.0",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue