From cd0a5e516f11322404c047dc755c1bf5928d86ae Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 1 Dec 2021 19:09:32 +0100 Subject: [PATCH] Update citation (#54) * Update citation * Fixes for new SB3 version * Fix type hint * Additional fixes --- .github/workflows/ci.yml | 2 ++ CITATION.bib | 10 ++++++++ CITATION.cff | 33 ------------------------- README.md | 16 ++++++------ docs/misc/changelog.rst | 24 ++++++++++++++++++ sb3_contrib/common/maskable/policies.py | 6 ++--- sb3_contrib/ppo_mask/ppo_mask.py | 8 +++--- sb3_contrib/qrdqn/qrdqn.py | 8 +++--- sb3_contrib/version.txt | 2 +- setup.py | 2 +- tests/test_dict_env.py | 6 ++--- 11 files changed, 61 insertions(+), 56 deletions(-) create mode 100644 CITATION.bib delete mode 100644 CITATION.cff diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dc56e25..410e0c6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,6 +36,8 @@ jobs: pip install . # Use headless version pip install opencv-python-headless + # Tmp fix: ROM missing in the newest atari-py version + pip install atari-py==0.2.5 - name: Build the doc run: | make doc diff --git a/CITATION.bib b/CITATION.bib new file mode 100644 index 0000000..1f24b40 --- /dev/null +++ b/CITATION.bib @@ -0,0 +1,10 @@ +@article{stable-baselines3, + author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann}, + title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {268}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-1364.html} +} diff --git a/CITATION.cff b/CITATION.cff deleted file mode 100644 index 8696aa0..0000000 --- a/CITATION.cff +++ /dev/null @@ -1,33 +0,0 @@ -cff-version: 1.2.0 -message: "If you use this software, please cite it as below." -authors: - - - family-names: Raffin - given-names: Antonin - orcid: "https://orcid.org/0000-0001-6036-6950" - - - family-names: Hill - given-names: Ashley - - - family-names: Enerstus - given-names: Maximilian - - - family-names: Gleave - given-names: Adam - orcid: "https://orcid.org/0000-0002-3467-528X" - - - family-names: Kanervisto - given-names: Anssi - orcid: "https://orcid.org/0000-0002-7479-4574" - - - family-names: Dormann - given-names: Noah - -title: "Stable Baselines3" -abbreviation: SB3 -url: "https://github.com/DLR-RM/stable-baselines3" -repository-code: "https://github.com/DLR-RM/stable-baselines3" -version: 1.1.0 -date-released: 2020-05-05 -license: MIT -# doi: # TODO when paper is released diff --git a/README.md b/README.md index b8d205e..3e4c607 100644 --- a/README.md +++ b/README.md @@ -68,12 +68,14 @@ If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md) To cite this repository in publications (please cite SB3 directly): ```bibtex -@misc{stable-baselines3, - author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah}, - title = {Stable Baselines3}, - year = {2019}, - publisher = {GitHub}, - journal = {GitHub repository}, - howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}}, +@article{stable-baselines3, + author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann}, + title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {268}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-1364.html} } ``` diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 474db6b..e885656 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,30 @@ Changelog ========== + +Release 1.3.1a3 (WIP) +------------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Upgraded to Stable-Baselines3 >= 1.3.1a3 + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + + +Documentation: +^^^^^^^^^^^^^^ + Release 1.3.0 (2021-10-23) ------------------------------- diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index d7c9522..449d362 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -227,11 +227,11 @@ class MaskableActorCriticPolicy(BasePolicy): def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, action_masks: Optional[np.ndarray] = None, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Get the policy action and state from an observation (and optional state). Includes sugar-coating to handle different observations (e.g. normalizing images). diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 4648af3..ad08054 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -365,11 +365,11 @@ class MaskablePPO(OnPolicyAlgorithm): def predict( self, observation: np.ndarray, - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, action_masks: Optional[np.ndarray] = None, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Get the model's action(s) from an observation. @@ -380,7 +380,7 @@ class MaskablePPO(OnPolicyAlgorithm): :param action_masks: Action masks to apply to the action distribution. :return: the model's action and the next state (used in recurrent policies) """ - return self.policy.predict(observation, state, mask, deterministic, action_masks=action_masks) + return self.policy.predict(observation, state, episode_start, deterministic, action_masks=action_masks) def train(self) -> None: """ diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index ad6016e..e51ca06 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -206,10 +206,10 @@ class QRDQN(OffPolicyAlgorithm): def predict( self, observation: np.ndarray, - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Overrides the base_class predict function to include epsilon-greedy exploration. @@ -230,7 +230,7 @@ class QRDQN(OffPolicyAlgorithm): else: action = np.array(self.action_space.sample()) else: - action, state = self.policy.predict(observation, state, mask, deterministic) + action, state = self.policy.predict(observation, state, episode_start, deterministic) return action, state def learn( diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index f0bb29e..896c1f3 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.3.0 +1.3.1a3 diff --git a/setup.py b/setup.py index dafed9a..492f059 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.3.0", + "stable_baselines3>=1.3.1a3", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index fda2725..86f5b5e 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -2,7 +2,7 @@ import gym import numpy as np import pytest from gym import spaces -from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv +from stable_baselines3.common.envs import SimpleMultiObsEnv from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize @@ -224,8 +224,8 @@ def test_vec_normalize(model_class): Additional tests to check observation space support for GoalEnv and VecNormalize using MultiInputPolicy. """ - env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not (model_class == QRDQN))]) - env = VecNormalize(env) + env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(DummyDictEnv(use_discrete_actions=model_class == QRDQN), 100)]) + env = VecNormalize(env, norm_obs_keys=["vec"]) kwargs = {} n_steps = 256