Update citation (#54)
* Update citation * Fixes for new SB3 version * Fix type hint * Additional fixes
This commit is contained in:
parent
b1397bbb72
commit
cd0a5e516f
|
|
@ -36,6 +36,8 @@ jobs:
|
|||
pip install .
|
||||
# Use headless version
|
||||
pip install opencv-python-headless
|
||||
# Tmp fix: ROM missing in the newest atari-py version
|
||||
pip install atari-py==0.2.5
|
||||
- name: Build the doc
|
||||
run: |
|
||||
make doc
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
@article{stable-baselines3,
|
||||
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
|
||||
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
|
||||
journal = {Journal of Machine Learning Research},
|
||||
year = {2021},
|
||||
volume = {22},
|
||||
number = {268},
|
||||
pages = {1-8},
|
||||
url = {http://jmlr.org/papers/v22/20-1364.html}
|
||||
}
|
||||
33
CITATION.cff
33
CITATION.cff
|
|
@ -1,33 +0,0 @@
|
|||
cff-version: 1.2.0
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
-
|
||||
family-names: Raffin
|
||||
given-names: Antonin
|
||||
orcid: "https://orcid.org/0000-0001-6036-6950"
|
||||
-
|
||||
family-names: Hill
|
||||
given-names: Ashley
|
||||
-
|
||||
family-names: Enerstus
|
||||
given-names: Maximilian
|
||||
-
|
||||
family-names: Gleave
|
||||
given-names: Adam
|
||||
orcid: "https://orcid.org/0000-0002-3467-528X"
|
||||
-
|
||||
family-names: Kanervisto
|
||||
given-names: Anssi
|
||||
orcid: "https://orcid.org/0000-0002-7479-4574"
|
||||
-
|
||||
family-names: Dormann
|
||||
given-names: Noah
|
||||
|
||||
title: "Stable Baselines3"
|
||||
abbreviation: SB3
|
||||
url: "https://github.com/DLR-RM/stable-baselines3"
|
||||
repository-code: "https://github.com/DLR-RM/stable-baselines3"
|
||||
version: 1.1.0
|
||||
date-released: 2020-05-05
|
||||
license: MIT
|
||||
# doi: # TODO when paper is released
|
||||
16
README.md
16
README.md
|
|
@ -68,12 +68,14 @@ If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md)
|
|||
To cite this repository in publications (please cite SB3 directly):
|
||||
|
||||
```bibtex
|
||||
@misc{stable-baselines3,
|
||||
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
|
||||
title = {Stable Baselines3},
|
||||
year = {2019},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
|
||||
@article{stable-baselines3,
|
||||
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
|
||||
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
|
||||
journal = {Journal of Machine Learning Research},
|
||||
year = {2021},
|
||||
volume = {22},
|
||||
number = {268},
|
||||
pages = {1-8},
|
||||
url = {http://jmlr.org/papers/v22/20-1364.html}
|
||||
}
|
||||
```
|
||||
|
|
|
|||
|
|
@ -3,6 +3,30 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
|
||||
Release 1.3.1a3 (WIP)
|
||||
-------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Upgraded to Stable-Baselines3 >= 1.3.1a3
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
Release 1.3.0 (2021-10-23)
|
||||
-------------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -227,11 +227,11 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action and state from an observation (and optional state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
|
|
|||
|
|
@ -365,11 +365,11 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
def predict(
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
action_masks: Optional[np.ndarray] = None,
|
||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the model's action(s) from an observation.
|
||||
|
||||
|
|
@ -380,7 +380,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
:param action_masks: Action masks to apply to the action distribution.
|
||||
:return: the model's action and the next state (used in recurrent policies)
|
||||
"""
|
||||
return self.policy.predict(observation, state, mask, deterministic, action_masks=action_masks)
|
||||
return self.policy.predict(observation, state, episode_start, deterministic, action_masks=action_masks)
|
||||
|
||||
def train(self) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -206,10 +206,10 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
def predict(
|
||||
self,
|
||||
observation: np.ndarray,
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Overrides the base_class predict function to include epsilon-greedy exploration.
|
||||
|
||||
|
|
@ -230,7 +230,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
else:
|
||||
action = np.array(self.action_space.sample())
|
||||
else:
|
||||
action, state = self.policy.predict(observation, state, mask, deterministic)
|
||||
action, state = self.policy.predict(observation, state, episode_start, deterministic)
|
||||
return action, state
|
||||
|
||||
def learn(
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.3.0
|
||||
1.3.1a3
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -63,7 +63,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.3.0",
|
||||
"stable_baselines3>=1.3.1a3",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import gym
|
|||
import numpy as np
|
||||
import pytest
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv
|
||||
from stable_baselines3.common.envs import SimpleMultiObsEnv
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
|
||||
|
||||
|
|
@ -224,8 +224,8 @@ def test_vec_normalize(model_class):
|
|||
Additional tests to check observation space support
|
||||
for GoalEnv and VecNormalize using MultiInputPolicy.
|
||||
"""
|
||||
env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not (model_class == QRDQN))])
|
||||
env = VecNormalize(env)
|
||||
env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(DummyDictEnv(use_discrete_actions=model_class == QRDQN), 100)])
|
||||
env = VecNormalize(env, norm_obs_keys=["vec"])
|
||||
|
||||
kwargs = {}
|
||||
n_steps = 256
|
||||
|
|
|
|||
Loading…
Reference in New Issue