From bec00386d14b505015c54413cd5cd968e6f85c72 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 25 Apr 2022 13:02:07 +0200 Subject: [PATCH] Upgrade to python 3.7+ syntax (#69) * Upgrade to python 3.7+ syntax * Switch to PyTorch 1.11 --- .github/workflows/ci.yml | 2 +- docs/conf.py | 3 +-- docs/misc/changelog.rst | 6 ++++-- sb3_contrib/__init__.py | 2 +- sb3_contrib/common/maskable/buffers.py | 8 +++----- sb3_contrib/common/maskable/policies.py | 4 ++-- sb3_contrib/common/vec_env/async_eval.py | 2 +- sb3_contrib/common/wrappers/time_feature.py | 2 +- sb3_contrib/qrdqn/policies.py | 8 ++++---- sb3_contrib/qrdqn/qrdqn.py | 8 ++++---- sb3_contrib/tqc/policies.py | 8 ++++---- sb3_contrib/tqc/tqc.py | 8 ++++---- sb3_contrib/trpo/trpo.py | 4 ++-- sb3_contrib/version.txt | 2 +- setup.py | 4 ++-- tests/test_save_load.py | 6 +++--- tests/test_train_eval_mode.py | 2 +- 17 files changed, 39 insertions(+), 40 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b981df6..eb0e69e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html # Install dependencies for docs and tests pip install stable_baselines3[extra,tests,docs] # Install master version diff --git a/docs/conf.py b/docs/conf.py index ed733df..4521619 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # @@ -46,7 +45,7 @@ sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) # Read version from file version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt") -with open(version_file, "r") as file_handler: +with open(version_file) as file_handler: __version__ = file_handler.read().strip() # -- Project information ----------------------------------------------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 6c38155..1569ceb 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,15 +3,17 @@ Changelog ========== -Release 1.5.1a1 (WIP) +Release 1.5.1a5 (WIP) ------------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Upgraded to Stable-Baselines3 >= 1.5.1a1 +- Upgraded to Stable-Baselines3 >= 1.5.1a5 - Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) - Renamed ``rollout/exploration rate`` key to ``rollout/exploration_rate`` for QRDQN (to be consistent with SB3 DQN) +- Upgraded to python 3.7+ syntax using ``pyupgrade`` +- SB3 now requires PyTorch >= 1.11 New Features: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 2ff218d..6ab21a1 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -8,5 +8,5 @@ from sb3_contrib.trpo import TRPO # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") -with open(version_file, "r") as file_handler: +with open(version_file) as file_handler: __version__ = file_handler.read().strip() diff --git a/sb3_contrib/common/maskable/buffers.py b/sb3_contrib/common/maskable/buffers.py index 71229ef..adfbef7 100644 --- a/sb3_contrib/common/maskable/buffers.py +++ b/sb3_contrib/common/maskable/buffers.py @@ -145,9 +145,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer): n_envs: int = 1, ): self.action_masks = None - super(MaskableDictRolloutBuffer, self).__init__( - buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs - ) + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) def reset(self) -> None: if isinstance(self.action_space, spaces.Discrete): @@ -162,7 +160,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer): self.mask_dims = mask_dims self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32) - super(MaskableDictRolloutBuffer, self).reset() + super().reset() def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: """ @@ -171,7 +169,7 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer): if action_masks is not None: self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims)) - super(MaskableDictRolloutBuffer, self).add(*args, **kwargs) + super().add(*args, **kwargs) def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]: assert self.full, "" diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index 1461eb6..e521c34 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -345,7 +345,7 @@ class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(MaskableActorCriticCnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -396,7 +396,7 @@ class MaskableMultiInputActorCriticPolicy(MaskableActorCriticPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(MaskableMultiInputActorCriticPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/sb3_contrib/common/vec_env/async_eval.py b/sb3_contrib/common/vec_env/async_eval.py index 767ff44..e17df73 100644 --- a/sb3_contrib/common/vec_env/async_eval.py +++ b/sb3_contrib/common/vec_env/async_eval.py @@ -72,7 +72,7 @@ def _worker( break -class AsyncEval(object): +class AsyncEval: """ Helper class to do asynchronous evaluation of different policies with multiple processes. It is useful when implementing population based methods like Evolution Strategies (ES), diff --git a/sb3_contrib/common/wrappers/time_feature.py b/sb3_contrib/common/wrappers/time_feature.py index b63f19c..2aa5a9d 100644 --- a/sb3_contrib/common/wrappers/time_feature.py +++ b/sb3_contrib/common/wrappers/time_feature.py @@ -50,7 +50,7 @@ class TimeFeatureWrapper(gym.Wrapper): else: env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype) - super(TimeFeatureWrapper, self).__init__(env) + super().__init__(env) # Try to infer the max number of steps per episode try: diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index 4aa0f67..d283c2f 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -38,7 +38,7 @@ class QuantileNetwork(BasePolicy): activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True, ): - super(QuantileNetwork, self).__init__( + super().__init__( observation_space, action_space, features_extractor=features_extractor, @@ -125,7 +125,7 @@ class QRDQNPolicy(BasePolicy): optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(QRDQNPolicy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -246,7 +246,7 @@ class CnnPolicy(QRDQNPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(CnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -294,7 +294,7 @@ class MultiInputPolicy(QRDQNPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(MultiInputPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 50aea1c..1d75644 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -93,7 +93,7 @@ class QRDQN(OffPolicyAlgorithm): _init_setup_model: bool = True, ): - super(QRDQN, self).__init__( + super().__init__( policy, env, learning_rate, @@ -139,7 +139,7 @@ class QRDQN(OffPolicyAlgorithm): self._setup_model() def _setup_model(self) -> None: - super(QRDQN, self)._setup_model() + super()._setup_model() self._create_aliases() self.exploration_schedule = get_linear_fn( self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction @@ -253,7 +253,7 @@ class QRDQN(OffPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(QRDQN, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, @@ -266,7 +266,7 @@ class QRDQN(OffPolicyAlgorithm): ) def _excluded_save_params(self) -> List[str]: - return super(QRDQN, self)._excluded_save_params() + ["quantile_net", "quantile_net_target"] + return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "policy.optimizer"] diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index 171846d..dbd073f 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -64,7 +64,7 @@ class Actor(BasePolicy): clip_mean: float = 2.0, normalize_images: bool = True, ): - super(Actor, self).__init__( + super().__init__( observation_space, action_space, features_extractor=features_extractor, @@ -299,7 +299,7 @@ class TQCPolicy(BasePolicy): n_critics: int = 2, share_features_extractor: bool = True, ): - super(TQCPolicy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -486,7 +486,7 @@ class CnnPolicy(TQCPolicy): n_critics: int = 2, share_features_extractor: bool = True, ): - super(CnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -560,7 +560,7 @@ class MultiInputPolicy(TQCPolicy): n_critics: int = 2, share_features_extractor: bool = True, ): - super(MultiInputPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 9c574e2..51d9f12 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -103,7 +103,7 @@ class TQC(OffPolicyAlgorithm): _init_setup_model: bool = True, ): - super(TQC, self).__init__( + super().__init__( policy, env, learning_rate, @@ -144,7 +144,7 @@ class TQC(OffPolicyAlgorithm): self._setup_model() def _setup_model(self) -> None: - super(TQC, self)._setup_model() + super()._setup_model() self._create_aliases() # Target entropy is used when learning the entropy coefficient @@ -293,7 +293,7 @@ class TQC(OffPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(TQC, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, @@ -307,7 +307,7 @@ class TQC(OffPolicyAlgorithm): def _excluded_save_params(self) -> List[str]: # Exclude aliases - return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] + return super()._excluded_save_params() + ["actor", "critic", "critic_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index e3d2f03..f20a9ae 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -101,7 +101,7 @@ class TRPO(OnPolicyAlgorithm): _init_setup_model: bool = True, ): - super(TRPO, self).__init__( + super().__init__( policy, env, learning_rate=learning_rate, @@ -414,7 +414,7 @@ class TRPO(OnPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> OnPolicyAlgorithm: - return super(TRPO, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 1110517..bccb8c6 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.5.1a1 +1.5.1a5 diff --git a/setup.py b/setup.py index c26523a..5aaa498 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ import os from setuptools import find_packages, setup -with open(os.path.join("sb3_contrib", "version.txt"), "r") as file_handler: +with open(os.path.join("sb3_contrib", "version.txt")) as file_handler: __version__ = file_handler.read().strip() @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.5.1a1", + "stable_baselines3>=1.5.1a5", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 2841435..b2a62d2 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -66,7 +66,7 @@ def test_save_load(tmp_path, model_class): model.set_parameters(invalid_object_params, exact_match=False) # Test that exact_match catches when something was missed. - missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1]) + missing_object_params = {k: v for k, v in list(original_params.items())[:-1]} with pytest.raises(ValueError): model.set_parameters(missing_object_params, exact_match=True) @@ -312,7 +312,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str): params = deepcopy(policy.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values policy.load_state_dict(random_params) @@ -409,7 +409,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str): params = deepcopy(q_net.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values q_net.load_state_dict(random_params) diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index eb08fcf..cf7e8aa 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -21,7 +21,7 @@ class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.Space): - super(FlattenBatchNormDropoutExtractor, self).__init__( + super().__init__( observation_space, get_flattened_obs_dim(observation_space), )