Maskable eval callback call callback fix (#93)
* call correctly both self.callback_on_new_best and self.callback - similar as in EvalCallback * MaskableEvalCallback - updated sync_envs_normalization handling * MaskableEvalCallback - updated sync_envs_normalization handling - test MaskablePPO - register policies (tests fails otherwise) * MaskableEvalCallback - updated docstring * updated changelog.rst * changes for stable-baselines3==1.6.0 * version range * suggested changes * Reformat and update version Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
fc68af8841
commit
eb48fec638
|
|
@ -15,7 +15,7 @@ dependencies:
|
||||||
- numpy
|
- numpy
|
||||||
- matplotlib
|
- matplotlib
|
||||||
- sphinx_autodoc_typehints
|
- sphinx_autodoc_typehints
|
||||||
- stable-baselines3>=1.3.0
|
- stable-baselines3>=1.6.0
|
||||||
- sphinx>=4.2
|
- sphinx>=4.2
|
||||||
- sphinx_rtd_theme>=1.0
|
- sphinx_rtd_theme>=1.0
|
||||||
- sphinx_copybutton
|
- sphinx_copybutton
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.6.1a1 (WIP)
|
Release 1.6.1a2 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -17,6 +17,7 @@ Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
- Fixed the issue of wrongly passing policy arguments when using CnnLstmPolicy or MultiInputLstmPolicy with ``RecurrentPPO`` (@mlodel)
|
- Fixed the issue of wrongly passing policy arguments when using CnnLstmPolicy or MultiInputLstmPolicy with ``RecurrentPPO`` (@mlodel)
|
||||||
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
||||||
|
- Fixed calling child callbacks in MaskableEvalCallback (@CppMaster)
|
||||||
|
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
|
|
@ -296,4 +297,4 @@ Contributors:
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
||||||
@mlodel
|
@mlodel @CppMaster
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
:param eval_env: The environment used for initialization
|
:param eval_env: The environment used for initialization
|
||||||
:param callback_on_new_best: Callback to trigger
|
:param callback_on_new_best: Callback to trigger
|
||||||
when there is a new best model according to the ``mean_reward``
|
when there is a new best model according to the ``mean_reward``
|
||||||
|
:param callback_after_eval: Callback to trigger after every evaluation
|
||||||
|
when there is a new best model according to the ``mean_reward``
|
||||||
:param n_eval_episodes: The number of episodes to test the agent
|
:param n_eval_episodes: The number of episodes to test the agent
|
||||||
:param eval_freq: Evaluate the agent every eval_freq call of the callback.
|
:param eval_freq: Evaluate the agent every eval_freq call of the callback.
|
||||||
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
|
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
|
||||||
|
|
@ -26,7 +28,7 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
:param verbose:
|
:param verbose:
|
||||||
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
|
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
|
||||||
wrapped with a Monitor wrapper)
|
wrapped with a Monitor wrapper)
|
||||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
:param use_masking: Whether to use invalid action masks during evaluation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, use_masking: bool = True, **kwargs):
|
def __init__(self, *args, use_masking: bool = True, **kwargs):
|
||||||
|
|
@ -34,9 +36,21 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
self.use_masking = use_masking
|
self.use_masking = use_masking
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
|
|
||||||
|
continue_training = True
|
||||||
|
|
||||||
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
||||||
|
|
||||||
# Sync training and eval env if there is VecNormalize
|
# Sync training and eval env if there is VecNormalize
|
||||||
sync_envs_normalization(self.training_env, self.eval_env)
|
if self.model.get_vec_normalize_env() is not None:
|
||||||
|
try:
|
||||||
|
sync_envs_normalization(self.training_env, self.eval_env)
|
||||||
|
except AttributeError as e:
|
||||||
|
raise AssertionError(
|
||||||
|
"Training and eval env are not wrapped the same way, "
|
||||||
|
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
|
||||||
|
"and warning above."
|
||||||
|
) from e
|
||||||
|
|
||||||
# Reset success rate buffer
|
# Reset success rate buffer
|
||||||
self._is_success_buffer = []
|
self._is_success_buffer = []
|
||||||
|
|
@ -91,7 +105,7 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
self.logger.record("eval/success_rate", success_rate)
|
self.logger.record("eval/success_rate", success_rate)
|
||||||
|
|
||||||
# Dump log so the evaluation results are printed with the correct timestep
|
# Dump log so the evaluation results are printed with the correct timestep
|
||||||
self.logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
|
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||||
self.logger.dump(self.num_timesteps)
|
self.logger.dump(self.num_timesteps)
|
||||||
|
|
||||||
if mean_reward > self.best_mean_reward:
|
if mean_reward > self.best_mean_reward:
|
||||||
|
|
@ -100,8 +114,12 @@ class MaskableEvalCallback(EvalCallback):
|
||||||
if self.best_model_save_path is not None:
|
if self.best_model_save_path is not None:
|
||||||
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
||||||
self.best_mean_reward = mean_reward
|
self.best_mean_reward = mean_reward
|
||||||
# Trigger callback if needed
|
# Trigger callback on new best model, if needed
|
||||||
if self.callback is not None:
|
if self.callback_on_new_best is not None:
|
||||||
return self._on_event()
|
continue_training = self.callback_on_new_best.on_step()
|
||||||
|
|
||||||
return True
|
# Trigger callback after every evaluation, if needed
|
||||||
|
if self.callback is not None:
|
||||||
|
continue_training = continue_training and self._on_event()
|
||||||
|
|
||||||
|
return continue_training
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.6.1a1
|
1.6.1a2
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import random
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import pytest
|
import pytest
|
||||||
|
from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement
|
||||||
from stable_baselines3.common.env_util import make_vec_env
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
|
|
@ -191,6 +192,28 @@ def test_callback(tmp_path):
|
||||||
model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False))
|
model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False))
|
||||||
|
|
||||||
|
|
||||||
|
def test_child_callback():
|
||||||
|
"""
|
||||||
|
Stop callback and callback on new best rewards
|
||||||
|
"""
|
||||||
|
|
||||||
|
env = make_env()
|
||||||
|
eval_env = make_env()
|
||||||
|
model = MaskablePPO("MlpPolicy", env, n_steps=64, n_epochs=1)
|
||||||
|
stop_callback = StopTrainingOnNoModelImprovement(1, 2)
|
||||||
|
new_best_mean_callback = EventCallback()
|
||||||
|
eval_callback = MaskableEvalCallback(
|
||||||
|
Monitor(eval_env),
|
||||||
|
eval_freq=64,
|
||||||
|
callback_after_eval=stop_callback,
|
||||||
|
callback_on_new_best=new_best_mean_callback,
|
||||||
|
)
|
||||||
|
model.learn(128, callback=eval_callback)
|
||||||
|
assert new_best_mean_callback.n_calls > 0
|
||||||
|
assert stop_callback.n_calls > 0
|
||||||
|
assert stop_callback.n_calls >= new_best_mean_callback.n_calls
|
||||||
|
|
||||||
|
|
||||||
def test_maskable_policy_required():
|
def test_maskable_policy_required():
|
||||||
"""
|
"""
|
||||||
MaskablePPO requires a policy that subclasses MaskableActorCriticPolicy
|
MaskablePPO requires a policy that subclasses MaskableActorCriticPolicy
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue