Maskable eval callback call callback fix (#93)
* call correctly both self.callback_on_new_best and self.callback - similar as in EvalCallback * MaskableEvalCallback - updated sync_envs_normalization handling * MaskableEvalCallback - updated sync_envs_normalization handling - test MaskablePPO - register policies (tests fails otherwise) * MaskableEvalCallback - updated docstring * updated changelog.rst * changes for stable-baselines3==1.6.0 * version range * suggested changes * Reformat and update version Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
fc68af8841
commit
eb48fec638
|
|
@ -15,7 +15,7 @@ dependencies:
|
|||
- numpy
|
||||
- matplotlib
|
||||
- sphinx_autodoc_typehints
|
||||
- stable-baselines3>=1.3.0
|
||||
- stable-baselines3>=1.6.0
|
||||
- sphinx>=4.2
|
||||
- sphinx_rtd_theme>=1.0
|
||||
- sphinx_copybutton
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.6.1a1 (WIP)
|
||||
Release 1.6.1a2 (WIP)
|
||||
-------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -17,6 +17,7 @@ Bug Fixes:
|
|||
^^^^^^^^^^
|
||||
- Fixed the issue of wrongly passing policy arguments when using CnnLstmPolicy or MultiInputLstmPolicy with ``RecurrentPPO`` (@mlodel)
|
||||
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
||||
- Fixed calling child callbacks in MaskableEvalCallback (@CppMaster)
|
||||
|
||||
|
||||
Deprecations:
|
||||
|
|
@ -296,4 +297,4 @@ Contributors:
|
|||
-------------
|
||||
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
||||
@mlodel
|
||||
@mlodel @CppMaster
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ class MaskableEvalCallback(EvalCallback):
|
|||
:param eval_env: The environment used for initialization
|
||||
:param callback_on_new_best: Callback to trigger
|
||||
when there is a new best model according to the ``mean_reward``
|
||||
:param callback_after_eval: Callback to trigger after every evaluation
|
||||
when there is a new best model according to the ``mean_reward``
|
||||
:param n_eval_episodes: The number of episodes to test the agent
|
||||
:param eval_freq: Evaluate the agent every eval_freq call of the callback.
|
||||
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
|
||||
|
|
@ -26,7 +28,7 @@ class MaskableEvalCallback(EvalCallback):
|
|||
:param verbose:
|
||||
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
|
||||
wrapped with a Monitor wrapper)
|
||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
||||
:param use_masking: Whether to use invalid action masks during evaluation
|
||||
"""
|
||||
|
||||
def __init__(self, *args, use_masking: bool = True, **kwargs):
|
||||
|
|
@ -34,9 +36,21 @@ class MaskableEvalCallback(EvalCallback):
|
|||
self.use_masking = use_masking
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
|
||||
continue_training = True
|
||||
|
||||
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
||||
|
||||
# Sync training and eval env if there is VecNormalize
|
||||
sync_envs_normalization(self.training_env, self.eval_env)
|
||||
if self.model.get_vec_normalize_env() is not None:
|
||||
try:
|
||||
sync_envs_normalization(self.training_env, self.eval_env)
|
||||
except AttributeError as e:
|
||||
raise AssertionError(
|
||||
"Training and eval env are not wrapped the same way, "
|
||||
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
|
||||
"and warning above."
|
||||
) from e
|
||||
|
||||
# Reset success rate buffer
|
||||
self._is_success_buffer = []
|
||||
|
|
@ -91,7 +105,7 @@ class MaskableEvalCallback(EvalCallback):
|
|||
self.logger.record("eval/success_rate", success_rate)
|
||||
|
||||
# Dump log so the evaluation results are printed with the correct timestep
|
||||
self.logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(self.num_timesteps)
|
||||
|
||||
if mean_reward > self.best_mean_reward:
|
||||
|
|
@ -100,8 +114,12 @@ class MaskableEvalCallback(EvalCallback):
|
|||
if self.best_model_save_path is not None:
|
||||
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
||||
self.best_mean_reward = mean_reward
|
||||
# Trigger callback if needed
|
||||
if self.callback is not None:
|
||||
return self._on_event()
|
||||
# Trigger callback on new best model, if needed
|
||||
if self.callback_on_new_best is not None:
|
||||
continue_training = self.callback_on_new_best.on_step()
|
||||
|
||||
return True
|
||||
# Trigger callback after every evaluation, if needed
|
||||
if self.callback is not None:
|
||||
continue_training = continue_training and self._on_event()
|
||||
|
||||
return continue_training
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.6.1a1
|
||||
1.6.1a2
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import random
|
|||
|
||||
import gym
|
||||
import pytest
|
||||
from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
|
|
@ -191,6 +192,28 @@ def test_callback(tmp_path):
|
|||
model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False))
|
||||
|
||||
|
||||
def test_child_callback():
|
||||
"""
|
||||
Stop callback and callback on new best rewards
|
||||
"""
|
||||
|
||||
env = make_env()
|
||||
eval_env = make_env()
|
||||
model = MaskablePPO("MlpPolicy", env, n_steps=64, n_epochs=1)
|
||||
stop_callback = StopTrainingOnNoModelImprovement(1, 2)
|
||||
new_best_mean_callback = EventCallback()
|
||||
eval_callback = MaskableEvalCallback(
|
||||
Monitor(eval_env),
|
||||
eval_freq=64,
|
||||
callback_after_eval=stop_callback,
|
||||
callback_on_new_best=new_best_mean_callback,
|
||||
)
|
||||
model.learn(128, callback=eval_callback)
|
||||
assert new_best_mean_callback.n_calls > 0
|
||||
assert stop_callback.n_calls > 0
|
||||
assert stop_callback.n_calls >= new_best_mean_callback.n_calls
|
||||
|
||||
|
||||
def test_maskable_policy_required():
|
||||
"""
|
||||
MaskablePPO requires a policy that subclasses MaskableActorCriticPolicy
|
||||
|
|
|
|||
Loading…
Reference in New Issue