From db4c0114d07bb7439dc7ef02eef2a6e0d0de09bd Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 18 Jun 2022 10:53:29 +0200 Subject: [PATCH] Update default TQC net arch when using NatureCnn (#79) * Update default TQC net arch when using NatureCnn * Bump version --- docs/misc/changelog.rst | 5 ++++- sb3_contrib/tqc/policies.py | 13 +++++-------- sb3_contrib/version.txt | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fafe6f6..82766aa 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.5.1a8 (WIP) +Release 1.5.1a9 (WIP) ------------------------------- **Add RecurrentPPO (aka PPO LSTM)** @@ -16,6 +16,9 @@ Breaking Changes: - Renamed ``rollout/exploration rate`` key to ``rollout/exploration_rate`` for QRDQN (to be consistent with SB3 DQN) - Upgraded to python 3.7+ syntax using ``pyupgrade`` - SB3 now requires PyTorch >= 1.11 +- Changed the default network architecture when using ``CnnPolicy`` or ``MultiInputPolicy`` with TQC, + ``share_features_extractor`` is now set to False by default and the ``net_arch=[256, 256]`` (instead of ``net_arch=[]`` that was before) + New Features: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index dbd073f..e8022f9 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -212,7 +212,7 @@ class Critic(BaseModel): normalize_images: bool = True, n_quantiles: int = 25, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): super().__init__( observation_space, @@ -297,7 +297,7 @@ class TQCPolicy(BasePolicy): optimizer_kwargs: Optional[Dict[str, Any]] = None, n_quantiles: int = 25, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): super().__init__( observation_space, @@ -310,10 +310,7 @@ class TQCPolicy(BasePolicy): ) if net_arch is None: - if features_extractor_class == NatureCNN: - net_arch = [] - else: - net_arch = [256, 256] + net_arch = [256, 256] actor_arch, critic_arch = get_actor_critic_arch(net_arch) @@ -484,7 +481,7 @@ class CnnPolicy(TQCPolicy): optimizer_kwargs: Optional[Dict[str, Any]] = None, n_quantiles: int = 25, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): super().__init__( observation_space, @@ -558,7 +555,7 @@ class MultiInputPolicy(TQCPolicy): optimizer_kwargs: Optional[Dict[str, Any]] = None, n_quantiles: int = 25, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): super().__init__( observation_space, diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 511e75b..125ec27 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.5.1a8 +1.5.1a9