Update citation (#54)

* Update citation

* Fixes for new SB3 version

* Fix type hint

* Additional fixes
This commit is contained in:
Antonin RAFFIN 2021-12-01 19:09:32 +01:00 committed by GitHub
parent b1397bbb72
commit cd0a5e516f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 61 additions and 56 deletions

View File

@ -36,6 +36,8 @@ jobs:
pip install .
# Use headless version
pip install opencv-python-headless
# Tmp fix: ROM missing in the newest atari-py version
pip install atari-py==0.2.5
- name: Build the doc
run: |
make doc

10
CITATION.bib Normal file
View File

@ -0,0 +1,10 @@
@article{stable-baselines3,
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
journal = {Journal of Machine Learning Research},
year = {2021},
volume = {22},
number = {268},
pages = {1-8},
url = {http://jmlr.org/papers/v22/20-1364.html}
}

View File

@ -1,33 +0,0 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
-
family-names: Raffin
given-names: Antonin
orcid: "https://orcid.org/0000-0001-6036-6950"
-
family-names: Hill
given-names: Ashley
-
family-names: Enerstus
given-names: Maximilian
-
family-names: Gleave
given-names: Adam
orcid: "https://orcid.org/0000-0002-3467-528X"
-
family-names: Kanervisto
given-names: Anssi
orcid: "https://orcid.org/0000-0002-7479-4574"
-
family-names: Dormann
given-names: Noah
title: "Stable Baselines3"
abbreviation: SB3
url: "https://github.com/DLR-RM/stable-baselines3"
repository-code: "https://github.com/DLR-RM/stable-baselines3"
version: 1.1.0
date-released: 2020-05-05
license: MIT
# doi: # TODO when paper is released

View File

@ -68,12 +68,14 @@ If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md)
To cite this repository in publications (please cite SB3 directly):
```bibtex
@misc{stable-baselines3,
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
title = {Stable Baselines3},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
@article{stable-baselines3,
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
journal = {Journal of Machine Learning Research},
year = {2021},
volume = {22},
number = {268},
pages = {1-8},
url = {http://jmlr.org/papers/v22/20-1364.html}
}
```

View File

@ -3,6 +3,30 @@
Changelog
==========
Release 1.3.1a3 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.3.1a3
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Documentation:
^^^^^^^^^^^^^^
Release 1.3.0 (2021-10-23)
-------------------------------

View File

@ -227,11 +227,11 @@ class MaskableActorCriticPolicy(BasePolicy):
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
action_masks: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action and state from an observation (and optional state).
Includes sugar-coating to handle different observations (e.g. normalizing images).

View File

@ -365,11 +365,11 @@ class MaskablePPO(OnPolicyAlgorithm):
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
action_masks: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the model's action(s) from an observation.
@ -380,7 +380,7 @@ class MaskablePPO(OnPolicyAlgorithm):
:param action_masks: Action masks to apply to the action distribution.
:return: the model's action and the next state (used in recurrent policies)
"""
return self.policy.predict(observation, state, mask, deterministic, action_masks=action_masks)
return self.policy.predict(observation, state, episode_start, deterministic, action_masks=action_masks)
def train(self) -> None:
"""

View File

@ -206,10 +206,10 @@ class QRDQN(OffPolicyAlgorithm):
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Overrides the base_class predict function to include epsilon-greedy exploration.
@ -230,7 +230,7 @@ class QRDQN(OffPolicyAlgorithm):
else:
action = np.array(self.action_space.sample())
else:
action, state = self.policy.predict(observation, state, mask, deterministic)
action, state = self.policy.predict(observation, state, episode_start, deterministic)
return action, state
def learn(

View File

@ -1 +1 @@
1.3.0
1.3.1a3

View File

@ -63,7 +63,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.3.0",
"stable_baselines3>=1.3.1a3",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",

View File

@ -2,7 +2,7 @@ import gym
import numpy as np
import pytest
from gym import spaces
from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv
from stable_baselines3.common.envs import SimpleMultiObsEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
@ -224,8 +224,8 @@ def test_vec_normalize(model_class):
Additional tests to check observation space support
for GoalEnv and VecNormalize using MultiInputPolicy.
"""
env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not (model_class == QRDQN))])
env = VecNormalize(env)
env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(DummyDictEnv(use_discrete_actions=model_class == QRDQN), 100)])
env = VecNormalize(env, norm_obs_keys=["vec"])
kwargs = {}
n_steps = 256