Upgrade to python 3.7+ syntax (#69)

* Upgrade to python 3.7+ syntax

* Switch to PyTorch 1.11
This commit is contained in:
Antonin RAFFIN 2022-04-25 13:02:07 +02:00 committed by GitHub
parent 812648e6cd
commit bec00386d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 39 additions and 40 deletions

View File

@ -28,7 +28,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
# Install dependencies for docs and tests
pip install stable_baselines3[extra,tests,docs]
# Install master version

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
@ -46,7 +45,7 @@ sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt")
with open(version_file, "r") as file_handler:
with open(version_file) as file_handler:
__version__ = file_handler.read().strip()
# -- Project information -----------------------------------------------------

View File

@ -3,15 +3,17 @@
Changelog
==========
Release 1.5.1a1 (WIP)
Release 1.5.1a5 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.5.1a1
- Upgraded to Stable-Baselines3 >= 1.5.1a5
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
- Renamed ``rollout/exploration rate`` key to ``rollout/exploration_rate`` for QRDQN (to be consistent with SB3 DQN)
- Upgraded to python 3.7+ syntax using ``pyupgrade``
- SB3 now requires PyTorch >= 1.11
New Features:
^^^^^^^^^^^^^

View File

@ -8,5 +8,5 @@ from sb3_contrib.trpo import TRPO
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_file, "r") as file_handler:
with open(version_file) as file_handler:
__version__ = file_handler.read().strip()

View File

@ -145,9 +145,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
n_envs: int = 1,
):
self.action_masks = None
super(MaskableDictRolloutBuffer, self).__init__(
buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs
)
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
def reset(self) -> None:
if isinstance(self.action_space, spaces.Discrete):
@ -162,7 +160,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
self.mask_dims = mask_dims
self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
super(MaskableDictRolloutBuffer, self).reset()
super().reset()
def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
"""
@ -171,7 +169,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer):
if action_masks is not None:
self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
super(MaskableDictRolloutBuffer, self).add(*args, **kwargs)
super().add(*args, **kwargs)
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
assert self.full, ""

View File

@ -345,7 +345,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(MaskableActorCriticCnnPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
@ -396,7 +396,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(MaskableMultiInputActorCriticPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,

View File

@ -72,7 +72,7 @@ def _worker(
break
class AsyncEval(object):
class AsyncEval:
"""
Helper class to do asynchronous evaluation of different policies with multiple processes.
It is useful when implementing population based methods like Evolution Strategies (ES),

View File

@ -50,7 +50,7 @@ class TimeFeatureWrapper(gym.Wrapper):
else:
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
super(TimeFeatureWrapper, self).__init__(env)
super().__init__(env)
# Try to infer the max number of steps per episode
try:

View File

@ -38,7 +38,7 @@ class QuantileNetwork(BasePolicy):
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
super(QuantileNetwork, self).__init__(
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
@ -125,7 +125,7 @@ class QRDQNPolicy(BasePolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(QRDQNPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
features_extractor_class,
@ -246,7 +246,7 @@ class CnnPolicy(QRDQNPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(CnnPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
@ -294,7 +294,7 @@ class MultiInputPolicy(QRDQNPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(MultiInputPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,

View File

@ -93,7 +93,7 @@ class QRDQN(OffPolicyAlgorithm):
_init_setup_model: bool = True,
):
super(QRDQN, self).__init__(
super().__init__(
policy,
env,
learning_rate,
@ -139,7 +139,7 @@ class QRDQN(OffPolicyAlgorithm):
self._setup_model()
def _setup_model(self) -> None:
super(QRDQN, self)._setup_model()
super()._setup_model()
self._create_aliases()
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
@ -253,7 +253,7 @@ class QRDQN(OffPolicyAlgorithm):
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
return super(QRDQN, self).learn(
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
@ -266,7 +266,7 @@ class QRDQN(OffPolicyAlgorithm):
)
def _excluded_save_params(self) -> List[str]:
return super(QRDQN, self)._excluded_save_params() + ["quantile_net", "quantile_net_target"]
return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

View File

@ -64,7 +64,7 @@ class Actor(BasePolicy):
clip_mean: float = 2.0,
normalize_images: bool = True,
):
super(Actor, self).__init__(
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
@ -299,7 +299,7 @@ class TQCPolicy(BasePolicy):
n_critics: int = 2,
share_features_extractor: bool = True,
):
super(TQCPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
features_extractor_class,
@ -486,7 +486,7 @@ class CnnPolicy(TQCPolicy):
n_critics: int = 2,
share_features_extractor: bool = True,
):
super(CnnPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
@ -560,7 +560,7 @@ class MultiInputPolicy(TQCPolicy):
n_critics: int = 2,
share_features_extractor: bool = True,
):
super(MultiInputPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,

View File

@ -103,7 +103,7 @@ class TQC(OffPolicyAlgorithm):
_init_setup_model: bool = True,
):
super(TQC, self).__init__(
super().__init__(
policy,
env,
learning_rate,
@ -144,7 +144,7 @@ class TQC(OffPolicyAlgorithm):
self._setup_model()
def _setup_model(self) -> None:
super(TQC, self)._setup_model()
super()._setup_model()
self._create_aliases()
# Target entropy is used when learning the entropy coefficient
@ -293,7 +293,7 @@ class TQC(OffPolicyAlgorithm):
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
return super(TQC, self).learn(
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
@ -307,7 +307,7 @@ class TQC(OffPolicyAlgorithm):
def _excluded_save_params(self) -> List[str]:
# Exclude aliases
return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]

View File

@ -101,7 +101,7 @@ class TRPO(OnPolicyAlgorithm):
_init_setup_model: bool = True,
):
super(TRPO, self).__init__(
super().__init__(
policy,
env,
learning_rate=learning_rate,
@ -414,7 +414,7 @@ class TRPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True,
) -> OnPolicyAlgorithm:
return super(TRPO, self).learn(
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,

View File

@ -1 +1 @@
1.5.1a1
1.5.1a5

View File

@ -2,7 +2,7 @@ import os
from setuptools import find_packages, setup
with open(os.path.join("sb3_contrib", "version.txt"), "r") as file_handler:
with open(os.path.join("sb3_contrib", "version.txt")) as file_handler:
__version__ = file_handler.read().strip()
@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.5.1a1",
"stable_baselines3>=1.5.1a5",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",

View File

@ -66,7 +66,7 @@ def test_save_load(tmp_path, model_class):
model.set_parameters(invalid_object_params, exact_match=False)
# Test that exact_match catches when something was missed.
missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1])
missing_object_params = {k: v for k, v in list(original_params.items())[:-1]}
with pytest.raises(ValueError):
model.set_parameters(missing_object_params, exact_match=True)
@ -312,7 +312,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
params = deepcopy(policy.state_dict())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
policy.load_state_dict(random_params)
@ -409,7 +409,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
params = deepcopy(q_net.state_dict())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
q_net.load_state_dict(random_params)

View File

@ -21,7 +21,7 @@ class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
"""
def __init__(self, observation_space: gym.Space):
super(FlattenBatchNormDropoutExtractor, self).__init__(
super().__init__(
observation_space,
get_flattened_obs_dim(observation_space),
)