Update citation (#54)
* Update citation * Fixes for new SB3 version * Fix type hint * Additional fixes
This commit is contained in:
parent
b1397bbb72
commit
cd0a5e516f
|
|
@ -36,6 +36,8 @@ jobs:
|
||||||
pip install .
|
pip install .
|
||||||
# Use headless version
|
# Use headless version
|
||||||
pip install opencv-python-headless
|
pip install opencv-python-headless
|
||||||
|
# Tmp fix: ROM missing in the newest atari-py version
|
||||||
|
pip install atari-py==0.2.5
|
||||||
- name: Build the doc
|
- name: Build the doc
|
||||||
run: |
|
run: |
|
||||||
make doc
|
make doc
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
@article{stable-baselines3,
|
||||||
|
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
|
||||||
|
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
|
||||||
|
journal = {Journal of Machine Learning Research},
|
||||||
|
year = {2021},
|
||||||
|
volume = {22},
|
||||||
|
number = {268},
|
||||||
|
pages = {1-8},
|
||||||
|
url = {http://jmlr.org/papers/v22/20-1364.html}
|
||||||
|
}
|
||||||
33
CITATION.cff
33
CITATION.cff
|
|
@ -1,33 +0,0 @@
|
||||||
cff-version: 1.2.0
|
|
||||||
message: "If you use this software, please cite it as below."
|
|
||||||
authors:
|
|
||||||
-
|
|
||||||
family-names: Raffin
|
|
||||||
given-names: Antonin
|
|
||||||
orcid: "https://orcid.org/0000-0001-6036-6950"
|
|
||||||
-
|
|
||||||
family-names: Hill
|
|
||||||
given-names: Ashley
|
|
||||||
-
|
|
||||||
family-names: Enerstus
|
|
||||||
given-names: Maximilian
|
|
||||||
-
|
|
||||||
family-names: Gleave
|
|
||||||
given-names: Adam
|
|
||||||
orcid: "https://orcid.org/0000-0002-3467-528X"
|
|
||||||
-
|
|
||||||
family-names: Kanervisto
|
|
||||||
given-names: Anssi
|
|
||||||
orcid: "https://orcid.org/0000-0002-7479-4574"
|
|
||||||
-
|
|
||||||
family-names: Dormann
|
|
||||||
given-names: Noah
|
|
||||||
|
|
||||||
title: "Stable Baselines3"
|
|
||||||
abbreviation: SB3
|
|
||||||
url: "https://github.com/DLR-RM/stable-baselines3"
|
|
||||||
repository-code: "https://github.com/DLR-RM/stable-baselines3"
|
|
||||||
version: 1.1.0
|
|
||||||
date-released: 2020-05-05
|
|
||||||
license: MIT
|
|
||||||
# doi: # TODO when paper is released
|
|
||||||
16
README.md
16
README.md
|
|
@ -68,12 +68,14 @@ If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md)
|
||||||
To cite this repository in publications (please cite SB3 directly):
|
To cite this repository in publications (please cite SB3 directly):
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@misc{stable-baselines3,
|
@article{stable-baselines3,
|
||||||
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
|
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
|
||||||
title = {Stable Baselines3},
|
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
|
||||||
year = {2019},
|
journal = {Journal of Machine Learning Research},
|
||||||
publisher = {GitHub},
|
year = {2021},
|
||||||
journal = {GitHub repository},
|
volume = {22},
|
||||||
howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
|
number = {268},
|
||||||
|
pages = {1-8},
|
||||||
|
url = {http://jmlr.org/papers/v22/20-1364.html}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,30 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
|
Release 1.3.1a3 (WIP)
|
||||||
|
-------------------------------
|
||||||
|
|
||||||
|
Breaking Changes:
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Upgraded to Stable-Baselines3 >= 1.3.1a3
|
||||||
|
|
||||||
|
New Features:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
Deprecations:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Others:
|
||||||
|
^^^^^^^
|
||||||
|
|
||||||
|
|
||||||
|
Documentation:
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Release 1.3.0 (2021-10-23)
|
Release 1.3.0 (2021-10-23)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -227,11 +227,11 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||||
state: Optional[np.ndarray] = None,
|
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||||
mask: Optional[np.ndarray] = None,
|
episode_start: Optional[np.ndarray] = None,
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
action_masks: Optional[np.ndarray] = None,
|
action_masks: Optional[np.ndarray] = None,
|
||||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||||
"""
|
"""
|
||||||
Get the policy action and state from an observation (and optional state).
|
Get the policy action and state from an observation (and optional state).
|
||||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||||
|
|
|
||||||
|
|
@ -365,11 +365,11 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
observation: np.ndarray,
|
observation: np.ndarray,
|
||||||
state: Optional[np.ndarray] = None,
|
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||||
mask: Optional[np.ndarray] = None,
|
episode_start: Optional[np.ndarray] = None,
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
action_masks: Optional[np.ndarray] = None,
|
action_masks: Optional[np.ndarray] = None,
|
||||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||||
"""
|
"""
|
||||||
Get the model's action(s) from an observation.
|
Get the model's action(s) from an observation.
|
||||||
|
|
||||||
|
|
@ -380,7 +380,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
:param action_masks: Action masks to apply to the action distribution.
|
:param action_masks: Action masks to apply to the action distribution.
|
||||||
:return: the model's action and the next state (used in recurrent policies)
|
:return: the model's action and the next state (used in recurrent policies)
|
||||||
"""
|
"""
|
||||||
return self.policy.predict(observation, state, mask, deterministic, action_masks=action_masks)
|
return self.policy.predict(observation, state, episode_start, deterministic, action_masks=action_masks)
|
||||||
|
|
||||||
def train(self) -> None:
|
def train(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -206,10 +206,10 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
observation: np.ndarray,
|
observation: np.ndarray,
|
||||||
state: Optional[np.ndarray] = None,
|
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||||
mask: Optional[np.ndarray] = None,
|
episode_start: Optional[np.ndarray] = None,
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||||
"""
|
"""
|
||||||
Overrides the base_class predict function to include epsilon-greedy exploration.
|
Overrides the base_class predict function to include epsilon-greedy exploration.
|
||||||
|
|
||||||
|
|
@ -230,7 +230,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
else:
|
else:
|
||||||
action = np.array(self.action_space.sample())
|
action = np.array(self.action_space.sample())
|
||||||
else:
|
else:
|
||||||
action, state = self.policy.predict(observation, state, mask, deterministic)
|
action, state = self.policy.predict(observation, state, episode_start, deterministic)
|
||||||
return action, state
|
return action, state
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.3.0
|
1.3.1a3
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -63,7 +63,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.3.0",
|
"stable_baselines3>=1.3.1a3",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv
|
from stable_baselines3.common.envs import SimpleMultiObsEnv
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
|
||||||
|
|
||||||
|
|
@ -224,8 +224,8 @@ def test_vec_normalize(model_class):
|
||||||
Additional tests to check observation space support
|
Additional tests to check observation space support
|
||||||
for GoalEnv and VecNormalize using MultiInputPolicy.
|
for GoalEnv and VecNormalize using MultiInputPolicy.
|
||||||
"""
|
"""
|
||||||
env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not (model_class == QRDQN))])
|
env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(DummyDictEnv(use_discrete_actions=model_class == QRDQN), 100)])
|
||||||
env = VecNormalize(env)
|
env = VecNormalize(env, norm_obs_keys=["vec"])
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
n_steps = 256
|
n_steps = 256
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue