Fix QR-DQN type hints (#170)

This commit is contained in:
Antonin RAFFIN 2023-03-30 11:50:26 +02:00 committed by GitHub
parent b5fd6e65ba
commit ce115982aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 27 additions and 20 deletions

View File

@ -3,7 +3,7 @@
Changelog Changelog
========== ==========
Release 1.8.0a10 (WIP) Release 1.8.0a13 (WIP)
-------------------------- --------------------------
Breaking Changes: Breaking Changes:
@ -24,7 +24,8 @@ Others:
^^^^^^^ ^^^^^^^
- Moved to pyproject.toml (except flake8) - Moved to pyproject.toml (except flake8)
- Added github issue forms - Added github issue forms
- Fix Atari Roms download in CI - Fixed Atari Roms download in CI
- Fixed ``sb3_contrib/qrdqn/*.py`` type hints
Documentation: Documentation:
^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^

View File

@ -29,8 +29,6 @@ show_error_codes = true
exclude = """(?x)( exclude = """(?x)(
sb3_contrib/ppo_recurrent/ppo_recurrent.py$ sb3_contrib/ppo_recurrent/ppo_recurrent.py$
| sb3_contrib/ars/ars.py$ | sb3_contrib/ars/ars.py$
| sb3_contrib/qrdqn/qrdqn.py$
| sb3_contrib/qrdqn/policies.py$
| sb3_contrib/common/recurrent/policies.py$ | sb3_contrib/common/recurrent/policies.py$
| sb3_contrib/common/recurrent/buffers.py$ | sb3_contrib/common/recurrent/buffers.py$
| sb3_contrib/common/maskable/distributions.py$ | sb3_contrib/common/maskable/distributions.py$

View File

@ -31,7 +31,7 @@ class QuantileNetwork(BasePolicy):
self, self,
observation_space: spaces.Space, observation_space: spaces.Space,
action_space: spaces.Space, action_space: spaces.Space,
features_extractor: nn.Module, features_extractor: BaseFeaturesExtractor,
features_dim: int, features_dim: int,
n_quantiles: int = 200, n_quantiles: int = 200,
net_arch: Optional[List[int]] = None, net_arch: Optional[List[int]] = None,
@ -50,7 +50,6 @@ class QuantileNetwork(BasePolicy):
self.net_arch = net_arch self.net_arch = net_arch
self.activation_fn = activation_fn self.activation_fn = activation_fn
self.features_extractor = features_extractor
self.features_dim = features_dim self.features_dim = features_dim
self.n_quantiles = n_quantiles self.n_quantiles = n_quantiles
action_dim = self.action_space.n # number of actions action_dim = self.action_space.n # number of actions
@ -64,6 +63,8 @@ class QuantileNetwork(BasePolicy):
:param obs: Observation :param obs: Observation
:return: The estimated quantiles for each action. :return: The estimated quantiles for each action.
""" """
# For type checker:
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor)) quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
return quantiles.view(-1, self.n_quantiles, self.action_space.n) return quantiles.view(-1, self.n_quantiles, self.action_space.n)
@ -152,7 +153,8 @@ class QRDQNPolicy(BasePolicy):
"normalize_images": normalize_images, "normalize_images": normalize_images,
} }
self.quantile_net, self.quantile_net_target = None, None self.quantile_net: QuantileNetwork
self.quantile_net_target: QuantileNetwork
self._build(lr_schedule) self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None: def _build(self, lr_schedule: Schedule) -> None:
@ -168,7 +170,11 @@ class QRDQNPolicy(BasePolicy):
self.quantile_net_target.set_training_mode(False) self.quantile_net_target.set_training_mode(False)
# Setup optimizer with initial learning rate # Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) self.optimizer = self.optimizer_class( # type: ignore[call-arg]
self.parameters(),
lr=lr_schedule(1),
**self.optimizer_kwargs,
)
def make_quantile_net(self) -> QuantileNetwork: def make_quantile_net(self) -> QuantileNetwork:
# Make sure we always have separate networks for features extractors etc # Make sure we always have separate networks for features extractors etc

View File

@ -6,9 +6,8 @@ from gym import spaces
from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
@ -71,12 +70,12 @@ class QRDQN(OffPolicyAlgorithm):
learning_rate: Union[float, Schedule] = 5e-5, learning_rate: Union[float, Schedule] = 5e-5,
buffer_size: int = 1000000, # 1e6 buffer_size: int = 1000000, # 1e6
learning_starts: int = 50000, learning_starts: int = 50000,
batch_size: Optional[int] = 32, batch_size: int = 32,
tau: float = 1.0, tau: float = 1.0,
gamma: float = 0.99, gamma: float = 0.99,
train_freq: int = 4, train_freq: int = 4,
gradient_steps: int = 1, gradient_steps: int = 1,
replay_buffer_class: Optional[ReplayBuffer] = None, replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False, optimize_memory_usage: bool = False,
target_update_interval: int = 10000, target_update_interval: int = 10000,
@ -124,8 +123,9 @@ class QRDQN(OffPolicyAlgorithm):
# "epsilon" for the epsilon-greedy exploration # "epsilon" for the epsilon-greedy exploration
self.exploration_rate = 0.0 self.exploration_rate = 0.0
# Linear schedule will be defined in `_setup_model()` # Linear schedule will be defined in `_setup_model()`
self.exploration_schedule = None self.exploration_schedule: Schedule
self.quantile_net, self.quantile_net_target = None, None self.quantile_net: th.nn.Module
self.quantile_net_target: th.nn.Module
if "optimizer_class" not in self.policy_kwargs: if "optimizer_class" not in self.policy_kwargs:
self.policy_kwargs["optimizer_class"] = th.optim.Adam self.policy_kwargs["optimizer_class"] = th.optim.Adam
@ -146,6 +146,8 @@ class QRDQN(OffPolicyAlgorithm):
) )
def _create_aliases(self) -> None: def _create_aliases(self) -> None:
# For type checker:
assert isinstance(self.policy, QRDQNPolicy)
self.quantile_net = self.policy.quantile_net self.quantile_net = self.policy.quantile_net
self.quantile_net_target = self.policy.quantile_net_target self.quantile_net_target = self.policy.quantile_net_target
self.n_quantiles = self.policy.n_quantiles self.n_quantiles = self.policy.n_quantiles
@ -172,7 +174,7 @@ class QRDQN(OffPolicyAlgorithm):
losses = [] losses = []
for _ in range(gradient_steps): for _ in range(gradient_steps):
# Sample replay buffer # Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
with th.no_grad(): with th.no_grad():
# Compute the quantiles of next observation # Compute the quantiles of next observation
@ -214,7 +216,7 @@ class QRDQN(OffPolicyAlgorithm):
def predict( def predict(
self, self,
observation: np.ndarray, observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None, state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None, episode_start: Optional[np.ndarray] = None,
deterministic: bool = False, deterministic: bool = False,
@ -233,8 +235,8 @@ class QRDQN(OffPolicyAlgorithm):
(used in recurrent policies) (used in recurrent policies)
""" """
if not deterministic and np.random.rand() < self.exploration_rate: if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): if self.policy.is_vectorized_observation(observation):
if isinstance(self.observation_space, spaces.Dict): if isinstance(observation, dict):
n_batch = observation[list(observation.keys())[0]].shape[0] n_batch = observation[list(observation.keys())[0]].shape[0]
else: else:
n_batch = observation.shape[0] n_batch = observation.shape[0]

View File

@ -1 +1 @@
1.8.0a10 1.8.0a13

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=1.8.0a10", "stable_baselines3>=1.8.0a13",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",