Add progress bar argument (#107)
* Add progress bar argument * Sort imports
This commit is contained in:
parent
e9c97948c8
commit
52795a307e
|
|
@ -15,6 +15,7 @@ Breaking Changes:
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
- Added ``progress_bar`` argument in the ``learn()`` method, displayed using TQDM and rich packages
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -315,6 +315,7 @@ class ARS(BaseAlgorithm):
|
||||||
eval_log_path: Optional[str] = None,
|
eval_log_path: Optional[str] = None,
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
async_eval: Optional[AsyncEval] = None,
|
async_eval: Optional[AsyncEval] = None,
|
||||||
|
progress_bar: bool = False,
|
||||||
) -> ARSSelf:
|
) -> ARSSelf:
|
||||||
"""
|
"""
|
||||||
Return a trained model.
|
Return a trained model.
|
||||||
|
|
@ -333,11 +334,20 @@ class ARS(BaseAlgorithm):
|
||||||
:param eval_log_path: Path to a folder where the evaluations will be saved
|
:param eval_log_path: Path to a folder where the evaluations will be saved
|
||||||
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
||||||
:param async_eval: The object for asynchronous evaluation of candidates.
|
:param async_eval: The object for asynchronous evaluation of candidates.
|
||||||
|
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||||
:return: the trained model
|
:return: the trained model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
total_steps, callback = self._setup_learn(
|
total_steps, callback = self._setup_learn(
|
||||||
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
|
total_timesteps,
|
||||||
|
eval_env,
|
||||||
|
callback,
|
||||||
|
eval_freq,
|
||||||
|
n_eval_episodes,
|
||||||
|
eval_log_path,
|
||||||
|
reset_num_timesteps,
|
||||||
|
tb_log_name,
|
||||||
|
progress_bar,
|
||||||
)
|
)
|
||||||
|
|
||||||
callback.on_training_start(locals(), globals())
|
callback.on_training_start(locals(), globals())
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import torch as th
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from stable_baselines3.common import utils
|
from stable_baselines3.common import utils
|
||||||
from stable_baselines3.common.buffers import RolloutBuffer
|
from stable_baselines3.common.buffers import RolloutBuffer
|
||||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback
|
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
|
|
@ -184,6 +184,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
n_eval_episodes: int = 5,
|
n_eval_episodes: int = 5,
|
||||||
log_path: Optional[str] = None,
|
log_path: Optional[str] = None,
|
||||||
use_masking: bool = True,
|
use_masking: bool = True,
|
||||||
|
progress_bar: bool = False,
|
||||||
) -> BaseCallback:
|
) -> BaseCallback:
|
||||||
"""
|
"""
|
||||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||||
|
|
@ -196,6 +197,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
:param n_eval_episodes: How many episodes to play per evaluation
|
:param n_eval_episodes: How many episodes to play per evaluation
|
||||||
:param log_path: Path to a folder where the evaluations will be saved
|
:param log_path: Path to a folder where the evaluations will be saved
|
||||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
:param use_masking: Whether or not to use invalid action masks during evaluation
|
||||||
|
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||||
:return: A hybrid callback calling `callback` and performing evaluation.
|
:return: A hybrid callback calling `callback` and performing evaluation.
|
||||||
"""
|
"""
|
||||||
# Convert a list of callbacks into a callback
|
# Convert a list of callbacks into a callback
|
||||||
|
|
@ -206,6 +208,10 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
if not isinstance(callback, BaseCallback):
|
if not isinstance(callback, BaseCallback):
|
||||||
callback = ConvertCallback(callback)
|
callback = ConvertCallback(callback)
|
||||||
|
|
||||||
|
# Add progress bar callback
|
||||||
|
if progress_bar:
|
||||||
|
callback = CallbackList([callback, ProgressBarCallback()])
|
||||||
|
|
||||||
# Create eval callback in charge of the evaluation
|
# Create eval callback in charge of the evaluation
|
||||||
if eval_env is not None:
|
if eval_env is not None:
|
||||||
# Avoid circular import error
|
# Avoid circular import error
|
||||||
|
|
@ -236,6 +242,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
tb_log_name: str = "run",
|
tb_log_name: str = "run",
|
||||||
use_masking: bool = True,
|
use_masking: bool = True,
|
||||||
|
progress_bar: bool = False,
|
||||||
) -> Tuple[int, BaseCallback]:
|
) -> Tuple[int, BaseCallback]:
|
||||||
"""
|
"""
|
||||||
Initialize different variables needed for training.
|
Initialize different variables needed for training.
|
||||||
|
|
@ -253,6 +260,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
||||||
:param tb_log_name: the name of the run for tensorboard log
|
:param tb_log_name: the name of the run for tensorboard log
|
||||||
:param use_masking: Whether or not to use invalid action masks during training
|
:param use_masking: Whether or not to use invalid action masks during training
|
||||||
|
:param progress_bar: Display a progress bar using tqdm and rich.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -299,7 +307,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
||||||
|
|
||||||
# Create eval callback if needed
|
# Create eval callback if needed
|
||||||
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking)
|
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking, progress_bar)
|
||||||
|
|
||||||
return total_timesteps, callback
|
return total_timesteps, callback
|
||||||
|
|
||||||
|
|
@ -563,6 +571,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
eval_log_path: Optional[str] = None,
|
eval_log_path: Optional[str] = None,
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
use_masking: bool = True,
|
use_masking: bool = True,
|
||||||
|
progress_bar: bool = False,
|
||||||
) -> MaskablePPOSelf:
|
) -> MaskablePPOSelf:
|
||||||
iteration = 0
|
iteration = 0
|
||||||
|
|
||||||
|
|
@ -576,6 +585,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
reset_num_timesteps,
|
reset_num_timesteps,
|
||||||
tb_log_name,
|
tb_log_name,
|
||||||
use_masking,
|
use_masking,
|
||||||
|
progress_bar,
|
||||||
)
|
)
|
||||||
|
|
||||||
callback.on_training_start(locals(), globals())
|
callback.on_training_start(locals(), globals())
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -198,47 +198,6 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||||
|
|
||||||
def _setup_learn(
|
|
||||||
self,
|
|
||||||
total_timesteps: int,
|
|
||||||
eval_env: Optional[GymEnv],
|
|
||||||
callback: MaybeCallback = None,
|
|
||||||
eval_freq: int = 10000,
|
|
||||||
n_eval_episodes: int = 5,
|
|
||||||
log_path: Optional[str] = None,
|
|
||||||
reset_num_timesteps: bool = True,
|
|
||||||
tb_log_name: str = "RecurrentPPO",
|
|
||||||
) -> Tuple[int, BaseCallback]:
|
|
||||||
"""
|
|
||||||
Initialize different variables needed for training.
|
|
||||||
|
|
||||||
:param total_timesteps: The total number of samples (env steps) to train on
|
|
||||||
:param eval_env: Environment to use for evaluation.
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
Please use `EvalCallback` or a custom Callback instead.
|
|
||||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
|
||||||
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
|
||||||
Caution, this parameter is deprecated and will be removed in the future.
|
|
||||||
Please use `EvalCallback` or a custom Callback instead.
|
|
||||||
:param n_eval_episodes: How many episodes to play per evaluation
|
|
||||||
:param log_path: Path to a folder where the evaluations will be saved
|
|
||||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
|
||||||
:param tb_log_name: the name of the run for tensorboard log
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
total_timesteps, callback = super()._setup_learn(
|
|
||||||
total_timesteps,
|
|
||||||
eval_env,
|
|
||||||
callback,
|
|
||||||
eval_freq,
|
|
||||||
n_eval_episodes,
|
|
||||||
log_path,
|
|
||||||
reset_num_timesteps,
|
|
||||||
tb_log_name,
|
|
||||||
)
|
|
||||||
return total_timesteps, callback
|
|
||||||
|
|
||||||
def collect_rollouts(
|
def collect_rollouts(
|
||||||
self,
|
self,
|
||||||
env: VecEnv,
|
env: VecEnv,
|
||||||
|
|
@ -500,11 +459,20 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
tb_log_name: str = "RecurrentPPO",
|
tb_log_name: str = "RecurrentPPO",
|
||||||
eval_log_path: Optional[str] = None,
|
eval_log_path: Optional[str] = None,
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
|
progress_bar: bool = False,
|
||||||
) -> RecurrentPPOSelf:
|
) -> RecurrentPPOSelf:
|
||||||
iteration = 0
|
iteration = 0
|
||||||
|
|
||||||
total_timesteps, callback = self._setup_learn(
|
total_timesteps, callback = self._setup_learn(
|
||||||
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
|
total_timesteps,
|
||||||
|
eval_env,
|
||||||
|
callback,
|
||||||
|
eval_freq,
|
||||||
|
n_eval_episodes,
|
||||||
|
eval_log_path,
|
||||||
|
reset_num_timesteps,
|
||||||
|
tb_log_name,
|
||||||
|
progress_bar,
|
||||||
)
|
)
|
||||||
|
|
||||||
callback.on_training_start(locals(), globals())
|
callback.on_training_start(locals(), globals())
|
||||||
|
|
|
||||||
|
|
@ -262,6 +262,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
tb_log_name: str = "QRDQN",
|
tb_log_name: str = "QRDQN",
|
||||||
eval_log_path: Optional[str] = None,
|
eval_log_path: Optional[str] = None,
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
|
progress_bar: bool = False,
|
||||||
) -> QRDQNSelf:
|
) -> QRDQNSelf:
|
||||||
|
|
||||||
return super().learn(
|
return super().learn(
|
||||||
|
|
@ -274,6 +275,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
tb_log_name=tb_log_name,
|
tb_log_name=tb_log_name,
|
||||||
eval_log_path=eval_log_path,
|
eval_log_path=eval_log_path,
|
||||||
reset_num_timesteps=reset_num_timesteps,
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
|
progress_bar=progress_bar,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _excluded_save_params(self) -> List[str]:
|
def _excluded_save_params(self) -> List[str]:
|
||||||
|
|
|
||||||
|
|
@ -299,6 +299,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
tb_log_name: str = "TQC",
|
tb_log_name: str = "TQC",
|
||||||
eval_log_path: Optional[str] = None,
|
eval_log_path: Optional[str] = None,
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
|
progress_bar: bool = False,
|
||||||
) -> TQCSelf:
|
) -> TQCSelf:
|
||||||
|
|
||||||
return super().learn(
|
return super().learn(
|
||||||
|
|
@ -311,6 +312,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
tb_log_name=tb_log_name,
|
tb_log_name=tb_log_name,
|
||||||
eval_log_path=eval_log_path,
|
eval_log_path=eval_log_path,
|
||||||
reset_num_timesteps=reset_num_timesteps,
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
|
progress_bar=progress_bar,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _excluded_save_params(self) -> List[str]:
|
def _excluded_save_params(self) -> List[str]:
|
||||||
|
|
|
||||||
|
|
@ -415,6 +415,7 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
tb_log_name: str = "TRPO",
|
tb_log_name: str = "TRPO",
|
||||||
eval_log_path: Optional[str] = None,
|
eval_log_path: Optional[str] = None,
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
|
progress_bar: bool = False,
|
||||||
) -> TRPOSelf:
|
) -> TRPOSelf:
|
||||||
|
|
||||||
return super().learn(
|
return super().learn(
|
||||||
|
|
@ -427,4 +428,5 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
tb_log_name=tb_log_name,
|
tb_log_name=tb_log_name,
|
||||||
eval_log_path=eval_log_path,
|
eval_log_path=eval_log_path,
|
||||||
reset_num_timesteps=reset_num_timesteps,
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
|
progress_bar=progress_bar,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,7 @@ def test_callback(tmp_path):
|
||||||
model = MaskablePPO("MlpPolicy", env, n_steps=64, gamma=0.4, seed=32, verbose=1)
|
model = MaskablePPO("MlpPolicy", env, n_steps=64, gamma=0.4, seed=32, verbose=1)
|
||||||
model.learn(100, callback=MaskableEvalCallback(eval_env, eval_freq=100, warn=False, log_path=tmp_path))
|
model.learn(100, callback=MaskableEvalCallback(eval_env, eval_freq=100, warn=False, log_path=tmp_path))
|
||||||
|
|
||||||
model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False))
|
model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False), progress_bar=True)
|
||||||
|
|
||||||
|
|
||||||
def test_child_callback():
|
def test_child_callback():
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ def test_tqc(ent_coef):
|
||||||
create_eval_env=True,
|
create_eval_env=True,
|
||||||
ent_coef=ent_coef,
|
ent_coef=ent_coef,
|
||||||
)
|
)
|
||||||
model.learn(total_timesteps=300, eval_freq=250)
|
model.learn(total_timesteps=300, eval_freq=250, progress_bar=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_critics", [1, 3])
|
@pytest.mark.parametrize("n_critics", [1, 3])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue