Update to new logger (#32)
This commit is contained in:
parent
08418a3cc8
commit
2258c72215
|
|
@ -3,14 +3,15 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.1.0a7 (WIP)
|
Release 1.1.0a11 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
- Added support for Dictionary observation spaces (cf SB3 doc)
|
- Added support for Dictionary observation spaces (cf. SB3 doc)
|
||||||
- Upgraded to Stable-Baselines3 >= 1.1.0
|
- Upgraded to Stable-Baselines3 >= 1.1.0
|
||||||
- Added proper handling of timeouts for off-policy algorithms (cf SB3 doc)
|
- Added proper handling of timeouts for off-policy algorithms (cf. SB3 doc)
|
||||||
|
- Updated usage of logger (cf. SB3 doc)
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common import logger
|
|
||||||
from stable_baselines3.common.buffers import ReplayBuffer
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||||
|
|
@ -153,7 +152,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
||||||
|
|
||||||
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
|
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
|
||||||
logger.record("rollout/exploration rate", self.exploration_rate)
|
self.logger.record("rollout/exploration rate", self.exploration_rate)
|
||||||
|
|
||||||
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
||||||
# Update learning rate according to schedule
|
# Update learning rate according to schedule
|
||||||
|
|
@ -199,8 +198,8 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
# Increase update counter
|
# Increase update counter
|
||||||
self._n_updates += gradient_steps
|
self._n_updates += gradient_steps
|
||||||
|
|
||||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||||
logger.record("train/loss", np.mean(losses))
|
self.logger.record("train/loss", np.mean(losses))
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common import logger
|
|
||||||
from stable_baselines3.common.buffers import ReplayBuffer
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.noise import ActionNoise
|
from stable_baselines3.common.noise import ActionNoise
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
|
|
@ -265,12 +264,12 @@ class TQC(OffPolicyAlgorithm):
|
||||||
|
|
||||||
self._n_updates += gradient_steps
|
self._n_updates += gradient_steps
|
||||||
|
|
||||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||||
logger.record("train/ent_coef", np.mean(ent_coefs))
|
self.logger.record("train/ent_coef", np.mean(ent_coefs))
|
||||||
logger.record("train/actor_loss", np.mean(actor_losses))
|
self.logger.record("train/actor_loss", np.mean(actor_losses))
|
||||||
logger.record("train/critic_loss", np.mean(critic_losses))
|
self.logger.record("train/critic_loss", np.mean(critic_losses))
|
||||||
if len(ent_coef_losses) > 0:
|
if len(ent_coef_losses) > 0:
|
||||||
logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.1.0a7
|
1.1.0a11
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -62,7 +62,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.1.0a7",
|
"stable_baselines3>=1.1.0a11",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue