Handle batch norm in target update (#99)
* Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996 * Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996 * Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996 * roll back test_cnn.py
This commit is contained in:
parent
7993b75781
commit
cad9034fdb
|
|
@ -21,6 +21,7 @@ Bug Fixes:
|
||||||
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
||||||
- Fixed calling child callbacks in MaskableEvalCallback (@CppMaster)
|
- Fixed calling child callbacks in MaskableEvalCallback (@CppMaster)
|
||||||
- Fixed missing verbose parameter passing in the ``MaskableEvalCallback`` constructor (@burakdmb)
|
- Fixed missing verbose parameter passing in the ``MaskableEvalCallback`` constructor (@burakdmb)
|
||||||
|
- Fixed the issue that when updating the target network in QRDQN, TQC, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875)
|
||||||
|
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
|
|
@ -305,4 +306,4 @@ Contributors:
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
|
||||||
@mlodel @CppMaster @burakdmb
|
@mlodel @CppMaster @burakdmb @honglu2875
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
|
||||||
|
|
||||||
from sb3_contrib.common.utils import quantile_huber_loss
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
||||||
|
|
@ -141,6 +141,9 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
def _setup_model(self) -> None:
|
def _setup_model(self) -> None:
|
||||||
super()._setup_model()
|
super()._setup_model()
|
||||||
self._create_aliases()
|
self._create_aliases()
|
||||||
|
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
|
||||||
|
self.batch_norm_stats = get_parameters_by_name(self.quantile_net, ["running_"])
|
||||||
|
self.batch_norm_stats_target = get_parameters_by_name(self.quantile_net_target, ["running_"])
|
||||||
self.exploration_schedule = get_linear_fn(
|
self.exploration_schedule = get_linear_fn(
|
||||||
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
|
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
|
||||||
)
|
)
|
||||||
|
|
@ -157,6 +160,8 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
if self.num_timesteps % self.target_update_interval == 0:
|
if self.num_timesteps % self.target_update_interval == 0:
|
||||||
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
||||||
|
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
|
||||||
|
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
|
||||||
|
|
||||||
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
|
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
|
||||||
self.logger.record("rollout/exploration_rate", self.exploration_rate)
|
self.logger.record("rollout/exploration_rate", self.exploration_rate)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from stable_baselines3.common.noise import ActionNoise
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||||
from stable_baselines3.common.utils import polyak_update
|
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||||
|
|
||||||
from sb3_contrib.common.utils import quantile_huber_loss
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
|
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
|
||||||
|
|
@ -146,6 +146,9 @@ class TQC(OffPolicyAlgorithm):
|
||||||
def _setup_model(self) -> None:
|
def _setup_model(self) -> None:
|
||||||
super()._setup_model()
|
super()._setup_model()
|
||||||
self._create_aliases()
|
self._create_aliases()
|
||||||
|
# Running mean and running var
|
||||||
|
self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
|
||||||
|
self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])
|
||||||
|
|
||||||
# Target entropy is used when learning the entropy coefficient
|
# Target entropy is used when learning the entropy coefficient
|
||||||
if self.target_entropy == "auto":
|
if self.target_entropy == "auto":
|
||||||
|
|
@ -270,6 +273,8 @@ class TQC(OffPolicyAlgorithm):
|
||||||
# Update target networks
|
# Update target networks
|
||||||
if gradient_step % self.target_update_interval == 0:
|
if gradient_step % self.target_update_interval == 0:
|
||||||
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
|
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
|
||||||
|
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
|
||||||
|
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
|
||||||
|
|
||||||
self._n_updates += gradient_steps
|
self._n_updates += gradient_steps
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,9 @@ def test_qrdqn_train_with_batch_norm():
|
||||||
) = clone_qrdqn_batch_norm_stats(model)
|
) = clone_qrdqn_batch_norm_stats(model)
|
||||||
|
|
||||||
model.learn(total_timesteps=200)
|
model.learn(total_timesteps=200)
|
||||||
|
# Force stats copy
|
||||||
|
model.target_update_interval = 1
|
||||||
|
model._on_step()
|
||||||
|
|
||||||
(
|
(
|
||||||
quantile_net_bias_after,
|
quantile_net_bias_after,
|
||||||
|
|
@ -152,10 +155,12 @@ def test_qrdqn_train_with_batch_norm():
|
||||||
) = clone_qrdqn_batch_norm_stats(model)
|
) = clone_qrdqn_batch_norm_stats(model)
|
||||||
|
|
||||||
assert ~th.isclose(quantile_net_bias_before, quantile_net_bias_after).all()
|
assert ~th.isclose(quantile_net_bias_before, quantile_net_bias_after).all()
|
||||||
assert ~th.isclose(quantile_net_running_mean_before, quantile_net_running_mean_after).all()
|
# Running stat should be copied even when tau=0
|
||||||
|
assert th.isclose(quantile_net_running_mean_before, quantile_net_target_running_mean_before).all()
|
||||||
|
|
||||||
assert th.isclose(quantile_net_target_bias_before, quantile_net_target_bias_after).all()
|
assert th.isclose(quantile_net_target_bias_before, quantile_net_target_bias_after).all()
|
||||||
assert th.isclose(quantile_net_target_running_mean_before, quantile_net_target_running_mean_after).all()
|
# Running stat should be copied even when tau=0
|
||||||
|
assert th.isclose(quantile_net_running_mean_after, quantile_net_target_running_mean_after).all()
|
||||||
|
|
||||||
|
|
||||||
def test_tqc_train_with_batch_norm():
|
def test_tqc_train_with_batch_norm():
|
||||||
|
|
@ -178,6 +183,9 @@ def test_tqc_train_with_batch_norm():
|
||||||
) = clone_tqc_batch_norm_stats(model)
|
) = clone_tqc_batch_norm_stats(model)
|
||||||
|
|
||||||
model.learn(total_timesteps=200)
|
model.learn(total_timesteps=200)
|
||||||
|
# Force stats copy
|
||||||
|
model.target_update_interval = 1
|
||||||
|
model._on_step()
|
||||||
|
|
||||||
(
|
(
|
||||||
actor_bias_after,
|
actor_bias_after,
|
||||||
|
|
@ -192,10 +200,12 @@ def test_tqc_train_with_batch_norm():
|
||||||
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
|
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
|
||||||
|
|
||||||
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
|
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
|
||||||
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
|
# Running stat should be copied even when tau=0
|
||||||
|
assert th.isclose(critic_running_mean_before, critic_target_running_mean_before).all()
|
||||||
|
|
||||||
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
|
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
|
||||||
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
|
# Running stat should be copied even when tau=0
|
||||||
|
assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue