diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 41e40ba..6ad4d12 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.7.0a1 (WIP) +Release 1.7.0a2 (WIP) -------------------------- Breaking Changes: @@ -29,6 +29,8 @@ Others: - Fixed flake8 config - Fixed ``sb3_contrib/common/utils.py`` type hint - Fixed ``sb3_contrib/common/recurrent/type_aliases.py`` type hint +- Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong) +- Removed ignores on Flake8 F401 (@ZikangXiong) Release 1.6.2 (2022-10-10) -------------------------- @@ -358,4 +360,4 @@ Contributors: ------------- @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec -@mlodel @CppMaster @burakdmb @honglu2875 +@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 1836ac4..3fbd28d 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -11,3 +11,12 @@ from sb3_contrib.trpo import TRPO version_file = os.path.join(os.path.dirname(__file__), "version.txt") with open(version_file) as file_handler: __version__ = file_handler.read().strip() + +__all__ = [ + "ARS", + "MaskablePPO", + "RecurrentPPO", + "QRDQN", + "TQC", + "TRPO", +] diff --git a/sb3_contrib/ars/__init__.py b/sb3_contrib/ars/__init__.py index 655ac8b..8181d57 100644 --- a/sb3_contrib/ars/__init__.py +++ b/sb3_contrib/ars/__init__.py @@ -1,2 +1,4 @@ from sb3_contrib.ars.ars import ARS from sb3_contrib.ars.policies import LinearPolicy, MlpPolicy + +__all__ = ["ARS", "LinearPolicy", "MlpPolicy"] diff --git a/sb3_contrib/common/envs/__init__.py b/sb3_contrib/common/envs/__init__.py index 20bdffe..e9f740b 100644 --- a/sb3_contrib/common/envs/__init__.py +++ b/sb3_contrib/common/envs/__init__.py @@ -3,3 +3,5 @@ from sb3_contrib.common.envs.invalid_actions_env import ( InvalidActionEnvMultiBinary, InvalidActionEnvMultiDiscrete, ) + +__all__ = ["InvalidActionEnvDiscrete", "InvalidActionEnvMultiBinary", "InvalidActionEnvMultiDiscrete"] diff --git a/sb3_contrib/common/vec_env/__init__.py b/sb3_contrib/common/vec_env/__init__.py index 488abf8..653b572 100644 --- a/sb3_contrib/common/vec_env/__init__.py +++ b/sb3_contrib/common/vec_env/__init__.py @@ -1 +1,3 @@ from sb3_contrib.common.vec_env.async_eval import AsyncEval + +__all__ = ["AsyncEval"] diff --git a/sb3_contrib/common/wrappers/__init__.py b/sb3_contrib/common/wrappers/__init__.py index 8b591b4..c7dc0b0 100644 --- a/sb3_contrib/common/wrappers/__init__.py +++ b/sb3_contrib/common/wrappers/__init__.py @@ -1,2 +1,4 @@ from sb3_contrib.common.wrappers.action_masker import ActionMasker from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper + +__all__ = ["ActionMasker", "TimeFeatureWrapper"] diff --git a/sb3_contrib/ppo_mask/__init__.py b/sb3_contrib/ppo_mask/__init__.py index c26ace8..89d4ced 100644 --- a/sb3_contrib/ppo_mask/__init__.py +++ b/sb3_contrib/ppo_mask/__init__.py @@ -1,2 +1,4 @@ from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "MaskablePPO"] diff --git a/sb3_contrib/ppo_recurrent/__init__.py b/sb3_contrib/ppo_recurrent/__init__.py index 3fb5436..f830104 100644 --- a/sb3_contrib/ppo_recurrent/__init__.py +++ b/sb3_contrib/ppo_recurrent/__init__.py @@ -1,2 +1,4 @@ from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy from sb3_contrib.ppo_recurrent.ppo_recurrent import RecurrentPPO + +__all__ = ["CnnLstmPolicy", "MlpLstmPolicy", "MultiInputLstmPolicy", "RecurrentPPO"] diff --git a/sb3_contrib/qrdqn/__init__.py b/sb3_contrib/qrdqn/__init__.py index 6ca676e..6f13c23 100644 --- a/sb3_contrib/qrdqn/__init__.py +++ b/sb3_contrib/qrdqn/__init__.py @@ -1,2 +1,4 @@ from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.qrdqn.qrdqn import QRDQN + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "QRDQN"] diff --git a/sb3_contrib/tqc/__init__.py b/sb3_contrib/tqc/__init__.py index b8d9a71..e40a55a 100644 --- a/sb3_contrib/tqc/__init__.py +++ b/sb3_contrib/tqc/__init__.py @@ -1,2 +1,4 @@ from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.tqc.tqc import TQC + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TQC"] diff --git a/sb3_contrib/trpo/__init__.py b/sb3_contrib/trpo/__init__.py index 7465a9d..8d6eea7 100644 --- a/sb3_contrib/trpo/__init__.py +++ b/sb3_contrib/trpo/__init__.py @@ -1,2 +1,4 @@ from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.trpo.trpo import TRPO + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TRPO"] diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 12cd5fb..b895285 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.7.0a1 +1.7.0a2 diff --git a/setup.cfg b/setup.cfg index 4004a6b..54a7b44 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,18 +47,6 @@ exclude = (?x)( [flake8] # line breaks before and after binary operators ignore = W503,W504,E203,E231 -# Ignore import not used when aliases are defined -per-file-ignores = - ./sb3_contrib/__init__.py:F401 - ./sb3_contrib/ars/__init__.py:F401 - ./sb3_contrib/ppo_mask/__init__.py:F401 - ./sb3_contrib/ppo_recurrent/__init__.py:F401 - ./sb3_contrib/qrdqn/__init__.py:F401 - ./sb3_contrib/tqc/__init__.py:F401 - ./sb3_contrib/trpo/__init__.py:F401 - ./sb3_contrib/common/wrappers/__init__.py:F401 - ./sb3_contrib/common/envs/__init__.py:F401 - ./sb3_contrib/common/vec_env/__init__.py:F401 exclude = # No need to traverse our git directory