Update default TQC net arch when using NatureCnn (#79)
* Update default TQC net arch when using NatureCnn * Bump version
This commit is contained in:
parent
bfa86ce4fe
commit
db4c0114d0
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.5.1a8 (WIP)
|
||||
Release 1.5.1a9 (WIP)
|
||||
-------------------------------
|
||||
|
||||
**Add RecurrentPPO (aka PPO LSTM)**
|
||||
|
|
@ -16,6 +16,9 @@ Breaking Changes:
|
|||
- Renamed ``rollout/exploration rate`` key to ``rollout/exploration_rate`` for QRDQN (to be consistent with SB3 DQN)
|
||||
- Upgraded to python 3.7+ syntax using ``pyupgrade``
|
||||
- SB3 now requires PyTorch >= 1.11
|
||||
- Changed the default network architecture when using ``CnnPolicy`` or ``MultiInputPolicy`` with TQC,
|
||||
``share_features_extractor`` is now set to False by default and the ``net_arch=[256, 256]`` (instead of ``net_arch=[]`` that was before)
|
||||
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ class Critic(BaseModel):
|
|||
normalize_images: bool = True,
|
||||
n_quantiles: int = 25,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
@ -297,7 +297,7 @@ class TQCPolicy(BasePolicy):
|
|||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_quantiles: int = 25,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
@ -310,9 +310,6 @@ class TQCPolicy(BasePolicy):
|
|||
)
|
||||
|
||||
if net_arch is None:
|
||||
if features_extractor_class == NatureCNN:
|
||||
net_arch = []
|
||||
else:
|
||||
net_arch = [256, 256]
|
||||
|
||||
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||
|
|
@ -484,7 +481,7 @@ class CnnPolicy(TQCPolicy):
|
|||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_quantiles: int = 25,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
@ -558,7 +555,7 @@ class MultiInputPolicy(TQCPolicy):
|
|||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_quantiles: int = 25,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.5.1a8
|
||||
1.5.1a9
|
||||
|
|
|
|||
Loading…
Reference in New Issue