Use higher resolution time_ns() and avoid division by zero (#91)
* Use higher resolution time_ns and add max to avoid division by zero * Add missing imports * Update changelog
This commit is contained in:
parent
3cbd2429be
commit
7e687ac47c
|
|
@ -15,6 +15,7 @@ New Features:
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
|
@ -242,8 +243,8 @@ class ARS(BaseAlgorithm):
|
|||
"""
|
||||
Dump information to the logger.
|
||||
"""
|
||||
time_elapsed = time.time() - self.start_time
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
|
|
@ -242,7 +243,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
:return:
|
||||
"""
|
||||
|
||||
self.start_time = time.time()
|
||||
self.start_time = time.time_ns()
|
||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
||||
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
||||
self.ep_info_buffer = deque(maxlen=100)
|
||||
|
|
@ -566,13 +567,14 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("time/fps", fps)
|
||||
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
|
||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue