Remove sde net arch (#44)
This commit is contained in:
parent
c525c5107b
commit
91f9b1ed34
|
|
@ -3,13 +3,14 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.2.1a0 (WIP)
|
||||
Release 1.2.1a2 (WIP)
|
||||
-------------------------------
|
||||
|
||||
**Invalid action masking for PPO**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Removed ``sde_net_arch``
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
|
|
@ -74,7 +75,6 @@ class Actor(BasePolicy):
|
|||
# Save arguments to re-create object at loading
|
||||
self.use_sde = use_sde
|
||||
self.sde_features_extractor = None
|
||||
self.sde_net_arch = sde_net_arch
|
||||
self.net_arch = net_arch
|
||||
self.features_dim = features_dim
|
||||
self.activation_fn = activation_fn
|
||||
|
|
@ -89,19 +89,15 @@ class Actor(BasePolicy):
|
|||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
|
||||
|
||||
if self.use_sde:
|
||||
latent_sde_dim = last_layer_dim
|
||||
# Separate feature extractor for gSDE
|
||||
if sde_net_arch is not None:
|
||||
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
|
||||
features_dim, sde_net_arch, activation_fn
|
||||
)
|
||||
if sde_net_arch is not None:
|
||||
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||
|
||||
if self.use_sde:
|
||||
self.action_dist = StateDependentNoiseDistribution(
|
||||
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
|
||||
)
|
||||
self.mu, self.log_std = self.action_dist.proba_distribution_net(
|
||||
latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init
|
||||
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
|
||||
)
|
||||
# Avoid numerical issues by limiting the mean of the Gaussian
|
||||
# to be in [-clip_mean, clip_mean]
|
||||
|
|
@ -123,7 +119,6 @@ class Actor(BasePolicy):
|
|||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
full_std=self.full_std,
|
||||
sde_net_arch=self.sde_net_arch,
|
||||
use_expln=self.use_expln,
|
||||
features_extractor=self.features_extractor,
|
||||
clip_mean=self.clip_mean,
|
||||
|
|
@ -168,10 +163,7 @@ class Actor(BasePolicy):
|
|||
mean_actions = self.mu(latent_pi)
|
||||
|
||||
if self.use_sde:
|
||||
latent_sde = latent_pi
|
||||
if self.sde_features_extractor is not None:
|
||||
latent_sde = self.sde_features_extractor(features)
|
||||
return mean_actions, self.log_std, dict(latent_sde=latent_sde)
|
||||
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
|
||||
# Unstructured exploration (Original implementation)
|
||||
log_std = self.log_std(latent_pi)
|
||||
# Original Implementation to cap the standard deviation
|
||||
|
|
@ -335,10 +327,13 @@ class TQCPolicy(BasePolicy):
|
|||
"normalize_images": normalize_images,
|
||||
}
|
||||
self.actor_kwargs = self.net_args.copy()
|
||||
|
||||
if sde_net_arch is not None:
|
||||
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||
|
||||
sde_kwargs = {
|
||||
"use_sde": use_sde,
|
||||
"log_std_init": log_std_init,
|
||||
"sde_net_arch": sde_net_arch,
|
||||
"use_expln": use_expln,
|
||||
"clip_mean": clip_mean,
|
||||
}
|
||||
|
|
@ -390,7 +385,6 @@ class TQCPolicy(BasePolicy):
|
|||
activation_fn=self.net_args["activation_fn"],
|
||||
use_sde=self.actor_kwargs["use_sde"],
|
||||
log_std_init=self.actor_kwargs["log_std_init"],
|
||||
sde_net_arch=self.actor_kwargs["sde_net_arch"],
|
||||
use_expln=self.actor_kwargs["use_expln"],
|
||||
clip_mean=self.actor_kwargs["clip_mean"],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.2.1a0
|
||||
1.2.1a2
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def test_sde():
|
|||
model = TQC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
policy_kwargs=dict(net_arch=[64], sde_net_arch=[8]),
|
||||
policy_kwargs=dict(net_arch=[64]),
|
||||
use_sde=True,
|
||||
learning_starts=100,
|
||||
verbose=1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue