Update to new logger (#32)

This commit is contained in:
Antonin RAFFIN 2021-06-14 17:25:08 +02:00 committed by GitHub
parent 08418a3cc8
commit 2258c72215
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 14 additions and 15 deletions

View File

@ -3,14 +3,15 @@
Changelog
==========
Release 1.1.0a7 (WIP)
Release 1.1.0a11 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Added support for Dictionary observation spaces (cf SB3 doc)
- Added support for Dictionary observation spaces (cf. SB3 doc)
- Upgraded to Stable-Baselines3 >= 1.1.0
- Added proper handling of timeouts for off-policy algorithms (cf SB3 doc)
- Added proper handling of timeouts for off-policy algorithms (cf. SB3 doc)
- Updated usage of logger (cf. SB3 doc)
Bug Fixes:
^^^^^^^^^^

View File

@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from stable_baselines3.common import logger
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.preprocessing import maybe_transpose
@ -153,7 +152,7 @@ class QRDQN(OffPolicyAlgorithm):
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
logger.record("rollout/exploration rate", self.exploration_rate)
self.logger.record("rollout/exploration rate", self.exploration_rate)
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Update learning rate according to schedule
@ -199,8 +198,8 @@ class QRDQN(OffPolicyAlgorithm):
# Increase update counter
self._n_updates += gradient_steps
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
logger.record("train/loss", np.mean(losses))
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/loss", np.mean(losses))
def predict(
self,

View File

@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from stable_baselines3.common import logger
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
@ -265,12 +264,12 @@ class TQC(OffPolicyAlgorithm):
self._n_updates += gradient_steps
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
logger.record("train/ent_coef", np.mean(ent_coefs))
logger.record("train/actor_loss", np.mean(actor_losses))
logger.record("train/critic_loss", np.mean(critic_losses))
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/ent_coef", np.mean(ent_coefs))
self.logger.record("train/actor_loss", np.mean(actor_losses))
self.logger.record("train/critic_loss", np.mean(critic_losses))
if len(ent_coef_losses) > 0:
logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn(
self,

View File

@ -1 +1 @@
1.1.0a7
1.1.0a11

View File

@ -62,7 +62,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.1.0a7",
"stable_baselines3>=1.1.0a11",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",