Handle batch norm in target update (#99)
* Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996 * Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996 * Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996 * roll back test_cnn.py
This commit is contained in:
parent
7993b75781
commit
cad9034fdb
|
|
@ -21,6 +21,7 @@ Bug Fixes:
|
|||
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
||||
- Fixed calling child callbacks in MaskableEvalCallback (@CppMaster)
|
||||
- Fixed missing verbose parameter passing in the ``MaskableEvalCallback`` constructor (@burakdmb)
|
||||
- Fixed the issue that when updating the target network in QRDQN, TQC, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875)
|
||||
|
||||
|
||||
Deprecations:
|
||||
|
|
@ -305,4 +306,4 @@ Contributors:
|
|||
-------------
|
||||
|
||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
||||
@mlodel @CppMaster @burakdmb
|
||||
@mlodel @CppMaster @burakdmb @honglu2875
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
|||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
||||
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
|
||||
|
||||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
||||
|
|
@ -141,6 +141,9 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
def _setup_model(self) -> None:
|
||||
super()._setup_model()
|
||||
self._create_aliases()
|
||||
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
|
||||
self.batch_norm_stats = get_parameters_by_name(self.quantile_net, ["running_"])
|
||||
self.batch_norm_stats_target = get_parameters_by_name(self.quantile_net_target, ["running_"])
|
||||
self.exploration_schedule = get_linear_fn(
|
||||
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
|
||||
)
|
||||
|
|
@ -157,6 +160,8 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
"""
|
||||
if self.num_timesteps % self.target_update_interval == 0:
|
||||
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
||||
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
|
||||
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
|
||||
|
||||
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
|
||||
self.logger.record("rollout/exploration_rate", self.exploration_rate)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from stable_baselines3.common.noise import ActionNoise
|
|||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
from stable_baselines3.common.utils import polyak_update
|
||||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||
|
||||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
|
||||
|
|
@ -146,6 +146,9 @@ class TQC(OffPolicyAlgorithm):
|
|||
def _setup_model(self) -> None:
|
||||
super()._setup_model()
|
||||
self._create_aliases()
|
||||
# Running mean and running var
|
||||
self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
|
||||
self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])
|
||||
|
||||
# Target entropy is used when learning the entropy coefficient
|
||||
if self.target_entropy == "auto":
|
||||
|
|
@ -270,6 +273,8 @@ class TQC(OffPolicyAlgorithm):
|
|||
# Update target networks
|
||||
if gradient_step % self.target_update_interval == 0:
|
||||
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
|
||||
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
|
||||
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
|
||||
|
||||
self._n_updates += gradient_steps
|
||||
|
||||
|
|
|
|||
|
|
@ -143,6 +143,9 @@ def test_qrdqn_train_with_batch_norm():
|
|||
) = clone_qrdqn_batch_norm_stats(model)
|
||||
|
||||
model.learn(total_timesteps=200)
|
||||
# Force stats copy
|
||||
model.target_update_interval = 1
|
||||
model._on_step()
|
||||
|
||||
(
|
||||
quantile_net_bias_after,
|
||||
|
|
@ -152,10 +155,12 @@ def test_qrdqn_train_with_batch_norm():
|
|||
) = clone_qrdqn_batch_norm_stats(model)
|
||||
|
||||
assert ~th.isclose(quantile_net_bias_before, quantile_net_bias_after).all()
|
||||
assert ~th.isclose(quantile_net_running_mean_before, quantile_net_running_mean_after).all()
|
||||
# Running stat should be copied even when tau=0
|
||||
assert th.isclose(quantile_net_running_mean_before, quantile_net_target_running_mean_before).all()
|
||||
|
||||
assert th.isclose(quantile_net_target_bias_before, quantile_net_target_bias_after).all()
|
||||
assert th.isclose(quantile_net_target_running_mean_before, quantile_net_target_running_mean_after).all()
|
||||
# Running stat should be copied even when tau=0
|
||||
assert th.isclose(quantile_net_running_mean_after, quantile_net_target_running_mean_after).all()
|
||||
|
||||
|
||||
def test_tqc_train_with_batch_norm():
|
||||
|
|
@ -178,6 +183,9 @@ def test_tqc_train_with_batch_norm():
|
|||
) = clone_tqc_batch_norm_stats(model)
|
||||
|
||||
model.learn(total_timesteps=200)
|
||||
# Force stats copy
|
||||
model.target_update_interval = 1
|
||||
model._on_step()
|
||||
|
||||
(
|
||||
actor_bias_after,
|
||||
|
|
@ -192,10 +200,12 @@ def test_tqc_train_with_batch_norm():
|
|||
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
|
||||
|
||||
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
|
||||
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
|
||||
# Running stat should be copied even when tau=0
|
||||
assert th.isclose(critic_running_mean_before, critic_target_running_mean_before).all()
|
||||
|
||||
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
|
||||
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
|
||||
# Running stat should be copied even when tau=0
|
||||
assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
||||
|
|
|
|||
Loading…
Reference in New Issue