Rename QRDQN logger key (#67)
This commit is contained in:
parent
99853265a9
commit
812648e6cd
|
|
@ -11,6 +11,7 @@ Breaking Changes:
|
||||||
- Upgraded to Stable-Baselines3 >= 1.5.1a1
|
- Upgraded to Stable-Baselines3 >= 1.5.1a1
|
||||||
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
|
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
|
||||||
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
|
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
|
||||||
|
- Renamed ``rollout/exploration rate`` key to ``rollout/exploration_rate`` for QRDQN (to be consistent with SB3 DQN)
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
||||||
|
|
||||||
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
|
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
|
||||||
self.logger.record("rollout/exploration rate", self.exploration_rate)
|
self.logger.record("rollout/exploration_rate", self.exploration_rate)
|
||||||
|
|
||||||
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
||||||
# Switch to train mode (this affects batch norm / dropout)
|
# Switch to train mode (this affects batch norm / dropout)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue