Add progress bar argument (#107)

* Add progress bar argument

* Sort imports
This commit is contained in:
Antonin RAFFIN 2022-10-10 18:44:13 +02:00 committed by GitHub
parent e9c97948c8
commit 52795a307e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 43 additions and 48 deletions

View File

@ -15,6 +15,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``progress_bar`` argument in the ``learn()`` method, displayed using TQDM and rich packages
Bug Fixes:
^^^^^^^^^^

View File

@ -315,6 +315,7 @@ class ARS(BaseAlgorithm):
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
async_eval: Optional[AsyncEval] = None,
progress_bar: bool = False,
) -> ARSSelf:
"""
Return a trained model.
@ -333,11 +334,20 @@ class ARS(BaseAlgorithm):
:param eval_log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
:param async_eval: The object for asynchronous evaluation of candidates.
:param progress_bar: Display a progress bar using tqdm and rich.
:return: the trained model
"""
total_steps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
callback.on_training_start(locals(), globals())

View File

@ -10,7 +10,7 @@ import torch as th
from gym import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@ -184,6 +184,7 @@ class MaskablePPO(OnPolicyAlgorithm):
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
use_masking: bool = True,
progress_bar: bool = False,
) -> BaseCallback:
"""
:param callback: Callback(s) called at every step with state of the algorithm.
@ -196,6 +197,7 @@ class MaskablePPO(OnPolicyAlgorithm):
:param n_eval_episodes: How many episodes to play per evaluation
:param log_path: Path to a folder where the evaluations will be saved
:param use_masking: Whether or not to use invalid action masks during evaluation
:param progress_bar: Display a progress bar using tqdm and rich.
:return: A hybrid callback calling `callback` and performing evaluation.
"""
# Convert a list of callbacks into a callback
@ -206,6 +208,10 @@ class MaskablePPO(OnPolicyAlgorithm):
if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback)
# Add progress bar callback
if progress_bar:
callback = CallbackList([callback, ProgressBarCallback()])
# Create eval callback in charge of the evaluation
if eval_env is not None:
# Avoid circular import error
@ -236,6 +242,7 @@ class MaskablePPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
use_masking: bool = True,
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
@ -253,6 +260,7 @@ class MaskablePPO(OnPolicyAlgorithm):
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:param use_masking: Whether or not to use invalid action masks during training
:param progress_bar: Display a progress bar using tqdm and rich.
:return:
"""
@ -299,7 +307,7 @@ class MaskablePPO(OnPolicyAlgorithm):
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
# Create eval callback if needed
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking)
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking, progress_bar)
return total_timesteps, callback
@ -563,6 +571,7 @@ class MaskablePPO(OnPolicyAlgorithm):
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
use_masking: bool = True,
progress_bar: bool = False,
) -> MaskablePPOSelf:
iteration = 0
@ -576,6 +585,7 @@ class MaskablePPO(OnPolicyAlgorithm):
reset_num_timesteps,
tb_log_name,
use_masking,
progress_bar,
)
callback.on_training_start(locals(), globals())

View File

@ -1,7 +1,7 @@
import sys
import time
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Optional, Type, TypeVar, Union
import gym
import numpy as np
@ -198,47 +198,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
def _setup_learn(
self,
total_timesteps: int,
eval_env: Optional[GymEnv],
callback: MaybeCallback = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "RecurrentPPO",
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on
:param eval_env: Environment to use for evaluation.
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param callback: Callback(s) called at every step with state of the algorithm.
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param n_eval_episodes: How many episodes to play per evaluation
:param log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:return:
"""
total_timesteps, callback = super()._setup_learn(
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
log_path,
reset_num_timesteps,
tb_log_name,
)
return total_timesteps, callback
def collect_rollouts(
self,
env: VecEnv,
@ -500,11 +459,20 @@ class RecurrentPPO(OnPolicyAlgorithm):
tb_log_name: str = "RecurrentPPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> RecurrentPPOSelf:
iteration = 0
total_timesteps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
callback.on_training_start(locals(), globals())

View File

@ -262,6 +262,7 @@ class QRDQN(OffPolicyAlgorithm):
tb_log_name: str = "QRDQN",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> QRDQNSelf:
return super().learn(
@ -274,6 +275,7 @@ class QRDQN(OffPolicyAlgorithm):
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
def _excluded_save_params(self) -> List[str]:

View File

@ -299,6 +299,7 @@ class TQC(OffPolicyAlgorithm):
tb_log_name: str = "TQC",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> TQCSelf:
return super().learn(
@ -311,6 +312,7 @@ class TQC(OffPolicyAlgorithm):
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
def _excluded_save_params(self) -> List[str]:

View File

@ -415,6 +415,7 @@ class TRPO(OnPolicyAlgorithm):
tb_log_name: str = "TRPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> TRPOSelf:
return super().learn(
@ -427,4 +428,5 @@ class TRPO(OnPolicyAlgorithm):
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

View File

@ -190,7 +190,7 @@ def test_callback(tmp_path):
model = MaskablePPO("MlpPolicy", env, n_steps=64, gamma=0.4, seed=32, verbose=1)
model.learn(100, callback=MaskableEvalCallback(eval_env, eval_freq=100, warn=False, log_path=tmp_path))
model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False))
model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False), progress_bar=True)
def test_child_callback():

View File

@ -20,7 +20,7 @@ def test_tqc(ent_coef):
create_eval_env=True,
ent_coef=ent_coef,
)
model.learn(total_timesteps=300, eval_freq=250)
model.learn(total_timesteps=300, eval_freq=250, progress_bar=True)
@pytest.mark.parametrize("n_critics", [1, 3])