From cad9034fdbb0741378c9f197675a5c80078286ab Mon Sep 17 00:00:00 2001 From: Honglu Fan <64070721+honglu2875@users.noreply.github.com> Date: Sat, 27 Aug 2022 04:31:00 -0600 Subject: [PATCH] Handle batch norm in target update (#99) * Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996 * Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996 * Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996 * roll back test_cnn.py --- docs/misc/changelog.rst | 3 ++- sb3_contrib/qrdqn/qrdqn.py | 7 ++++++- sb3_contrib/tqc/tqc.py | 7 ++++++- tests/test_train_eval_mode.py | 18 ++++++++++++++---- 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9c12ef2..99d86c5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -21,6 +21,7 @@ Bug Fixes: - Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. - Fixed calling child callbacks in MaskableEvalCallback (@CppMaster) - Fixed missing verbose parameter passing in the ``MaskableEvalCallback`` constructor (@burakdmb) +- Fixed the issue that when updating the target network in QRDQN, TQC, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875) Deprecations: @@ -305,4 +306,4 @@ Contributors: ------------- @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec -@mlodel @CppMaster @burakdmb +@mlodel @CppMaster @burakdmb @honglu2875 diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index ac07634..273665f 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -8,7 +8,7 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update +from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy @@ -141,6 +141,9 @@ class QRDQN(OffPolicyAlgorithm): def _setup_model(self) -> None: super()._setup_model() self._create_aliases() + # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 + self.batch_norm_stats = get_parameters_by_name(self.quantile_net, ["running_"]) + self.batch_norm_stats_target = get_parameters_by_name(self.quantile_net_target, ["running_"]) self.exploration_schedule = get_linear_fn( self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction ) @@ -157,6 +160,8 @@ class QRDQN(OffPolicyAlgorithm): """ if self.num_timesteps % self.target_update_interval == 0: polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau) + # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 + polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) self.logger.record("rollout/exploration_rate", self.exploration_rate) diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 51d9f12..8ec746b 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -8,7 +8,7 @@ from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback -from stable_baselines3.common.utils import polyak_update +from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy @@ -146,6 +146,9 @@ class TQC(OffPolicyAlgorithm): def _setup_model(self) -> None: super()._setup_model() self._create_aliases() + # Running mean and running var + self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"]) + self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"]) # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": @@ -270,6 +273,8 @@ class TQC(OffPolicyAlgorithm): # Update target networks if gradient_step % self.target_update_interval == 0: polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) + # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 + polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) self._n_updates += gradient_steps diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index cf7e8aa..1973d45 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -143,6 +143,9 @@ def test_qrdqn_train_with_batch_norm(): ) = clone_qrdqn_batch_norm_stats(model) model.learn(total_timesteps=200) + # Force stats copy + model.target_update_interval = 1 + model._on_step() ( quantile_net_bias_after, @@ -152,10 +155,12 @@ def test_qrdqn_train_with_batch_norm(): ) = clone_qrdqn_batch_norm_stats(model) assert ~th.isclose(quantile_net_bias_before, quantile_net_bias_after).all() - assert ~th.isclose(quantile_net_running_mean_before, quantile_net_running_mean_after).all() + # Running stat should be copied even when tau=0 + assert th.isclose(quantile_net_running_mean_before, quantile_net_target_running_mean_before).all() assert th.isclose(quantile_net_target_bias_before, quantile_net_target_bias_after).all() - assert th.isclose(quantile_net_target_running_mean_before, quantile_net_target_running_mean_after).all() + # Running stat should be copied even when tau=0 + assert th.isclose(quantile_net_running_mean_after, quantile_net_target_running_mean_after).all() def test_tqc_train_with_batch_norm(): @@ -178,6 +183,9 @@ def test_tqc_train_with_batch_norm(): ) = clone_tqc_batch_norm_stats(model) model.learn(total_timesteps=200) + # Force stats copy + model.target_update_interval = 1 + model._on_step() ( actor_bias_after, @@ -192,10 +200,12 @@ def test_tqc_train_with_batch_norm(): assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all() assert ~th.isclose(critic_bias_before, critic_bias_after).all() - assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all() + # Running stat should be copied even when tau=0 + assert th.isclose(critic_running_mean_before, critic_target_running_mean_before).all() assert th.isclose(critic_target_bias_before, critic_target_bias_after).all() - assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all() + # Running stat should be copied even when tau=0 + assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all() @pytest.mark.parametrize("model_class", [QRDQN, TQC])