From ce115982aa478e8f462d3d8d730c82b152d5255e Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 30 Mar 2023 11:50:26 +0200 Subject: [PATCH] Fix QR-DQN type hints (#170) --- docs/misc/changelog.rst | 5 +++-- pyproject.toml | 2 -- sb3_contrib/qrdqn/policies.py | 14 ++++++++++---- sb3_contrib/qrdqn/qrdqn.py | 22 ++++++++++++---------- sb3_contrib/version.txt | 2 +- setup.py | 2 +- 6 files changed, 27 insertions(+), 20 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 807eced..45c32ac 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.8.0a10 (WIP) +Release 1.8.0a13 (WIP) -------------------------- Breaking Changes: @@ -24,7 +24,8 @@ Others: ^^^^^^^ - Moved to pyproject.toml (except flake8) - Added github issue forms -- Fix Atari Roms download in CI +- Fixed Atari Roms download in CI +- Fixed ``sb3_contrib/qrdqn/*.py`` type hints Documentation: ^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 0fda57f..ff29993 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,8 +29,6 @@ show_error_codes = true exclude = """(?x)( sb3_contrib/ppo_recurrent/ppo_recurrent.py$ | sb3_contrib/ars/ars.py$ - | sb3_contrib/qrdqn/qrdqn.py$ - | sb3_contrib/qrdqn/policies.py$ | sb3_contrib/common/recurrent/policies.py$ | sb3_contrib/common/recurrent/buffers.py$ | sb3_contrib/common/maskable/distributions.py$ diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index 6090097..0491f50 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -31,7 +31,7 @@ class QuantileNetwork(BasePolicy): self, observation_space: spaces.Space, action_space: spaces.Space, - features_extractor: nn.Module, + features_extractor: BaseFeaturesExtractor, features_dim: int, n_quantiles: int = 200, net_arch: Optional[List[int]] = None, @@ -50,7 +50,6 @@ class QuantileNetwork(BasePolicy): self.net_arch = net_arch self.activation_fn = activation_fn - self.features_extractor = features_extractor self.features_dim = features_dim self.n_quantiles = n_quantiles action_dim = self.action_space.n # number of actions @@ -64,6 +63,8 @@ class QuantileNetwork(BasePolicy): :param obs: Observation :return: The estimated quantiles for each action. """ + # For type checker: + assert isinstance(self.features_extractor, BaseFeaturesExtractor) quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor)) return quantiles.view(-1, self.n_quantiles, self.action_space.n) @@ -152,7 +153,8 @@ class QRDQNPolicy(BasePolicy): "normalize_images": normalize_images, } - self.quantile_net, self.quantile_net_target = None, None + self.quantile_net: QuantileNetwork + self.quantile_net_target: QuantileNetwork self._build(lr_schedule) def _build(self, lr_schedule: Schedule) -> None: @@ -168,7 +170,11 @@ class QRDQNPolicy(BasePolicy): self.quantile_net_target.set_training_mode(False) # Setup optimizer with initial learning rate - self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.optimizer = self.optimizer_class( # type: ignore[call-arg] + self.parameters(), + lr=lr_schedule(1), + **self.optimizer_kwargs, + ) def make_quantile_net(self) -> QuantileNetwork: # Make sure we always have separate networks for features extractors etc diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index a17b4e1..7e1785b 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -6,9 +6,8 @@ from gym import spaces from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy -from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update +from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy @@ -71,12 +70,12 @@ class QRDQN(OffPolicyAlgorithm): learning_rate: Union[float, Schedule] = 5e-5, buffer_size: int = 1000000, # 1e6 learning_starts: int = 50000, - batch_size: Optional[int] = 32, + batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, train_freq: int = 4, gradient_steps: int = 1, - replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, target_update_interval: int = 10000, @@ -124,8 +123,9 @@ class QRDQN(OffPolicyAlgorithm): # "epsilon" for the epsilon-greedy exploration self.exploration_rate = 0.0 # Linear schedule will be defined in `_setup_model()` - self.exploration_schedule = None - self.quantile_net, self.quantile_net_target = None, None + self.exploration_schedule: Schedule + self.quantile_net: th.nn.Module + self.quantile_net_target: th.nn.Module if "optimizer_class" not in self.policy_kwargs: self.policy_kwargs["optimizer_class"] = th.optim.Adam @@ -146,6 +146,8 @@ class QRDQN(OffPolicyAlgorithm): ) def _create_aliases(self) -> None: + # For type checker: + assert isinstance(self.policy, QRDQNPolicy) self.quantile_net = self.policy.quantile_net self.quantile_net_target = self.policy.quantile_net_target self.n_quantiles = self.policy.n_quantiles @@ -172,7 +174,7 @@ class QRDQN(OffPolicyAlgorithm): losses = [] for _ in range(gradient_steps): # Sample replay buffer - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] with th.no_grad(): # Compute the quantiles of next observation @@ -214,7 +216,7 @@ class QRDQN(OffPolicyAlgorithm): def predict( self, - observation: np.ndarray, + observation: Union[np.ndarray, Dict[str, np.ndarray]], state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, @@ -233,8 +235,8 @@ class QRDQN(OffPolicyAlgorithm): (used in recurrent policies) """ if not deterministic and np.random.rand() < self.exploration_rate: - if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): - if isinstance(self.observation_space, spaces.Dict): + if self.policy.is_vectorized_observation(observation): + if isinstance(observation, dict): n_batch = observation[list(observation.keys())[0]].shape[0] else: n_batch = observation.shape[0] diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index cba76ac..e9033cc 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.8.0a10 +1.8.0a13 diff --git a/setup.py b/setup.py index acad2f9..b3a17cc 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.8.0a10", + "stable_baselines3>=1.8.0a13", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin",