Remove sde net arch (#44)

This commit is contained in:
Antonin RAFFIN 2021-09-28 21:59:59 +02:00 committed by GitHub
parent c525c5107b
commit 91f9b1ed34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 20 deletions

View File

@ -3,13 +3,14 @@
Changelog Changelog
========== ==========
Release 1.2.1a0 (WIP) Release 1.2.1a2 (WIP)
------------------------------- -------------------------------
**Invalid action masking for PPO** **Invalid action masking for PPO**
Breaking Changes: Breaking Changes:
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
- Removed ``sde_net_arch``
New Features: New Features:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^

View File

@ -1,9 +1,10 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym import gym
import torch as th import torch as th
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy
from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor, BaseFeaturesExtractor,
@ -74,7 +75,6 @@ class Actor(BasePolicy):
# Save arguments to re-create object at loading # Save arguments to re-create object at loading
self.use_sde = use_sde self.use_sde = use_sde
self.sde_features_extractor = None self.sde_features_extractor = None
self.sde_net_arch = sde_net_arch
self.net_arch = net_arch self.net_arch = net_arch
self.features_dim = features_dim self.features_dim = features_dim
self.activation_fn = activation_fn self.activation_fn = activation_fn
@ -89,19 +89,15 @@ class Actor(BasePolicy):
self.latent_pi = nn.Sequential(*latent_pi_net) self.latent_pi = nn.Sequential(*latent_pi_net)
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
if self.use_sde: if sde_net_arch is not None:
latent_sde_dim = last_layer_dim warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
# Separate feature extractor for gSDE
if sde_net_arch is not None:
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
features_dim, sde_net_arch, activation_fn
)
if self.use_sde:
self.action_dist = StateDependentNoiseDistribution( self.action_dist = StateDependentNoiseDistribution(
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
) )
self.mu, self.log_std = self.action_dist.proba_distribution_net( self.mu, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
) )
# Avoid numerical issues by limiting the mean of the Gaussian # Avoid numerical issues by limiting the mean of the Gaussian
# to be in [-clip_mean, clip_mean] # to be in [-clip_mean, clip_mean]
@ -123,7 +119,6 @@ class Actor(BasePolicy):
use_sde=self.use_sde, use_sde=self.use_sde,
log_std_init=self.log_std_init, log_std_init=self.log_std_init,
full_std=self.full_std, full_std=self.full_std,
sde_net_arch=self.sde_net_arch,
use_expln=self.use_expln, use_expln=self.use_expln,
features_extractor=self.features_extractor, features_extractor=self.features_extractor,
clip_mean=self.clip_mean, clip_mean=self.clip_mean,
@ -168,10 +163,7 @@ class Actor(BasePolicy):
mean_actions = self.mu(latent_pi) mean_actions = self.mu(latent_pi)
if self.use_sde: if self.use_sde:
latent_sde = latent_pi return mean_actions, self.log_std, dict(latent_sde=latent_pi)
if self.sde_features_extractor is not None:
latent_sde = self.sde_features_extractor(features)
return mean_actions, self.log_std, dict(latent_sde=latent_sde)
# Unstructured exploration (Original implementation) # Unstructured exploration (Original implementation)
log_std = self.log_std(latent_pi) log_std = self.log_std(latent_pi)
# Original Implementation to cap the standard deviation # Original Implementation to cap the standard deviation
@ -335,10 +327,13 @@ class TQCPolicy(BasePolicy):
"normalize_images": normalize_images, "normalize_images": normalize_images,
} }
self.actor_kwargs = self.net_args.copy() self.actor_kwargs = self.net_args.copy()
if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
sde_kwargs = { sde_kwargs = {
"use_sde": use_sde, "use_sde": use_sde,
"log_std_init": log_std_init, "log_std_init": log_std_init,
"sde_net_arch": sde_net_arch,
"use_expln": use_expln, "use_expln": use_expln,
"clip_mean": clip_mean, "clip_mean": clip_mean,
} }
@ -390,7 +385,6 @@ class TQCPolicy(BasePolicy):
activation_fn=self.net_args["activation_fn"], activation_fn=self.net_args["activation_fn"],
use_sde=self.actor_kwargs["use_sde"], use_sde=self.actor_kwargs["use_sde"],
log_std_init=self.actor_kwargs["log_std_init"], log_std_init=self.actor_kwargs["log_std_init"],
sde_net_arch=self.actor_kwargs["sde_net_arch"],
use_expln=self.actor_kwargs["use_expln"], use_expln=self.actor_kwargs["use_expln"],
clip_mean=self.actor_kwargs["clip_mean"], clip_mean=self.actor_kwargs["clip_mean"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone

View File

@ -1 +1 @@
1.2.1a0 1.2.1a2

View File

@ -34,7 +34,7 @@ def test_sde():
model = TQC( model = TQC(
"MlpPolicy", "MlpPolicy",
"Pendulum-v0", "Pendulum-v0",
policy_kwargs=dict(net_arch=[64], sde_net_arch=[8]), policy_kwargs=dict(net_arch=[64]),
use_sde=True, use_sde=True,
learning_starts=100, learning_starts=100,
verbose=1, verbose=1,