Fix QR-DQN type hints (#170)

This commit is contained in:
Antonin RAFFIN 2023-03-30 11:50:26 +02:00 committed by GitHub
parent b5fd6e65ba
commit ce115982aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 27 additions and 20 deletions

View File

@ -3,7 +3,7 @@
Changelog
==========
Release 1.8.0a10 (WIP)
Release 1.8.0a13 (WIP)
--------------------------
Breaking Changes:
@ -24,7 +24,8 @@ Others:
^^^^^^^
- Moved to pyproject.toml (except flake8)
- Added github issue forms
- Fix Atari Roms download in CI
- Fixed Atari Roms download in CI
- Fixed ``sb3_contrib/qrdqn/*.py`` type hints
Documentation:
^^^^^^^^^^^^^^

View File

@ -29,8 +29,6 @@ show_error_codes = true
exclude = """(?x)(
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
| sb3_contrib/ars/ars.py$
| sb3_contrib/qrdqn/qrdqn.py$
| sb3_contrib/qrdqn/policies.py$
| sb3_contrib/common/recurrent/policies.py$
| sb3_contrib/common/recurrent/buffers.py$
| sb3_contrib/common/maskable/distributions.py$

View File

@ -31,7 +31,7 @@ class QuantileNetwork(BasePolicy):
self,
observation_space: spaces.Space,
action_space: spaces.Space,
features_extractor: nn.Module,
features_extractor: BaseFeaturesExtractor,
features_dim: int,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
@ -50,7 +50,6 @@ class QuantileNetwork(BasePolicy):
self.net_arch = net_arch
self.activation_fn = activation_fn
self.features_extractor = features_extractor
self.features_dim = features_dim
self.n_quantiles = n_quantiles
action_dim = self.action_space.n # number of actions
@ -64,6 +63,8 @@ class QuantileNetwork(BasePolicy):
:param obs: Observation
:return: The estimated quantiles for each action.
"""
# For type checker:
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
@ -152,7 +153,8 @@ class QRDQNPolicy(BasePolicy):
"normalize_images": normalize_images,
}
self.quantile_net, self.quantile_net_target = None, None
self.quantile_net: QuantileNetwork
self.quantile_net_target: QuantileNetwork
self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None:
@ -168,7 +170,11 @@ class QRDQNPolicy(BasePolicy):
self.quantile_net_target.set_training_mode(False)
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.optimizer = self.optimizer_class( # type: ignore[call-arg]
self.parameters(),
lr=lr_schedule(1),
**self.optimizer_kwargs,
)
def make_quantile_net(self) -> QuantileNetwork:
# Make sure we always have separate networks for features extractors etc

View File

@ -6,9 +6,8 @@ from gym import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
@ -71,12 +70,12 @@ class QRDQN(OffPolicyAlgorithm):
learning_rate: Union[float, Schedule] = 5e-5,
buffer_size: int = 1000000, # 1e6
learning_starts: int = 50000,
batch_size: Optional[int] = 32,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: int = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
@ -124,8 +123,9 @@ class QRDQN(OffPolicyAlgorithm):
# "epsilon" for the epsilon-greedy exploration
self.exploration_rate = 0.0
# Linear schedule will be defined in `_setup_model()`
self.exploration_schedule = None
self.quantile_net, self.quantile_net_target = None, None
self.exploration_schedule: Schedule
self.quantile_net: th.nn.Module
self.quantile_net_target: th.nn.Module
if "optimizer_class" not in self.policy_kwargs:
self.policy_kwargs["optimizer_class"] = th.optim.Adam
@ -146,6 +146,8 @@ class QRDQN(OffPolicyAlgorithm):
)
def _create_aliases(self) -> None:
# For type checker:
assert isinstance(self.policy, QRDQNPolicy)
self.quantile_net = self.policy.quantile_net
self.quantile_net_target = self.policy.quantile_net_target
self.n_quantiles = self.policy.n_quantiles
@ -172,7 +174,7 @@ class QRDQN(OffPolicyAlgorithm):
losses = []
for _ in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
with th.no_grad():
# Compute the quantiles of next observation
@ -214,7 +216,7 @@ class QRDQN(OffPolicyAlgorithm):
def predict(
self,
observation: np.ndarray,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
@ -233,8 +235,8 @@ class QRDQN(OffPolicyAlgorithm):
(used in recurrent policies)
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
if isinstance(self.observation_space, spaces.Dict):
if self.policy.is_vectorized_observation(observation):
if isinstance(observation, dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
else:
n_batch = observation.shape[0]

View File

@ -1 +1 @@
1.8.0a10
1.8.0a13

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.8.0a10",
"stable_baselines3>=1.8.0a13",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",