diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 64b1aaf..cd22d6b 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -15,7 +15,7 @@ dependencies: - numpy - matplotlib - sphinx_autodoc_typehints - - stable-baselines3>=1.3.0 + - stable-baselines3>=1.6.0 - sphinx>=4.2 - sphinx_rtd_theme>=1.0 - sphinx_copybutton diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5a5f92f..957ba1e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.6.1a1 (WIP) +Release 1.6.1a2 (WIP) ------------------------------- Breaking Changes: @@ -17,6 +17,7 @@ Bug Fixes: ^^^^^^^^^^ - Fixed the issue of wrongly passing policy arguments when using CnnLstmPolicy or MultiInputLstmPolicy with ``RecurrentPPO`` (@mlodel) - Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. +- Fixed calling child callbacks in MaskableEvalCallback (@CppMaster) Deprecations: @@ -296,4 +297,4 @@ Contributors: ------------- @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec -@mlodel +@mlodel @CppMaster diff --git a/sb3_contrib/common/maskable/callbacks.py b/sb3_contrib/common/maskable/callbacks.py index 9832548..46927c6 100644 --- a/sb3_contrib/common/maskable/callbacks.py +++ b/sb3_contrib/common/maskable/callbacks.py @@ -14,6 +14,8 @@ class MaskableEvalCallback(EvalCallback): :param eval_env: The environment used for initialization :param callback_on_new_best: Callback to trigger when there is a new best model according to the ``mean_reward`` + :param callback_after_eval: Callback to trigger after every evaluation + when there is a new best model according to the ``mean_reward`` :param n_eval_episodes: The number of episodes to test the agent :param eval_freq: Evaluate the agent every eval_freq call of the callback. :param log_path: Path to a folder where the evaluations (``evaluations.npz``) @@ -26,7 +28,7 @@ class MaskableEvalCallback(EvalCallback): :param verbose: :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been wrapped with a Monitor wrapper) - :param use_masking: Whether or not to use invalid action masks during evaluation + :param use_masking: Whether to use invalid action masks during evaluation """ def __init__(self, *args, use_masking: bool = True, **kwargs): @@ -34,9 +36,21 @@ class MaskableEvalCallback(EvalCallback): self.use_masking = use_masking def _on_step(self) -> bool: + + continue_training = True + if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: + # Sync training and eval env if there is VecNormalize - sync_envs_normalization(self.training_env, self.eval_env) + if self.model.get_vec_normalize_env() is not None: + try: + sync_envs_normalization(self.training_env, self.eval_env) + except AttributeError as e: + raise AssertionError( + "Training and eval env are not wrapped the same way, " + "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback " + "and warning above." + ) from e # Reset success rate buffer self._is_success_buffer = [] @@ -91,7 +105,7 @@ class MaskableEvalCallback(EvalCallback): self.logger.record("eval/success_rate", success_rate) # Dump log so the evaluation results are printed with the correct timestep - self.logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.dump(self.num_timesteps) if mean_reward > self.best_mean_reward: @@ -100,8 +114,12 @@ class MaskableEvalCallback(EvalCallback): if self.best_model_save_path is not None: self.model.save(os.path.join(self.best_model_save_path, "best_model")) self.best_mean_reward = mean_reward - # Trigger callback if needed - if self.callback is not None: - return self._on_event() + # Trigger callback on new best model, if needed + if self.callback_on_new_best is not None: + continue_training = self.callback_on_new_best.on_step() - return True + # Trigger callback after every evaluation, if needed + if self.callback is not None: + continue_training = continue_training and self._on_event() + + return continue_training diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index e36b727..51cf83a 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.6.1a1 +1.6.1a2 diff --git a/tests/test_invalid_actions.py b/tests/test_invalid_actions.py index d530cfc..7187800 100644 --- a/tests/test_invalid_actions.py +++ b/tests/test_invalid_actions.py @@ -2,6 +2,7 @@ import random import gym import pytest +from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.monitor import Monitor @@ -191,6 +192,28 @@ def test_callback(tmp_path): model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False)) +def test_child_callback(): + """ + Stop callback and callback on new best rewards + """ + + env = make_env() + eval_env = make_env() + model = MaskablePPO("MlpPolicy", env, n_steps=64, n_epochs=1) + stop_callback = StopTrainingOnNoModelImprovement(1, 2) + new_best_mean_callback = EventCallback() + eval_callback = MaskableEvalCallback( + Monitor(eval_env), + eval_freq=64, + callback_after_eval=stop_callback, + callback_on_new_best=new_best_mean_callback, + ) + model.learn(128, callback=eval_callback) + assert new_best_mean_callback.n_calls > 0 + assert stop_callback.n_calls > 0 + assert stop_callback.n_calls >= new_best_mean_callback.n_calls + + def test_maskable_policy_required(): """ MaskablePPO requires a policy that subclasses MaskableActorCriticPolicy