Update to new logger (#32)
This commit is contained in:
parent
08418a3cc8
commit
2258c72215
|
|
@ -3,14 +3,15 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.1.0a7 (WIP)
|
||||
Release 1.1.0a11 (WIP)
|
||||
-------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Added support for Dictionary observation spaces (cf SB3 doc)
|
||||
- Added support for Dictionary observation spaces (cf. SB3 doc)
|
||||
- Upgraded to Stable-Baselines3 >= 1.1.0
|
||||
- Added proper handling of timeouts for off-policy algorithms (cf SB3 doc)
|
||||
- Added proper handling of timeouts for off-policy algorithms (cf. SB3 doc)
|
||||
- Updated usage of logger (cf. SB3 doc)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||
|
|
@ -153,7 +152,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
||||
|
||||
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
|
||||
logger.record("rollout/exploration rate", self.exploration_rate)
|
||||
self.logger.record("rollout/exploration rate", self.exploration_rate)
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
||||
# Update learning rate according to schedule
|
||||
|
|
@ -199,8 +198,8 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
# Increase update counter
|
||||
self._n_updates += gradient_steps
|
||||
|
||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
logger.record("train/loss", np.mean(losses))
|
||||
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
self.logger.record("train/loss", np.mean(losses))
|
||||
|
||||
def predict(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
|||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
|
|
@ -265,12 +264,12 @@ class TQC(OffPolicyAlgorithm):
|
|||
|
||||
self._n_updates += gradient_steps
|
||||
|
||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
logger.record("train/ent_coef", np.mean(ent_coefs))
|
||||
logger.record("train/actor_loss", np.mean(actor_losses))
|
||||
logger.record("train/critic_loss", np.mean(critic_losses))
|
||||
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
self.logger.record("train/ent_coef", np.mean(ent_coefs))
|
||||
self.logger.record("train/actor_loss", np.mean(actor_losses))
|
||||
self.logger.record("train/critic_loss", np.mean(critic_losses))
|
||||
if len(ent_coef_losses) > 0:
|
||||
logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.1.0a7
|
||||
1.1.0a11
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -62,7 +62,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.1.0a7",
|
||||
"stable_baselines3>=1.1.0a11",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
Loading…
Reference in New Issue