Fix QR-DQN type hints (#170)
This commit is contained in:
parent
b5fd6e65ba
commit
ce115982aa
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.8.0a10 (WIP)
|
||||
Release 1.8.0a13 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -24,7 +24,8 @@ Others:
|
|||
^^^^^^^
|
||||
- Moved to pyproject.toml (except flake8)
|
||||
- Added github issue forms
|
||||
- Fix Atari Roms download in CI
|
||||
- Fixed Atari Roms download in CI
|
||||
- Fixed ``sb3_contrib/qrdqn/*.py`` type hints
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -29,8 +29,6 @@ show_error_codes = true
|
|||
exclude = """(?x)(
|
||||
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
|
||||
| sb3_contrib/ars/ars.py$
|
||||
| sb3_contrib/qrdqn/qrdqn.py$
|
||||
| sb3_contrib/qrdqn/policies.py$
|
||||
| sb3_contrib/common/recurrent/policies.py$
|
||||
| sb3_contrib/common/recurrent/buffers.py$
|
||||
| sb3_contrib/common/maskable/distributions.py$
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class QuantileNetwork(BasePolicy):
|
|||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
features_extractor: nn.Module,
|
||||
features_extractor: BaseFeaturesExtractor,
|
||||
features_dim: int,
|
||||
n_quantiles: int = 200,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
|
|
@ -50,7 +50,6 @@ class QuantileNetwork(BasePolicy):
|
|||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.features_extractor = features_extractor
|
||||
self.features_dim = features_dim
|
||||
self.n_quantiles = n_quantiles
|
||||
action_dim = self.action_space.n # number of actions
|
||||
|
|
@ -64,6 +63,8 @@ class QuantileNetwork(BasePolicy):
|
|||
:param obs: Observation
|
||||
:return: The estimated quantiles for each action.
|
||||
"""
|
||||
# For type checker:
|
||||
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
|
||||
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
|
||||
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
|
||||
|
||||
|
|
@ -152,7 +153,8 @@ class QRDQNPolicy(BasePolicy):
|
|||
"normalize_images": normalize_images,
|
||||
}
|
||||
|
||||
self.quantile_net, self.quantile_net_target = None, None
|
||||
self.quantile_net: QuantileNetwork
|
||||
self.quantile_net_target: QuantileNetwork
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule: Schedule) -> None:
|
||||
|
|
@ -168,7 +170,11 @@ class QRDQNPolicy(BasePolicy):
|
|||
self.quantile_net_target.set_training_mode(False)
|
||||
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
self.optimizer = self.optimizer_class( # type: ignore[call-arg]
|
||||
self.parameters(),
|
||||
lr=lr_schedule(1),
|
||||
**self.optimizer_kwargs,
|
||||
)
|
||||
|
||||
def make_quantile_net(self) -> QuantileNetwork:
|
||||
# Make sure we always have separate networks for features extractors etc
|
||||
|
|
|
|||
|
|
@ -6,9 +6,8 @@ from gym import spaces
|
|||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
|
||||
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
|
||||
|
||||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
||||
|
|
@ -71,12 +70,12 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
learning_rate: Union[float, Schedule] = 5e-5,
|
||||
buffer_size: int = 1000000, # 1e6
|
||||
learning_starts: int = 50000,
|
||||
batch_size: Optional[int] = 32,
|
||||
batch_size: int = 32,
|
||||
tau: float = 1.0,
|
||||
gamma: float = 0.99,
|
||||
train_freq: int = 4,
|
||||
gradient_steps: int = 1,
|
||||
replay_buffer_class: Optional[ReplayBuffer] = None,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
target_update_interval: int = 10000,
|
||||
|
|
@ -124,8 +123,9 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
# "epsilon" for the epsilon-greedy exploration
|
||||
self.exploration_rate = 0.0
|
||||
# Linear schedule will be defined in `_setup_model()`
|
||||
self.exploration_schedule = None
|
||||
self.quantile_net, self.quantile_net_target = None, None
|
||||
self.exploration_schedule: Schedule
|
||||
self.quantile_net: th.nn.Module
|
||||
self.quantile_net_target: th.nn.Module
|
||||
|
||||
if "optimizer_class" not in self.policy_kwargs:
|
||||
self.policy_kwargs["optimizer_class"] = th.optim.Adam
|
||||
|
|
@ -146,6 +146,8 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
)
|
||||
|
||||
def _create_aliases(self) -> None:
|
||||
# For type checker:
|
||||
assert isinstance(self.policy, QRDQNPolicy)
|
||||
self.quantile_net = self.policy.quantile_net
|
||||
self.quantile_net_target = self.policy.quantile_net_target
|
||||
self.n_quantiles = self.policy.n_quantiles
|
||||
|
|
@ -172,7 +174,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
losses = []
|
||||
for _ in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||
|
||||
with th.no_grad():
|
||||
# Compute the quantiles of next observation
|
||||
|
|
@ -214,7 +216,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
|
||||
def predict(
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
|
|
@ -233,8 +235,8 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
(used in recurrent policies)
|
||||
"""
|
||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
if self.policy.is_vectorized_observation(observation):
|
||||
if isinstance(observation, dict):
|
||||
n_batch = observation[list(observation.keys())[0]].shape[0]
|
||||
else:
|
||||
n_batch = observation.shape[0]
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0a10
|
||||
1.8.0a13
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.8.0a10",
|
||||
"stable_baselines3>=1.8.0a13",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
Loading…
Reference in New Issue