Upgrade to python 3.7+ syntax (#69)

* Upgrade to python 3.7+ syntax

* Switch to PyTorch 1.11
This commit is contained in:
Antonin RAFFIN 2022-04-25 13:02:07 +02:00 committed by GitHub
parent 812648e6cd
commit bec00386d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 39 additions and 40 deletions

View File

@ -28,7 +28,7 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
# cpu version of pytorch # cpu version of pytorch
pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
# Install dependencies for docs and tests # Install dependencies for docs and tests
pip install stable_baselines3[extra,tests,docs] pip install stable_baselines3[extra,tests,docs]
# Install master version # Install master version

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# #
# Configuration file for the Sphinx documentation builder. # Configuration file for the Sphinx documentation builder.
# #
@ -46,7 +45,7 @@ sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
# Read version from file # Read version from file
version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt") version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt")
with open(version_file, "r") as file_handler: with open(version_file) as file_handler:
__version__ = file_handler.read().strip() __version__ = file_handler.read().strip()
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------

View File

@ -3,15 +3,17 @@
Changelog Changelog
========== ==========
Release 1.5.1a1 (WIP) Release 1.5.1a5 (WIP)
------------------------------- -------------------------------
Breaking Changes: Breaking Changes:
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.5.1a1 - Upgraded to Stable-Baselines3 >= 1.5.1a5
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former - Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
- Renamed ``rollout/exploration rate`` key to ``rollout/exploration_rate`` for QRDQN (to be consistent with SB3 DQN) - Renamed ``rollout/exploration rate`` key to ``rollout/exploration_rate`` for QRDQN (to be consistent with SB3 DQN)
- Upgraded to python 3.7+ syntax using ``pyupgrade``
- SB3 now requires PyTorch >= 1.11
New Features: New Features:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^

View File

@ -8,5 +8,5 @@ from sb3_contrib.trpo import TRPO
# Read version from file # Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt") version_file = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_file, "r") as file_handler: with open(version_file) as file_handler:
__version__ = file_handler.read().strip() __version__ = file_handler.read().strip()

View File

@ -145,9 +145,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
n_envs: int = 1, n_envs: int = 1,
): ):
self.action_masks = None self.action_masks = None
super(MaskableDictRolloutBuffer, self).__init__( super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs
)
def reset(self) -> None: def reset(self) -> None:
if isinstance(self.action_space, spaces.Discrete): if isinstance(self.action_space, spaces.Discrete):
@ -162,7 +160,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
self.mask_dims = mask_dims self.mask_dims = mask_dims
self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32) self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
super(MaskableDictRolloutBuffer, self).reset() super().reset()
def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
""" """
@ -171,7 +169,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
if action_masks is not None: if action_masks is not None:
self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims)) self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
super(MaskableDictRolloutBuffer, self).add(*args, **kwargs) super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
assert self.full, "" assert self.full, ""

View File

@ -345,7 +345,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
): ):
super(MaskableActorCriticCnnPolicy, self).__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
lr_schedule, lr_schedule,
@ -396,7 +396,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
): ):
super(MaskableMultiInputActorCriticPolicy, self).__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
lr_schedule, lr_schedule,

View File

@ -72,7 +72,7 @@ def _worker(
break break
class AsyncEval(object): class AsyncEval:
""" """
Helper class to do asynchronous evaluation of different policies with multiple processes. Helper class to do asynchronous evaluation of different policies with multiple processes.
It is useful when implementing population based methods like Evolution Strategies (ES), It is useful when implementing population based methods like Evolution Strategies (ES),

View File

@ -50,7 +50,7 @@ class TimeFeatureWrapper(gym.Wrapper):
else: else:
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype) env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
super(TimeFeatureWrapper, self).__init__(env) super().__init__(env)
# Try to infer the max number of steps per episode # Try to infer the max number of steps per episode
try: try:

View File

@ -38,7 +38,7 @@ class QuantileNetwork(BasePolicy):
activation_fn: Type[nn.Module] = nn.ReLU, activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True, normalize_images: bool = True,
): ):
super(QuantileNetwork, self).__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
features_extractor=features_extractor, features_extractor=features_extractor,
@ -125,7 +125,7 @@ class QRDQNPolicy(BasePolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
): ):
super(QRDQNPolicy, self).__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
features_extractor_class, features_extractor_class,
@ -246,7 +246,7 @@ class CnnPolicy(QRDQNPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
): ):
super(CnnPolicy, self).__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
lr_schedule, lr_schedule,
@ -294,7 +294,7 @@ class MultiInputPolicy(QRDQNPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
): ):
super(MultiInputPolicy, self).__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
lr_schedule, lr_schedule,

View File

@ -93,7 +93,7 @@ class QRDQN(OffPolicyAlgorithm):
_init_setup_model: bool = True, _init_setup_model: bool = True,
): ):
super(QRDQN, self).__init__( super().__init__(
policy, policy,
env, env,
learning_rate, learning_rate,
@ -139,7 +139,7 @@ class QRDQN(OffPolicyAlgorithm):
self._setup_model() self._setup_model()
def _setup_model(self) -> None: def _setup_model(self) -> None:
super(QRDQN, self)._setup_model() super()._setup_model()
self._create_aliases() self._create_aliases()
self.exploration_schedule = get_linear_fn( self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
@ -253,7 +253,7 @@ class QRDQN(OffPolicyAlgorithm):
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm: ) -> OffPolicyAlgorithm:
return super(QRDQN, self).learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
callback=callback, callback=callback,
log_interval=log_interval, log_interval=log_interval,
@ -266,7 +266,7 @@ class QRDQN(OffPolicyAlgorithm):
) )
def _excluded_save_params(self) -> List[str]: def _excluded_save_params(self) -> List[str]:
return super(QRDQN, self)._excluded_save_params() + ["quantile_net", "quantile_net_target"] return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"] state_dicts = ["policy", "policy.optimizer"]

View File

@ -64,7 +64,7 @@ class Actor(BasePolicy):
clip_mean: float = 2.0, clip_mean: float = 2.0,
normalize_images: bool = True, normalize_images: bool = True,
): ):
super(Actor, self).__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
features_extractor=features_extractor, features_extractor=features_extractor,
@ -299,7 +299,7 @@ class TQCPolicy(BasePolicy):
n_critics: int = 2, n_critics: int = 2,
share_features_extractor: bool = True, share_features_extractor: bool = True,
): ):
super(TQCPolicy, self).__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
features_extractor_class, features_extractor_class,
@ -486,7 +486,7 @@ class CnnPolicy(TQCPolicy):
n_critics: int = 2, n_critics: int = 2,
share_features_extractor: bool = True, share_features_extractor: bool = True,
): ):
super(CnnPolicy, self).__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
lr_schedule, lr_schedule,
@ -560,7 +560,7 @@ class MultiInputPolicy(TQCPolicy):
n_critics: int = 2, n_critics: int = 2,
share_features_extractor: bool = True, share_features_extractor: bool = True,
): ):
super(MultiInputPolicy, self).__init__( super().__init__(
observation_space, observation_space,
action_space, action_space,
lr_schedule, lr_schedule,

View File

@ -103,7 +103,7 @@ class TQC(OffPolicyAlgorithm):
_init_setup_model: bool = True, _init_setup_model: bool = True,
): ):
super(TQC, self).__init__( super().__init__(
policy, policy,
env, env,
learning_rate, learning_rate,
@ -144,7 +144,7 @@ class TQC(OffPolicyAlgorithm):
self._setup_model() self._setup_model()
def _setup_model(self) -> None: def _setup_model(self) -> None:
super(TQC, self)._setup_model() super()._setup_model()
self._create_aliases() self._create_aliases()
# Target entropy is used when learning the entropy coefficient # Target entropy is used when learning the entropy coefficient
@ -293,7 +293,7 @@ class TQC(OffPolicyAlgorithm):
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm: ) -> OffPolicyAlgorithm:
return super(TQC, self).learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
callback=callback, callback=callback,
log_interval=log_interval, log_interval=log_interval,
@ -307,7 +307,7 @@ class TQC(OffPolicyAlgorithm):
def _excluded_save_params(self) -> List[str]: def _excluded_save_params(self) -> List[str]:
# Exclude aliases # Exclude aliases
return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]

View File

@ -101,7 +101,7 @@ class TRPO(OnPolicyAlgorithm):
_init_setup_model: bool = True, _init_setup_model: bool = True,
): ):
super(TRPO, self).__init__( super().__init__(
policy, policy,
env, env,
learning_rate=learning_rate, learning_rate=learning_rate,
@ -414,7 +414,7 @@ class TRPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True, reset_num_timesteps: bool = True,
) -> OnPolicyAlgorithm: ) -> OnPolicyAlgorithm:
return super(TRPO, self).learn( return super().learn(
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
callback=callback, callback=callback,
log_interval=log_interval, log_interval=log_interval,

View File

@ -1 +1 @@
1.5.1a1 1.5.1a5

View File

@ -2,7 +2,7 @@ import os
from setuptools import find_packages, setup from setuptools import find_packages, setup
with open(os.path.join("sb3_contrib", "version.txt"), "r") as file_handler: with open(os.path.join("sb3_contrib", "version.txt")) as file_handler:
__version__ = file_handler.read().strip() __version__ = file_handler.read().strip()
@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=1.5.1a1", "stable_baselines3>=1.5.1a5",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

View File

@ -66,7 +66,7 @@ def test_save_load(tmp_path, model_class):
model.set_parameters(invalid_object_params, exact_match=False) model.set_parameters(invalid_object_params, exact_match=False)
# Test that exact_match catches when something was missed. # Test that exact_match catches when something was missed.
missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1]) missing_object_params = {k: v for k, v in list(original_params.items())[:-1]}
with pytest.raises(ValueError): with pytest.raises(ValueError):
model.set_parameters(missing_object_params, exact_match=True) model.set_parameters(missing_object_params, exact_match=True)
@ -312,7 +312,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
params = deepcopy(policy.state_dict()) params = deepcopy(policy.state_dict())
# Modify all parameters to be random values # Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values # Update model parameters with the new random values
policy.load_state_dict(random_params) policy.load_state_dict(random_params)
@ -409,7 +409,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
params = deepcopy(q_net.state_dict()) params = deepcopy(q_net.state_dict())
# Modify all parameters to be random values # Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values # Update model parameters with the new random values
q_net.load_state_dict(random_params) q_net.load_state_dict(random_params)

View File

@ -21,7 +21,7 @@ class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
""" """
def __init__(self, observation_space: gym.Space): def __init__(self, observation_space: gym.Space):
super(FlattenBatchNormDropoutExtractor, self).__init__( super().__init__(
observation_space, observation_space,
get_flattened_obs_dim(observation_space), get_flattened_obs_dim(observation_space),
) )