Update default TQC net arch when using NatureCnn (#79)

* Update default TQC net arch when using NatureCnn

* Bump version
This commit is contained in:
Antonin RAFFIN 2022-06-18 10:53:29 +02:00 committed by GitHub
parent bfa86ce4fe
commit db4c0114d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 10 deletions

View File

@ -3,7 +3,7 @@
Changelog
==========
Release 1.5.1a8 (WIP)
Release 1.5.1a9 (WIP)
-------------------------------
**Add RecurrentPPO (aka PPO LSTM)**
@ -16,6 +16,9 @@ Breaking Changes:
- Renamed ``rollout/exploration rate`` key to ``rollout/exploration_rate`` for QRDQN (to be consistent with SB3 DQN)
- Upgraded to python 3.7+ syntax using ``pyupgrade``
- SB3 now requires PyTorch >= 1.11
- Changed the default network architecture when using ``CnnPolicy`` or ``MultiInputPolicy`` with TQC,
``share_features_extractor`` is now set to False by default and the ``net_arch=[256, 256]`` (instead of ``net_arch=[]`` that was before)
New Features:
^^^^^^^^^^^^^

View File

@ -212,7 +212,7 @@ class Critic(BaseModel):
normalize_images: bool = True,
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = True,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
@ -297,7 +297,7 @@ class TQCPolicy(BasePolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = True,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
@ -310,9 +310,6 @@ class TQCPolicy(BasePolicy):
)
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = [256, 256]
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
@ -484,7 +481,7 @@ class CnnPolicy(TQCPolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = True,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
@ -558,7 +555,7 @@ class MultiInputPolicy(TQCPolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = True,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,

View File

@ -1 +1 @@
1.5.1a8
1.5.1a9