Fix QR-DQN type hints (#170)
This commit is contained in:
parent
b5fd6e65ba
commit
ce115982aa
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.8.0a10 (WIP)
|
Release 1.8.0a13 (WIP)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -24,7 +24,8 @@ Others:
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
- Moved to pyproject.toml (except flake8)
|
- Moved to pyproject.toml (except flake8)
|
||||||
- Added github issue forms
|
- Added github issue forms
|
||||||
- Fix Atari Roms download in CI
|
- Fixed Atari Roms download in CI
|
||||||
|
- Fixed ``sb3_contrib/qrdqn/*.py`` type hints
|
||||||
|
|
||||||
Documentation:
|
Documentation:
|
||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -29,8 +29,6 @@ show_error_codes = true
|
||||||
exclude = """(?x)(
|
exclude = """(?x)(
|
||||||
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
|
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
|
||||||
| sb3_contrib/ars/ars.py$
|
| sb3_contrib/ars/ars.py$
|
||||||
| sb3_contrib/qrdqn/qrdqn.py$
|
|
||||||
| sb3_contrib/qrdqn/policies.py$
|
|
||||||
| sb3_contrib/common/recurrent/policies.py$
|
| sb3_contrib/common/recurrent/policies.py$
|
||||||
| sb3_contrib/common/recurrent/buffers.py$
|
| sb3_contrib/common/recurrent/buffers.py$
|
||||||
| sb3_contrib/common/maskable/distributions.py$
|
| sb3_contrib/common/maskable/distributions.py$
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ class QuantileNetwork(BasePolicy):
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
features_extractor: nn.Module,
|
features_extractor: BaseFeaturesExtractor,
|
||||||
features_dim: int,
|
features_dim: int,
|
||||||
n_quantiles: int = 200,
|
n_quantiles: int = 200,
|
||||||
net_arch: Optional[List[int]] = None,
|
net_arch: Optional[List[int]] = None,
|
||||||
|
|
@ -50,7 +50,6 @@ class QuantileNetwork(BasePolicy):
|
||||||
|
|
||||||
self.net_arch = net_arch
|
self.net_arch = net_arch
|
||||||
self.activation_fn = activation_fn
|
self.activation_fn = activation_fn
|
||||||
self.features_extractor = features_extractor
|
|
||||||
self.features_dim = features_dim
|
self.features_dim = features_dim
|
||||||
self.n_quantiles = n_quantiles
|
self.n_quantiles = n_quantiles
|
||||||
action_dim = self.action_space.n # number of actions
|
action_dim = self.action_space.n # number of actions
|
||||||
|
|
@ -64,6 +63,8 @@ class QuantileNetwork(BasePolicy):
|
||||||
:param obs: Observation
|
:param obs: Observation
|
||||||
:return: The estimated quantiles for each action.
|
:return: The estimated quantiles for each action.
|
||||||
"""
|
"""
|
||||||
|
# For type checker:
|
||||||
|
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
|
||||||
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
|
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
|
||||||
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
|
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
|
||||||
|
|
||||||
|
|
@ -152,7 +153,8 @@ class QRDQNPolicy(BasePolicy):
|
||||||
"normalize_images": normalize_images,
|
"normalize_images": normalize_images,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.quantile_net, self.quantile_net_target = None, None
|
self.quantile_net: QuantileNetwork
|
||||||
|
self.quantile_net_target: QuantileNetwork
|
||||||
self._build(lr_schedule)
|
self._build(lr_schedule)
|
||||||
|
|
||||||
def _build(self, lr_schedule: Schedule) -> None:
|
def _build(self, lr_schedule: Schedule) -> None:
|
||||||
|
|
@ -168,7 +170,11 @@ class QRDQNPolicy(BasePolicy):
|
||||||
self.quantile_net_target.set_training_mode(False)
|
self.quantile_net_target.set_training_mode(False)
|
||||||
|
|
||||||
# Setup optimizer with initial learning rate
|
# Setup optimizer with initial learning rate
|
||||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.optimizer = self.optimizer_class( # type: ignore[call-arg]
|
||||||
|
self.parameters(),
|
||||||
|
lr=lr_schedule(1),
|
||||||
|
**self.optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def make_quantile_net(self) -> QuantileNetwork:
|
def make_quantile_net(self) -> QuantileNetwork:
|
||||||
# Make sure we always have separate networks for features extractors etc
|
# Make sure we always have separate networks for features extractors etc
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,8 @@ from gym import spaces
|
||||||
from stable_baselines3.common.buffers import ReplayBuffer
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
|
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
|
||||||
|
|
||||||
from sb3_contrib.common.utils import quantile_huber_loss
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
||||||
|
|
@ -71,12 +70,12 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
learning_rate: Union[float, Schedule] = 5e-5,
|
learning_rate: Union[float, Schedule] = 5e-5,
|
||||||
buffer_size: int = 1000000, # 1e6
|
buffer_size: int = 1000000, # 1e6
|
||||||
learning_starts: int = 50000,
|
learning_starts: int = 50000,
|
||||||
batch_size: Optional[int] = 32,
|
batch_size: int = 32,
|
||||||
tau: float = 1.0,
|
tau: float = 1.0,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
train_freq: int = 4,
|
train_freq: int = 4,
|
||||||
gradient_steps: int = 1,
|
gradient_steps: int = 1,
|
||||||
replay_buffer_class: Optional[ReplayBuffer] = None,
|
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
optimize_memory_usage: bool = False,
|
optimize_memory_usage: bool = False,
|
||||||
target_update_interval: int = 10000,
|
target_update_interval: int = 10000,
|
||||||
|
|
@ -124,8 +123,9 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
# "epsilon" for the epsilon-greedy exploration
|
# "epsilon" for the epsilon-greedy exploration
|
||||||
self.exploration_rate = 0.0
|
self.exploration_rate = 0.0
|
||||||
# Linear schedule will be defined in `_setup_model()`
|
# Linear schedule will be defined in `_setup_model()`
|
||||||
self.exploration_schedule = None
|
self.exploration_schedule: Schedule
|
||||||
self.quantile_net, self.quantile_net_target = None, None
|
self.quantile_net: th.nn.Module
|
||||||
|
self.quantile_net_target: th.nn.Module
|
||||||
|
|
||||||
if "optimizer_class" not in self.policy_kwargs:
|
if "optimizer_class" not in self.policy_kwargs:
|
||||||
self.policy_kwargs["optimizer_class"] = th.optim.Adam
|
self.policy_kwargs["optimizer_class"] = th.optim.Adam
|
||||||
|
|
@ -146,6 +146,8 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_aliases(self) -> None:
|
def _create_aliases(self) -> None:
|
||||||
|
# For type checker:
|
||||||
|
assert isinstance(self.policy, QRDQNPolicy)
|
||||||
self.quantile_net = self.policy.quantile_net
|
self.quantile_net = self.policy.quantile_net
|
||||||
self.quantile_net_target = self.policy.quantile_net_target
|
self.quantile_net_target = self.policy.quantile_net_target
|
||||||
self.n_quantiles = self.policy.n_quantiles
|
self.n_quantiles = self.policy.n_quantiles
|
||||||
|
|
@ -172,7 +174,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
losses = []
|
losses = []
|
||||||
for _ in range(gradient_steps):
|
for _ in range(gradient_steps):
|
||||||
# Sample replay buffer
|
# Sample replay buffer
|
||||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Compute the quantiles of next observation
|
# Compute the quantiles of next observation
|
||||||
|
|
@ -214,7 +216,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
observation: np.ndarray,
|
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||||
episode_start: Optional[np.ndarray] = None,
|
episode_start: Optional[np.ndarray] = None,
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
|
|
@ -233,8 +235,8 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
(used in recurrent policies)
|
(used in recurrent policies)
|
||||||
"""
|
"""
|
||||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||||
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
if self.policy.is_vectorized_observation(observation):
|
||||||
if isinstance(self.observation_space, spaces.Dict):
|
if isinstance(observation, dict):
|
||||||
n_batch = observation[list(observation.keys())[0]].shape[0]
|
n_batch = observation[list(observation.keys())[0]].shape[0]
|
||||||
else:
|
else:
|
||||||
n_batch = observation.shape[0]
|
n_batch = observation.shape[0]
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.8.0a10
|
1.8.0a13
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.8.0a10",
|
"stable_baselines3>=1.8.0a13",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue