Add progress bar argument (#107)

* Add progress bar argument

* Sort imports
This commit is contained in:
Antonin RAFFIN 2022-10-10 18:44:13 +02:00 committed by GitHub
parent e9c97948c8
commit 52795a307e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 43 additions and 48 deletions

View File

@ -15,6 +15,7 @@ Breaking Changes:
New Features: New Features:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
- Added ``progress_bar`` argument in the ``learn()`` method, displayed using TQDM and rich packages
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^

View File

@ -315,6 +315,7 @@ class ARS(BaseAlgorithm):
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
async_eval: Optional[AsyncEval] = None, async_eval: Optional[AsyncEval] = None,
progress_bar: bool = False,
) -> ARSSelf: ) -> ARSSelf:
""" """
Return a trained model. Return a trained model.
@ -333,11 +334,20 @@ class ARS(BaseAlgorithm):
:param eval_log_path: Path to a folder where the evaluations will be saved :param eval_log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
:param async_eval: The object for asynchronous evaluation of candidates. :param async_eval: The object for asynchronous evaluation of candidates.
:param progress_bar: Display a progress bar using tqdm and rich.
:return: the trained model :return: the trained model
""" """
total_steps, callback = self._setup_learn( total_steps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps,
tb_log_name,
progress_bar,
) )
callback.on_training_start(locals(), globals()) callback.on_training_start(locals(), globals())

View File

@ -10,7 +10,7 @@ import torch as th
from gym import spaces from gym import spaces
from stable_baselines3.common import utils from stable_baselines3.common import utils
from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@ -184,6 +184,7 @@ class MaskablePPO(OnPolicyAlgorithm):
n_eval_episodes: int = 5, n_eval_episodes: int = 5,
log_path: Optional[str] = None, log_path: Optional[str] = None,
use_masking: bool = True, use_masking: bool = True,
progress_bar: bool = False,
) -> BaseCallback: ) -> BaseCallback:
""" """
:param callback: Callback(s) called at every step with state of the algorithm. :param callback: Callback(s) called at every step with state of the algorithm.
@ -196,6 +197,7 @@ class MaskablePPO(OnPolicyAlgorithm):
:param n_eval_episodes: How many episodes to play per evaluation :param n_eval_episodes: How many episodes to play per evaluation
:param log_path: Path to a folder where the evaluations will be saved :param log_path: Path to a folder where the evaluations will be saved
:param use_masking: Whether or not to use invalid action masks during evaluation :param use_masking: Whether or not to use invalid action masks during evaluation
:param progress_bar: Display a progress bar using tqdm and rich.
:return: A hybrid callback calling `callback` and performing evaluation. :return: A hybrid callback calling `callback` and performing evaluation.
""" """
# Convert a list of callbacks into a callback # Convert a list of callbacks into a callback
@ -206,6 +208,10 @@ class MaskablePPO(OnPolicyAlgorithm):
if not isinstance(callback, BaseCallback): if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback) callback = ConvertCallback(callback)
# Add progress bar callback
if progress_bar:
callback = CallbackList([callback, ProgressBarCallback()])
# Create eval callback in charge of the evaluation # Create eval callback in charge of the evaluation
if eval_env is not None: if eval_env is not None:
# Avoid circular import error # Avoid circular import error
@ -236,6 +242,7 @@ class MaskablePPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
tb_log_name: str = "run", tb_log_name: str = "run",
use_masking: bool = True, use_masking: bool = True,
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]: ) -> Tuple[int, BaseCallback]:
""" """
Initialize different variables needed for training. Initialize different variables needed for training.
@ -253,6 +260,7 @@ class MaskablePPO(OnPolicyAlgorithm):
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log :param tb_log_name: the name of the run for tensorboard log
:param use_masking: Whether or not to use invalid action masks during training :param use_masking: Whether or not to use invalid action masks during training
:param progress_bar: Display a progress bar using tqdm and rich.
:return: :return:
""" """
@ -299,7 +307,7 @@ class MaskablePPO(OnPolicyAlgorithm):
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps) self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
# Create eval callback if needed # Create eval callback if needed
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking) callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking, progress_bar)
return total_timesteps, callback return total_timesteps, callback
@ -563,6 +571,7 @@ class MaskablePPO(OnPolicyAlgorithm):
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
use_masking: bool = True, use_masking: bool = True,
progress_bar: bool = False,
) -> MaskablePPOSelf: ) -> MaskablePPOSelf:
iteration = 0 iteration = 0
@ -576,6 +585,7 @@ class MaskablePPO(OnPolicyAlgorithm):
reset_num_timesteps, reset_num_timesteps,
tb_log_name, tb_log_name,
use_masking, use_masking,
progress_bar,
) )
callback.on_training_start(locals(), globals()) callback.on_training_start(locals(), globals())

View File

@ -1,7 +1,7 @@
import sys import sys
import time import time
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union from typing import Any, Dict, Optional, Type, TypeVar, Union
import gym import gym
import numpy as np import numpy as np
@ -198,47 +198,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
self.clip_range_vf = get_schedule_fn(self.clip_range_vf) self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
def _setup_learn(
self,
total_timesteps: int,
eval_env: Optional[GymEnv],
callback: MaybeCallback = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "RecurrentPPO",
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on
:param eval_env: Environment to use for evaluation.
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param callback: Callback(s) called at every step with state of the algorithm.
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param n_eval_episodes: How many episodes to play per evaluation
:param log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:return:
"""
total_timesteps, callback = super()._setup_learn(
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
log_path,
reset_num_timesteps,
tb_log_name,
)
return total_timesteps, callback
def collect_rollouts( def collect_rollouts(
self, self,
env: VecEnv, env: VecEnv,
@ -500,11 +459,20 @@ class RecurrentPPO(OnPolicyAlgorithm):
tb_log_name: str = "RecurrentPPO", tb_log_name: str = "RecurrentPPO",
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> RecurrentPPOSelf: ) -> RecurrentPPOSelf:
iteration = 0 iteration = 0
total_timesteps, callback = self._setup_learn( total_timesteps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps,
tb_log_name,
progress_bar,
) )
callback.on_training_start(locals(), globals()) callback.on_training_start(locals(), globals())

View File

@ -262,6 +262,7 @@ class QRDQN(OffPolicyAlgorithm):
tb_log_name: str = "QRDQN", tb_log_name: str = "QRDQN",
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> QRDQNSelf: ) -> QRDQNSelf:
return super().learn( return super().learn(
@ -274,6 +275,7 @@ class QRDQN(OffPolicyAlgorithm):
tb_log_name=tb_log_name, tb_log_name=tb_log_name,
eval_log_path=eval_log_path, eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps, reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
) )
def _excluded_save_params(self) -> List[str]: def _excluded_save_params(self) -> List[str]:

View File

@ -299,6 +299,7 @@ class TQC(OffPolicyAlgorithm):
tb_log_name: str = "TQC", tb_log_name: str = "TQC",
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> TQCSelf: ) -> TQCSelf:
return super().learn( return super().learn(
@ -311,6 +312,7 @@ class TQC(OffPolicyAlgorithm):
tb_log_name=tb_log_name, tb_log_name=tb_log_name,
eval_log_path=eval_log_path, eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps, reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
) )
def _excluded_save_params(self) -> List[str]: def _excluded_save_params(self) -> List[str]:

View File

@ -415,6 +415,7 @@ class TRPO(OnPolicyAlgorithm):
tb_log_name: str = "TRPO", tb_log_name: str = "TRPO",
eval_log_path: Optional[str] = None, eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> TRPOSelf: ) -> TRPOSelf:
return super().learn( return super().learn(
@ -427,4 +428,5 @@ class TRPO(OnPolicyAlgorithm):
tb_log_name=tb_log_name, tb_log_name=tb_log_name,
eval_log_path=eval_log_path, eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps, reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
) )

View File

@ -190,7 +190,7 @@ def test_callback(tmp_path):
model = MaskablePPO("MlpPolicy", env, n_steps=64, gamma=0.4, seed=32, verbose=1) model = MaskablePPO("MlpPolicy", env, n_steps=64, gamma=0.4, seed=32, verbose=1)
model.learn(100, callback=MaskableEvalCallback(eval_env, eval_freq=100, warn=False, log_path=tmp_path)) model.learn(100, callback=MaskableEvalCallback(eval_env, eval_freq=100, warn=False, log_path=tmp_path))
model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False)) model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False), progress_bar=True)
def test_child_callback(): def test_child_callback():

View File

@ -20,7 +20,7 @@ def test_tqc(ent_coef):
create_eval_env=True, create_eval_env=True,
ent_coef=ent_coef, ent_coef=ent_coef,
) )
model.learn(total_timesteps=300, eval_freq=250) model.learn(total_timesteps=300, eval_freq=250, progress_bar=True)
@pytest.mark.parametrize("n_critics", [1, 3]) @pytest.mark.parametrize("n_critics", [1, 3])