Update default TQC net arch when using NatureCnn (#79)
* Update default TQC net arch when using NatureCnn * Bump version
This commit is contained in:
parent
bfa86ce4fe
commit
db4c0114d0
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.5.1a8 (WIP)
|
Release 1.5.1a9 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
**Add RecurrentPPO (aka PPO LSTM)**
|
**Add RecurrentPPO (aka PPO LSTM)**
|
||||||
|
|
@ -16,6 +16,9 @@ Breaking Changes:
|
||||||
- Renamed ``rollout/exploration rate`` key to ``rollout/exploration_rate`` for QRDQN (to be consistent with SB3 DQN)
|
- Renamed ``rollout/exploration rate`` key to ``rollout/exploration_rate`` for QRDQN (to be consistent with SB3 DQN)
|
||||||
- Upgraded to python 3.7+ syntax using ``pyupgrade``
|
- Upgraded to python 3.7+ syntax using ``pyupgrade``
|
||||||
- SB3 now requires PyTorch >= 1.11
|
- SB3 now requires PyTorch >= 1.11
|
||||||
|
- Changed the default network architecture when using ``CnnPolicy`` or ``MultiInputPolicy`` with TQC,
|
||||||
|
``share_features_extractor`` is now set to False by default and the ``net_arch=[256, 256]`` (instead of ``net_arch=[]`` that was before)
|
||||||
|
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -212,7 +212,7 @@ class Critic(BaseModel):
|
||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
n_quantiles: int = 25,
|
n_quantiles: int = 25,
|
||||||
n_critics: int = 2,
|
n_critics: int = 2,
|
||||||
share_features_extractor: bool = True,
|
share_features_extractor: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
|
|
@ -297,7 +297,7 @@ class TQCPolicy(BasePolicy):
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
n_quantiles: int = 25,
|
n_quantiles: int = 25,
|
||||||
n_critics: int = 2,
|
n_critics: int = 2,
|
||||||
share_features_extractor: bool = True,
|
share_features_extractor: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
|
|
@ -310,10 +310,7 @@ class TQCPolicy(BasePolicy):
|
||||||
)
|
)
|
||||||
|
|
||||||
if net_arch is None:
|
if net_arch is None:
|
||||||
if features_extractor_class == NatureCNN:
|
net_arch = [256, 256]
|
||||||
net_arch = []
|
|
||||||
else:
|
|
||||||
net_arch = [256, 256]
|
|
||||||
|
|
||||||
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||||
|
|
||||||
|
|
@ -484,7 +481,7 @@ class CnnPolicy(TQCPolicy):
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
n_quantiles: int = 25,
|
n_quantiles: int = 25,
|
||||||
n_critics: int = 2,
|
n_critics: int = 2,
|
||||||
share_features_extractor: bool = True,
|
share_features_extractor: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
|
|
@ -558,7 +555,7 @@ class MultiInputPolicy(TQCPolicy):
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
n_quantiles: int = 25,
|
n_quantiles: int = 25,
|
||||||
n_critics: int = 2,
|
n_critics: int = 2,
|
||||||
share_features_extractor: bool = True,
|
share_features_extractor: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.5.1a8
|
1.5.1a9
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue