Expose modules in `__init__.py` with `__all__` attribute (#124)

* expose modules in __init__.py with __all__ attribute

* Update version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Zikang Xiong 2022-12-05 09:53:57 -05:00 committed by GitHub
parent b3e4ddd09a
commit ddb3a1355e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 32 additions and 15 deletions

View File

@ -3,7 +3,7 @@
Changelog
==========
Release 1.7.0a1 (WIP)
Release 1.7.0a2 (WIP)
--------------------------
Breaking Changes:
@ -29,6 +29,8 @@ Others:
- Fixed flake8 config
- Fixed ``sb3_contrib/common/utils.py`` type hint
- Fixed ``sb3_contrib/common/recurrent/type_aliases.py`` type hint
- Exposed modules in `__init__.py` with `__all__` attribute (@ZikangXiong)
- Removed ignores on Flake8 F401 (@ZikangXiong)
Release 1.6.2 (2022-10-10)
--------------------------
@ -358,4 +360,4 @@ Contributors:
-------------
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
@mlodel @CppMaster @burakdmb @honglu2875
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong

View File

@ -11,3 +11,12 @@ from sb3_contrib.trpo import TRPO
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_file) as file_handler:
__version__ = file_handler.read().strip()
__all__ = [
"ARS",
"MaskablePPO",
"RecurrentPPO",
"QRDQN",
"TQC",
"TRPO",
]

View File

@ -1,2 +1,4 @@
from sb3_contrib.ars.ars import ARS
from sb3_contrib.ars.policies import LinearPolicy, MlpPolicy
__all__ = ["ARS", "LinearPolicy", "MlpPolicy"]

View File

@ -3,3 +3,5 @@ from sb3_contrib.common.envs.invalid_actions_env import (
InvalidActionEnvMultiBinary,
InvalidActionEnvMultiDiscrete,
)
__all__ = ["InvalidActionEnvDiscrete", "InvalidActionEnvMultiBinary", "InvalidActionEnvMultiDiscrete"]

View File

@ -1 +1,3 @@
from sb3_contrib.common.vec_env.async_eval import AsyncEval
__all__ = ["AsyncEval"]

View File

@ -1,2 +1,4 @@
from sb3_contrib.common.wrappers.action_masker import ActionMasker
from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper
__all__ = ["ActionMasker", "TimeFeatureWrapper"]

View File

@ -1,2 +1,4 @@
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "MaskablePPO"]

View File

@ -1,2 +1,4 @@
from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
from sb3_contrib.ppo_recurrent.ppo_recurrent import RecurrentPPO
__all__ = ["CnnLstmPolicy", "MlpLstmPolicy", "MultiInputLstmPolicy", "RecurrentPPO"]

View File

@ -1,2 +1,4 @@
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.qrdqn.qrdqn import QRDQN
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "QRDQN"]

View File

@ -1,2 +1,4 @@
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.tqc.tqc import TQC
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TQC"]

View File

@ -1,2 +1,4 @@
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.trpo.trpo import TRPO
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TRPO"]

View File

@ -1 +1 @@
1.7.0a1
1.7.0a2

View File

@ -47,18 +47,6 @@ exclude = (?x)(
[flake8]
# line breaks before and after binary operators
ignore = W503,W504,E203,E231
# Ignore import not used when aliases are defined
per-file-ignores =
./sb3_contrib/__init__.py:F401
./sb3_contrib/ars/__init__.py:F401
./sb3_contrib/ppo_mask/__init__.py:F401
./sb3_contrib/ppo_recurrent/__init__.py:F401
./sb3_contrib/qrdqn/__init__.py:F401
./sb3_contrib/tqc/__init__.py:F401
./sb3_contrib/trpo/__init__.py:F401
./sb3_contrib/common/wrappers/__init__.py:F401
./sb3_contrib/common/envs/__init__.py:F401
./sb3_contrib/common/vec_env/__init__.py:F401
exclude =
# No need to traverse our git directory