From 2258c72215c264dccc2061e10f0ae82909f9a733 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 14 Jun 2021 17:25:08 +0200 Subject: [PATCH] Update to new logger (#32) --- docs/misc/changelog.rst | 7 ++++--- sb3_contrib/qrdqn/qrdqn.py | 7 +++---- sb3_contrib/tqc/tqc.py | 11 +++++------ sb3_contrib/version.txt | 2 +- setup.py | 2 +- 5 files changed, 14 insertions(+), 15 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7ccf472..822a028 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,14 +3,15 @@ Changelog ========== -Release 1.1.0a7 (WIP) +Release 1.1.0a11 (WIP) ------------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Added support for Dictionary observation spaces (cf SB3 doc) +- Added support for Dictionary observation spaces (cf. SB3 doc) - Upgraded to Stable-Baselines3 >= 1.1.0 -- Added proper handling of timeouts for off-policy algorithms (cf SB3 doc) +- Added proper handling of timeouts for off-policy algorithms (cf. SB3 doc) +- Updated usage of logger (cf. SB3 doc) Bug Fixes: ^^^^^^^^^^ diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index d0f5c95..001f550 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import gym import numpy as np import torch as th -from stable_baselines3.common import logger from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.preprocessing import maybe_transpose @@ -153,7 +152,7 @@ class QRDQN(OffPolicyAlgorithm): polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau) self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) - logger.record("rollout/exploration rate", self.exploration_rate) + self.logger.record("rollout/exploration rate", self.exploration_rate) def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Update learning rate according to schedule @@ -199,8 +198,8 @@ class QRDQN(OffPolicyAlgorithm): # Increase update counter self._n_updates += gradient_steps - logger.record("train/n_updates", self._n_updates, exclude="tensorboard") - logger.record("train/loss", np.mean(losses)) + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/loss", np.mean(losses)) def predict( self, diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index a9d5df4..015f8d4 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import gym import numpy as np import torch as th -from stable_baselines3.common import logger from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm @@ -265,12 +264,12 @@ class TQC(OffPolicyAlgorithm): self._n_updates += gradient_steps - logger.record("train/n_updates", self._n_updates, exclude="tensorboard") - logger.record("train/ent_coef", np.mean(ent_coefs)) - logger.record("train/actor_loss", np.mean(actor_losses)) - logger.record("train/critic_loss", np.mean(critic_losses)) + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/ent_coef", np.mean(ent_coefs)) + self.logger.record("train/actor_loss", np.mean(actor_losses)) + self.logger.record("train/critic_loss", np.mean(critic_losses)) if len(ent_coef_losses) > 0: - logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) + self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) def learn( self, diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 5e7d254..a149840 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.1.0a7 +1.1.0a11 diff --git a/setup.py b/setup.py index 5d2af6b..023bdad 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.1.0a7", + "stable_baselines3>=1.1.0a11", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin",