Handle batch norm in target update (#99)

* Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996

* Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996

* Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996

* roll back test_cnn.py
This commit is contained in:
Honglu Fan 2022-08-27 04:31:00 -06:00 committed by GitHub
parent 7993b75781
commit cad9034fdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 7 deletions

View File

@ -21,6 +21,7 @@ Bug Fixes:
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
- Fixed calling child callbacks in MaskableEvalCallback (@CppMaster)
- Fixed missing verbose parameter passing in the ``MaskableEvalCallback`` constructor (@burakdmb)
- Fixed the issue that when updating the target network in QRDQN, TQC, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875)
Deprecations:
@ -305,4 +306,4 @@ Contributors:
-------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
@mlodel @CppMaster @burakdmb
@mlodel @CppMaster @burakdmb @honglu2875

View File

@ -8,7 +8,7 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
@ -141,6 +141,9 @@ class QRDQN(OffPolicyAlgorithm):
def _setup_model(self) -> None:
super()._setup_model()
self._create_aliases()
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
self.batch_norm_stats = get_parameters_by_name(self.quantile_net, ["running_"])
self.batch_norm_stats_target = get_parameters_by_name(self.quantile_net_target, ["running_"])
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
)
@ -157,6 +160,8 @@ class QRDQN(OffPolicyAlgorithm):
"""
if self.num_timesteps % self.target_update_interval == 0:
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
self.logger.record("rollout/exploration_rate", self.exploration_rate)

View File

@ -8,7 +8,7 @@ from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
@ -146,6 +146,9 @@ class TQC(OffPolicyAlgorithm):
def _setup_model(self) -> None:
super()._setup_model()
self._create_aliases()
# Running mean and running var
self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
@ -270,6 +273,8 @@ class TQC(OffPolicyAlgorithm):
# Update target networks
if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
self._n_updates += gradient_steps

View File

@ -143,6 +143,9 @@ def test_qrdqn_train_with_batch_norm():
) = clone_qrdqn_batch_norm_stats(model)
model.learn(total_timesteps=200)
# Force stats copy
model.target_update_interval = 1
model._on_step()
(
quantile_net_bias_after,
@ -152,10 +155,12 @@ def test_qrdqn_train_with_batch_norm():
) = clone_qrdqn_batch_norm_stats(model)
assert ~th.isclose(quantile_net_bias_before, quantile_net_bias_after).all()
assert ~th.isclose(quantile_net_running_mean_before, quantile_net_running_mean_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(quantile_net_running_mean_before, quantile_net_target_running_mean_before).all()
assert th.isclose(quantile_net_target_bias_before, quantile_net_target_bias_after).all()
assert th.isclose(quantile_net_target_running_mean_before, quantile_net_target_running_mean_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(quantile_net_running_mean_after, quantile_net_target_running_mean_after).all()
def test_tqc_train_with_batch_norm():
@ -178,6 +183,9 @@ def test_tqc_train_with_batch_norm():
) = clone_tqc_batch_norm_stats(model)
model.learn(total_timesteps=200)
# Force stats copy
model.target_update_interval = 1
model._on_step()
(
actor_bias_after,
@ -192,10 +200,12 @@ def test_tqc_train_with_batch_norm():
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(critic_running_mean_before, critic_target_running_mean_before).all()
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()
@pytest.mark.parametrize("model_class", [QRDQN, TQC])