Remove sde net arch (#44)
This commit is contained in:
parent
c525c5107b
commit
91f9b1ed34
|
|
@ -3,13 +3,14 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.2.1a0 (WIP)
|
Release 1.2.1a2 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
**Invalid action masking for PPO**
|
**Invalid action masking for PPO**
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Removed ``sde_net_arch``
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy
|
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
from stable_baselines3.common.torch_layers import (
|
from stable_baselines3.common.torch_layers import (
|
||||||
BaseFeaturesExtractor,
|
BaseFeaturesExtractor,
|
||||||
|
|
@ -74,7 +75,6 @@ class Actor(BasePolicy):
|
||||||
# Save arguments to re-create object at loading
|
# Save arguments to re-create object at loading
|
||||||
self.use_sde = use_sde
|
self.use_sde = use_sde
|
||||||
self.sde_features_extractor = None
|
self.sde_features_extractor = None
|
||||||
self.sde_net_arch = sde_net_arch
|
|
||||||
self.net_arch = net_arch
|
self.net_arch = net_arch
|
||||||
self.features_dim = features_dim
|
self.features_dim = features_dim
|
||||||
self.activation_fn = activation_fn
|
self.activation_fn = activation_fn
|
||||||
|
|
@ -89,19 +89,15 @@ class Actor(BasePolicy):
|
||||||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||||
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
|
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
|
||||||
|
|
||||||
if self.use_sde:
|
|
||||||
latent_sde_dim = last_layer_dim
|
|
||||||
# Separate feature extractor for gSDE
|
|
||||||
if sde_net_arch is not None:
|
if sde_net_arch is not None:
|
||||||
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
|
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||||
features_dim, sde_net_arch, activation_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if self.use_sde:
|
||||||
self.action_dist = StateDependentNoiseDistribution(
|
self.action_dist = StateDependentNoiseDistribution(
|
||||||
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
|
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
|
||||||
)
|
)
|
||||||
self.mu, self.log_std = self.action_dist.proba_distribution_net(
|
self.mu, self.log_std = self.action_dist.proba_distribution_net(
|
||||||
latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init
|
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
|
||||||
)
|
)
|
||||||
# Avoid numerical issues by limiting the mean of the Gaussian
|
# Avoid numerical issues by limiting the mean of the Gaussian
|
||||||
# to be in [-clip_mean, clip_mean]
|
# to be in [-clip_mean, clip_mean]
|
||||||
|
|
@ -123,7 +119,6 @@ class Actor(BasePolicy):
|
||||||
use_sde=self.use_sde,
|
use_sde=self.use_sde,
|
||||||
log_std_init=self.log_std_init,
|
log_std_init=self.log_std_init,
|
||||||
full_std=self.full_std,
|
full_std=self.full_std,
|
||||||
sde_net_arch=self.sde_net_arch,
|
|
||||||
use_expln=self.use_expln,
|
use_expln=self.use_expln,
|
||||||
features_extractor=self.features_extractor,
|
features_extractor=self.features_extractor,
|
||||||
clip_mean=self.clip_mean,
|
clip_mean=self.clip_mean,
|
||||||
|
|
@ -168,10 +163,7 @@ class Actor(BasePolicy):
|
||||||
mean_actions = self.mu(latent_pi)
|
mean_actions = self.mu(latent_pi)
|
||||||
|
|
||||||
if self.use_sde:
|
if self.use_sde:
|
||||||
latent_sde = latent_pi
|
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
|
||||||
if self.sde_features_extractor is not None:
|
|
||||||
latent_sde = self.sde_features_extractor(features)
|
|
||||||
return mean_actions, self.log_std, dict(latent_sde=latent_sde)
|
|
||||||
# Unstructured exploration (Original implementation)
|
# Unstructured exploration (Original implementation)
|
||||||
log_std = self.log_std(latent_pi)
|
log_std = self.log_std(latent_pi)
|
||||||
# Original Implementation to cap the standard deviation
|
# Original Implementation to cap the standard deviation
|
||||||
|
|
@ -335,10 +327,13 @@ class TQCPolicy(BasePolicy):
|
||||||
"normalize_images": normalize_images,
|
"normalize_images": normalize_images,
|
||||||
}
|
}
|
||||||
self.actor_kwargs = self.net_args.copy()
|
self.actor_kwargs = self.net_args.copy()
|
||||||
|
|
||||||
|
if sde_net_arch is not None:
|
||||||
|
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||||
|
|
||||||
sde_kwargs = {
|
sde_kwargs = {
|
||||||
"use_sde": use_sde,
|
"use_sde": use_sde,
|
||||||
"log_std_init": log_std_init,
|
"log_std_init": log_std_init,
|
||||||
"sde_net_arch": sde_net_arch,
|
|
||||||
"use_expln": use_expln,
|
"use_expln": use_expln,
|
||||||
"clip_mean": clip_mean,
|
"clip_mean": clip_mean,
|
||||||
}
|
}
|
||||||
|
|
@ -390,7 +385,6 @@ class TQCPolicy(BasePolicy):
|
||||||
activation_fn=self.net_args["activation_fn"],
|
activation_fn=self.net_args["activation_fn"],
|
||||||
use_sde=self.actor_kwargs["use_sde"],
|
use_sde=self.actor_kwargs["use_sde"],
|
||||||
log_std_init=self.actor_kwargs["log_std_init"],
|
log_std_init=self.actor_kwargs["log_std_init"],
|
||||||
sde_net_arch=self.actor_kwargs["sde_net_arch"],
|
|
||||||
use_expln=self.actor_kwargs["use_expln"],
|
use_expln=self.actor_kwargs["use_expln"],
|
||||||
clip_mean=self.actor_kwargs["clip_mean"],
|
clip_mean=self.actor_kwargs["clip_mean"],
|
||||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.2.1a0
|
1.2.1a2
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ def test_sde():
|
||||||
model = TQC(
|
model = TQC(
|
||||||
"MlpPolicy",
|
"MlpPolicy",
|
||||||
"Pendulum-v0",
|
"Pendulum-v0",
|
||||||
policy_kwargs=dict(net_arch=[64], sde_net_arch=[8]),
|
policy_kwargs=dict(net_arch=[64]),
|
||||||
use_sde=True,
|
use_sde=True,
|
||||||
learning_starts=100,
|
learning_starts=100,
|
||||||
verbose=1,
|
verbose=1,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue