Remove sde net arch (#44)

This commit is contained in:
Antonin RAFFIN 2021-09-28 21:59:59 +02:00 committed by GitHub
parent c525c5107b
commit 91f9b1ed34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 20 deletions

View File

@ -3,13 +3,14 @@
Changelog
==========
Release 1.2.1a0 (WIP)
Release 1.2.1a2 (WIP)
-------------------------------
**Invalid action masking for PPO**
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed ``sde_net_arch``
New Features:
^^^^^^^^^^^^^

View File

@ -1,9 +1,10 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
@ -74,7 +75,6 @@ class Actor(BasePolicy):
# Save arguments to re-create object at loading
self.use_sde = use_sde
self.sde_features_extractor = None
self.sde_net_arch = sde_net_arch
self.net_arch = net_arch
self.features_dim = features_dim
self.activation_fn = activation_fn
@ -89,19 +89,15 @@ class Actor(BasePolicy):
self.latent_pi = nn.Sequential(*latent_pi_net)
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
if self.use_sde:
latent_sde_dim = last_layer_dim
# Separate feature extractor for gSDE
if sde_net_arch is not None:
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
features_dim, sde_net_arch, activation_fn
)
if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
if self.use_sde:
self.action_dist = StateDependentNoiseDistribution(
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
)
self.mu, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
)
# Avoid numerical issues by limiting the mean of the Gaussian
# to be in [-clip_mean, clip_mean]
@ -123,7 +119,6 @@ class Actor(BasePolicy):
use_sde=self.use_sde,
log_std_init=self.log_std_init,
full_std=self.full_std,
sde_net_arch=self.sde_net_arch,
use_expln=self.use_expln,
features_extractor=self.features_extractor,
clip_mean=self.clip_mean,
@ -168,10 +163,7 @@ class Actor(BasePolicy):
mean_actions = self.mu(latent_pi)
if self.use_sde:
latent_sde = latent_pi
if self.sde_features_extractor is not None:
latent_sde = self.sde_features_extractor(features)
return mean_actions, self.log_std, dict(latent_sde=latent_sde)
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
# Unstructured exploration (Original implementation)
log_std = self.log_std(latent_pi)
# Original Implementation to cap the standard deviation
@ -335,10 +327,13 @@ class TQCPolicy(BasePolicy):
"normalize_images": normalize_images,
}
self.actor_kwargs = self.net_args.copy()
if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
sde_kwargs = {
"use_sde": use_sde,
"log_std_init": log_std_init,
"sde_net_arch": sde_net_arch,
"use_expln": use_expln,
"clip_mean": clip_mean,
}
@ -390,7 +385,6 @@ class TQCPolicy(BasePolicy):
activation_fn=self.net_args["activation_fn"],
use_sde=self.actor_kwargs["use_sde"],
log_std_init=self.actor_kwargs["log_std_init"],
sde_net_arch=self.actor_kwargs["sde_net_arch"],
use_expln=self.actor_kwargs["use_expln"],
clip_mean=self.actor_kwargs["clip_mean"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone

View File

@ -1 +1 @@
1.2.1a0
1.2.1a2

View File

@ -34,7 +34,7 @@ def test_sde():
model = TQC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64], sde_net_arch=[8]),
policy_kwargs=dict(net_arch=[64]),
use_sde=True,
learning_starts=100,
verbose=1,