diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a9e94af..368f7a6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -32,6 +32,8 @@ Deprecations: Others: ^^^^^^^ +- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl) + Documentation: ^^^^^^^^^^^^^^ - Add some additional notes about ``MaskablePPO`` (evaluation and multi-process) (@icheered) @@ -545,4 +547,4 @@ Contributors: ------------- @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec -@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered +@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 303862e..5271ac8 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -82,7 +82,7 @@ class QRDQN(OffPolicyAlgorithm): batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, - train_freq: int = 4, + train_freq: Union[int, Tuple[int, str]] = 4, gradient_steps: int = 1, replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index f5e4ec6..d4d8538 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -88,7 +88,7 @@ class TQC(OffPolicyAlgorithm): batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: int = 1, + train_freq: Union[int, Tuple[int, str]] = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[Type[ReplayBuffer]] = None,