From 91f9b1ed34fbaa9243a044ea67aa4c677663bfc2 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 28 Sep 2021 21:59:59 +0200 Subject: [PATCH] Remove sde net arch (#44) --- docs/misc/changelog.rst | 3 ++- sb3_contrib/tqc/policies.py | 28 +++++++++++----------------- sb3_contrib/version.txt | 2 +- tests/test_run.py | 2 +- 4 files changed, 15 insertions(+), 20 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 79693dd..fcc94da 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,13 +3,14 @@ Changelog ========== -Release 1.2.1a0 (WIP) +Release 1.2.1a2 (WIP) ------------------------------- **Invalid action masking for PPO** Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Removed ``sde_net_arch`` New Features: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index d02601c..089a2f1 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -1,9 +1,10 @@ +import warnings from typing import Any, Dict, List, Optional, Tuple, Type, Union import gym import torch as th from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution -from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_features_extractor, register_policy +from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -74,7 +75,6 @@ class Actor(BasePolicy): # Save arguments to re-create object at loading self.use_sde = use_sde self.sde_features_extractor = None - self.sde_net_arch = sde_net_arch self.net_arch = net_arch self.features_dim = features_dim self.activation_fn = activation_fn @@ -89,19 +89,15 @@ class Actor(BasePolicy): self.latent_pi = nn.Sequential(*latent_pi_net) last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim - if self.use_sde: - latent_sde_dim = last_layer_dim - # Separate feature extractor for gSDE - if sde_net_arch is not None: - self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor( - features_dim, sde_net_arch, activation_fn - ) + if sde_net_arch is not None: + warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) + if self.use_sde: self.action_dist = StateDependentNoiseDistribution( action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True ) self.mu, self.log_std = self.action_dist.proba_distribution_net( - latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init + latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init ) # Avoid numerical issues by limiting the mean of the Gaussian # to be in [-clip_mean, clip_mean] @@ -123,7 +119,6 @@ class Actor(BasePolicy): use_sde=self.use_sde, log_std_init=self.log_std_init, full_std=self.full_std, - sde_net_arch=self.sde_net_arch, use_expln=self.use_expln, features_extractor=self.features_extractor, clip_mean=self.clip_mean, @@ -168,10 +163,7 @@ class Actor(BasePolicy): mean_actions = self.mu(latent_pi) if self.use_sde: - latent_sde = latent_pi - if self.sde_features_extractor is not None: - latent_sde = self.sde_features_extractor(features) - return mean_actions, self.log_std, dict(latent_sde=latent_sde) + return mean_actions, self.log_std, dict(latent_sde=latent_pi) # Unstructured exploration (Original implementation) log_std = self.log_std(latent_pi) # Original Implementation to cap the standard deviation @@ -335,10 +327,13 @@ class TQCPolicy(BasePolicy): "normalize_images": normalize_images, } self.actor_kwargs = self.net_args.copy() + + if sde_net_arch is not None: + warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) + sde_kwargs = { "use_sde": use_sde, "log_std_init": log_std_init, - "sde_net_arch": sde_net_arch, "use_expln": use_expln, "clip_mean": clip_mean, } @@ -390,7 +385,6 @@ class TQCPolicy(BasePolicy): activation_fn=self.net_args["activation_fn"], use_sde=self.actor_kwargs["use_sde"], log_std_init=self.actor_kwargs["log_std_init"], - sde_net_arch=self.actor_kwargs["sde_net_arch"], use_expln=self.actor_kwargs["use_expln"], clip_mean=self.actor_kwargs["clip_mean"], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 348e216..c4baa5c 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.2.1a0 +1.2.1a2 diff --git a/tests/test_run.py b/tests/test_run.py index 195d011..b1431a3 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -34,7 +34,7 @@ def test_sde(): model = TQC( "MlpPolicy", "Pendulum-v0", - policy_kwargs=dict(net_arch=[64], sde_net_arch=[8]), + policy_kwargs=dict(net_arch=[64]), use_sde=True, learning_starts=100, verbose=1,