Update to new logger (#32)

This commit is contained in:
Antonin RAFFIN 2021-06-14 17:25:08 +02:00 committed by GitHub
parent 08418a3cc8
commit 2258c72215
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 14 additions and 15 deletions

View File

@ -3,14 +3,15 @@
Changelog Changelog
========== ==========
Release 1.1.0a7 (WIP) Release 1.1.0a11 (WIP)
------------------------------- -------------------------------
Breaking Changes: Breaking Changes:
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
- Added support for Dictionary observation spaces (cf SB3 doc) - Added support for Dictionary observation spaces (cf. SB3 doc)
- Upgraded to Stable-Baselines3 >= 1.1.0 - Upgraded to Stable-Baselines3 >= 1.1.0
- Added proper handling of timeouts for off-policy algorithms (cf SB3 doc) - Added proper handling of timeouts for off-policy algorithms (cf. SB3 doc)
- Updated usage of logger (cf. SB3 doc)
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^

View File

@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym import gym
import numpy as np import numpy as np
import torch as th import torch as th
from stable_baselines3.common import logger
from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.preprocessing import maybe_transpose
@ -153,7 +152,7 @@ class QRDQN(OffPolicyAlgorithm):
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau) polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
logger.record("rollout/exploration rate", self.exploration_rate) self.logger.record("rollout/exploration rate", self.exploration_rate)
def train(self, gradient_steps: int, batch_size: int = 100) -> None: def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Update learning rate according to schedule # Update learning rate according to schedule
@ -199,8 +198,8 @@ class QRDQN(OffPolicyAlgorithm):
# Increase update counter # Increase update counter
self._n_updates += gradient_steps self._n_updates += gradient_steps
logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
logger.record("train/loss", np.mean(losses)) self.logger.record("train/loss", np.mean(losses))
def predict( def predict(
self, self,

View File

@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import gym import gym
import numpy as np import numpy as np
import torch as th import torch as th
from stable_baselines3.common import logger
from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
@ -265,12 +264,12 @@ class TQC(OffPolicyAlgorithm):
self._n_updates += gradient_steps self._n_updates += gradient_steps
logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
logger.record("train/ent_coef", np.mean(ent_coefs)) self.logger.record("train/ent_coef", np.mean(ent_coefs))
logger.record("train/actor_loss", np.mean(actor_losses)) self.logger.record("train/actor_loss", np.mean(actor_losses))
logger.record("train/critic_loss", np.mean(critic_losses)) self.logger.record("train/critic_loss", np.mean(critic_losses))
if len(ent_coef_losses) > 0: if len(ent_coef_losses) > 0:
logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn( def learn(
self, self,

View File

@ -1 +1 @@
1.1.0a7 1.1.0a11

View File

@ -62,7 +62,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=1.1.0a7", "stable_baselines3>=1.1.0a11",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",