Merge pull request #1 from Stable-Baselines-Team/contrib
Add contributing guides
This commit is contained in:
commit
3fe0a0c2f2
|
|
@ -32,6 +32,7 @@ for both code and stack traces.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from stable_baselines3 import ...
|
from stable_baselines3 import ...
|
||||||
|
from sb3_contrib import ...
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -43,6 +44,7 @@ Traceback (most recent call last): File ...
|
||||||
**System Info**
|
**System Info**
|
||||||
Describe the characteristic of your environment:
|
Describe the characteristic of your environment:
|
||||||
* Describe how the library was installed (pip, docker, source, ...)
|
* Describe how the library was installed (pip, docker, source, ...)
|
||||||
|
* Stable-Baselines3 and sb3-contrib versions
|
||||||
* GPU models and configuration
|
* GPU models and configuration
|
||||||
* Python version
|
* Python version
|
||||||
* PyTorch version
|
* PyTorch version
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,11 @@
|
||||||
<!--- Provide a general summary of your changes in the Title above -->
|
<!--- Provide a general summary of your changes in the Title above (e.g. feature name, fix of a bug)-->
|
||||||
|
|
||||||
## Description
|
## Description
|
||||||
<!--- Describe your changes in detail -->
|
<!--- Describe your changes in detail, including links to any sources for new features -->
|
||||||
|
|
||||||
## Motivation and Context
|
## Context
|
||||||
<!--- Why is this change required? What problem does it solve? -->
|
<!--- Link the related issue here. You can use the syntax `closes #100` if this solves the issue #100 -->
|
||||||
<!--- If it fixes an open issue, please link to the issue here. -->
|
- [ ] I have raised an issue to propose this change ([required](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/CONTRIBUTING.md))
|
||||||
<!--- You can use the syntax `closes #100` if this solves the issue #100 -->
|
|
||||||
- [ ] I have raised an issue to propose this change ([required](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) for new features and bug fixes)
|
|
||||||
|
|
||||||
## Types of changes
|
## Types of changes
|
||||||
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->
|
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->
|
||||||
|
|
@ -19,11 +17,13 @@
|
||||||
## Checklist:
|
## Checklist:
|
||||||
<!--- Go over all the following points, and put an `x` in all the boxes that apply. -->
|
<!--- Go over all the following points, and put an `x` in all the boxes that apply. -->
|
||||||
<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
|
<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
|
||||||
- [ ] I've read the [CONTRIBUTION](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) guide (**required**)
|
- [ ] I've read the [CONTRIBUTION](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/CONTRIBUTING.md) guide (**required**)
|
||||||
- [ ] I have updated the changelog accordingly (**required**).
|
- [ ] The functionality/performance matches that of the source (**required** for new training algorithms or training-related features).
|
||||||
- [ ] My change requires a change to the documentation.
|
|
||||||
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
|
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
|
||||||
|
- [ ] I have included an example of using the feature (*required for new features*).
|
||||||
|
- [ ] I have included baseline results (**required** for new training algorithms or training-related features).
|
||||||
- [ ] I have updated the documentation accordingly.
|
- [ ] I have updated the documentation accordingly.
|
||||||
|
- [ ] I have updated the changelog accordingly (**required**).
|
||||||
- [ ] I have reformatted the code using `make format` (**required**)
|
- [ ] I have reformatted the code using `make format` (**required**)
|
||||||
- [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**)
|
- [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**)
|
||||||
- [ ] I have ensured `make pytest` and `make type` both pass. (**required**)
|
- [ ] I have ensured `make pytest` and `make type` both pass. (**required**)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
## Release 0.9.0a2 (WIP)
|
## Release 0.10.0a0 (WIP)
|
||||||
|
|
||||||
### Breaking Changes
|
### Breaking Changes
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,177 @@
|
||||||
|
## Contributing to Stable-Baselines3 - Contrib
|
||||||
|
|
||||||
|
This contrib repository is designed for experimental implementations of various
|
||||||
|
parts of reinforcement training so that others may make use of them. This includes full
|
||||||
|
RL algorithms, different tools (e.g. new environment wrappers,
|
||||||
|
callbacks) and extending algorithms implemented in stable-baselines3.
|
||||||
|
|
||||||
|
**Before opening a pull request**, open an issue discussing the contribution.
|
||||||
|
Once we agree that the plan looks good, go ahead and implement it.
|
||||||
|
|
||||||
|
Contributions and review focuses on following three parts:
|
||||||
|
1) **Implementation quality**
|
||||||
|
- Performance of the RL algorithms should match the one reported by the original authors (if applicable).
|
||||||
|
- This is ensured by including a code that replicates an experiment from the original
|
||||||
|
paper or from an established codebase (e.g. the code from authors), as well as
|
||||||
|
a test to check that implementation works on program level (does not crash).
|
||||||
|
2) Documentation
|
||||||
|
- Documentation quality should match that of stable-baselines3, with each feature covered
|
||||||
|
in the documentation, in-code documentation to clarify the flow
|
||||||
|
of logic and report of the expected results, where applicable.
|
||||||
|
3) Consistency with stable-baselines3
|
||||||
|
- To ease readability, all contributions need to follow the code style (see below) and
|
||||||
|
idioms used in stable-baselines3.
|
||||||
|
|
||||||
|
The implementation quality is a strict requirements with little room for changes, because
|
||||||
|
otherwise the implementation can do more harm than good (wrong results). Parts two and three
|
||||||
|
are taken into account during review but being a repository for more experimental code, these
|
||||||
|
are not very strict.
|
||||||
|
|
||||||
|
See [issues with "experimental" tag](https://github.com/DLR-RM/stable-baselines3/issues?q=is%3Aissue+is%3Aopen+label%3Aexperimental)
|
||||||
|
for suggestions of the community for new possible features to include in contrib.
|
||||||
|
|
||||||
|
## How to implement your suggestion
|
||||||
|
|
||||||
|
Implement your feature/suggestion/algorithm in following ways, using the first one that applies:
|
||||||
|
1) Environment wrapper: This can be used with any algorithm and even outside stable-baselines3.
|
||||||
|
Place code for these under `sb3_contrib/common/wrappers` directory.
|
||||||
|
2) [Custom callback](https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html).
|
||||||
|
Place code under `sb3_contrib/common/callbacks` directory.
|
||||||
|
3) Following the style/naming of `common` files in the stable-baseline3. If your suggestion is a specific network architecture
|
||||||
|
for feature extraction from images, place this in `sb3_contrib/common/torch_layers.py`, for example.
|
||||||
|
4) A new learning algorithm. This is the last resort but most applicable solution.
|
||||||
|
Even if your suggestion is a (trivial) modification to an existing algorithm, create a new algorithm for it
|
||||||
|
unless otherwise discussed (which inherits the base algorithm). The algorithm should use same API as
|
||||||
|
stable-baselines3 algorithms (e.g. `learn`, `load`), and the code should be placed under
|
||||||
|
`sb3_contrib/[algorithm_name]` directory.
|
||||||
|
|
||||||
|
Look over stable-baselines3 code for the general naming of variables and try to keep this style.
|
||||||
|
|
||||||
|
If algorithm you are implementing involves more complex/uncommon equations, comment each part of these
|
||||||
|
calculations with references to the parts in paper.
|
||||||
|
|
||||||
|
## Pull Request (PR) and review
|
||||||
|
|
||||||
|
Before proposing a PR, please open an issue, where the feature will be discussed.
|
||||||
|
This prevent from duplicated PR to be proposed and also ease the code review process.
|
||||||
|
|
||||||
|
Each PR need to be reviewed and accepted by at least one of the maintainers.
|
||||||
|
A PR must pass the Continuous Integration tests to be merged with the master branch.
|
||||||
|
|
||||||
|
Along with the code, PR **must** include the following:
|
||||||
|
1) Update to documentation to include a description of the feature. If feature is a simple tool (e.g. wrapper, callback),
|
||||||
|
this goes under respective pages in documentation. If full training algorithm, this goes under a new page with template below
|
||||||
|
(`docs/modules/[algo_name]`).
|
||||||
|
2) If a training algorithm/improvement: results of a replicated experiment from the original paper in the documentation,
|
||||||
|
**which must match the results from authors** unless solid arguments can be provided why they did not match.
|
||||||
|
3) If above holds: The **exact** code to run the replicated experiment (i.e. it should produce the above results), and inside the
|
||||||
|
code information about the environment used (Python version, library versions, OS, hardware information). If small enough,
|
||||||
|
include this in the documentation. If applicable, use [rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo) to
|
||||||
|
run the agent performance comparison experiments (fork repository, implement experiment in a new branch and share link to
|
||||||
|
that branch). If above do not apply, create new code to replicate the experiment and include link to it.
|
||||||
|
4) Updated tests in `tests/test_run.py` and `tests/test_save_load.py` to test that features run as expected and serialize
|
||||||
|
correctly. This this is **not** for testing e.g. training performance of a learning algorithm, and
|
||||||
|
should be relatively quick to run.
|
||||||
|
|
||||||
|
Below is a template for documentation for full RL algorithms.
|
||||||
|
|
||||||
|
```rst
|
||||||
|
[Feature/Algorithm name]
|
||||||
|
========================
|
||||||
|
|
||||||
|
- Non-abbreviated name and/or one-sentence description of the method.
|
||||||
|
- Link and reference to the original publications the present the feature, or other established source(s).
|
||||||
|
- Links to any codebases that were used for reference (e.g. authors' implementations)
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
|
||||||
|
A minimal example on how to use the feature (full, runnable code).
|
||||||
|
|
||||||
|
Results
|
||||||
|
-------
|
||||||
|
|
||||||
|
A description and comparison of results (e.g. how the change improved results over the non-changed algorithm), if
|
||||||
|
applicable.
|
||||||
|
|
||||||
|
Include the expected results from the work that originally proposed the method (e.g. original paper).
|
||||||
|
|
||||||
|
Include the code to replicate these results or a link to repository/branch where the code can be found.
|
||||||
|
Use `rl-baselines3-zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ if possible, fork it, create a new branch
|
||||||
|
and share the code to replicate results there.
|
||||||
|
|
||||||
|
Comments
|
||||||
|
--------
|
||||||
|
|
||||||
|
Comments regarding the implementation, e.g. missing parts, uncertain parts, differences
|
||||||
|
to the original implementation.
|
||||||
|
````
|
||||||
|
|
||||||
|
If you are not familiar with creating a Pull Request, here are some guides:
|
||||||
|
- http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request
|
||||||
|
- https://help.github.com/articles/creating-a-pull-request/
|
||||||
|
|
||||||
|
|
||||||
|
## Codestyle
|
||||||
|
|
||||||
|
We are using [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
|
||||||
|
|
||||||
|
**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.
|
||||||
|
|
||||||
|
Please document each function/method and [type](https://google.github.io/pytype/user_guide.html) them using the following template:
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
def my_function(arg1: type1, arg2: type2) -> returntype:
|
||||||
|
"""
|
||||||
|
Short description of the function.
|
||||||
|
|
||||||
|
:param arg1: describe what is arg1
|
||||||
|
:param arg2: describe what is arg2
|
||||||
|
:return: describe what is returned
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
return my_variable
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tests
|
||||||
|
|
||||||
|
All new features and algorithms must add tests in the `tests/` folder ensuring that everything works fine (on program level).
|
||||||
|
We use [pytest](https://pytest.org/).
|
||||||
|
Also, when a bug fix is proposed, tests should be added to avoid regression.
|
||||||
|
|
||||||
|
To run tests with `pytest`:
|
||||||
|
|
||||||
|
```
|
||||||
|
make pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
Type checking with `pytype`:
|
||||||
|
|
||||||
|
```
|
||||||
|
make type
|
||||||
|
```
|
||||||
|
|
||||||
|
Codestyle check with `black`, `isort` and `flake8`:
|
||||||
|
|
||||||
|
```
|
||||||
|
make check-codestyle
|
||||||
|
make lint
|
||||||
|
```
|
||||||
|
|
||||||
|
To run `pytype`, `format` and `lint` in one command:
|
||||||
|
```
|
||||||
|
make commit-checks
|
||||||
|
```
|
||||||
|
|
||||||
|
Build the documentation:
|
||||||
|
|
||||||
|
```
|
||||||
|
make doc
|
||||||
|
```
|
||||||
|
|
||||||
|
## Changelog and Documentation
|
||||||
|
|
||||||
|
Please do not forget to update the changelog (`CHANGELOG.rst`).
|
||||||
|
|
||||||
|
Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one.
|
||||||
2
Makefile
2
Makefile
|
|
@ -5,7 +5,7 @@ pytest:
|
||||||
./scripts/run_tests.sh
|
./scripts/run_tests.sh
|
||||||
|
|
||||||
type:
|
type:
|
||||||
pytype
|
pytype -j auto
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
# stop the build if there are Python syntax errors or undefined names
|
# stop the build if there are Python syntax errors or undefined names
|
||||||
|
|
|
||||||
46
README.md
46
README.md
|
|
@ -1,24 +1,56 @@
|
||||||
|
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>
|
||||||
|
|
||||||
[](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/actions) [](https://github.com/psf/black)
|
[](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/actions) [](https://github.com/psf/black)
|
||||||
|
|
||||||
# Stable-Baselines3 - Contrib
|
# Stable-Baselines3 - Contrib (SB3-Contrib)
|
||||||
|
|
||||||
Contrib package for [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) - Experimental code
|
Contrib package for [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) - Experimental reinforcement learning (RL) code.
|
||||||
|
"sb3-contrib" for short.
|
||||||
|
|
||||||
**You need the master version of Stable Baselines3**:
|
### What is SB3-Contrib?
|
||||||
```
|
|
||||||
pip install git+https://github.com/DLR-RM/stable-baselines3
|
|
||||||
```
|
|
||||||
|
|
||||||
Implemented:
|
A place for RL algorithms and tools that are considered experimental, e.g. implementations of the latest publications. Goal is to keep the simplicity, documentation and style of stable-baselines3 but for less matured implementations.
|
||||||
|
|
||||||
|
### Why create this repository?
|
||||||
|
|
||||||
|
Over the span of stable-baselines and stable-baselines3, the community has been eager to contribute in form of better logging utilities, environment wrappers, extended support (e.g. different action spaces) and learning algorithms.
|
||||||
|
|
||||||
|
However sometimes these utilities were too niche to be considered for stable-baselines or
|
||||||
|
proved to be too difficult to integrate well into existing code without a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful goes! We hope this allows to extend the known quality of stable-baselines style and documentation beyond the relatively small scope of utilities of the main repository.
|
||||||
|
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
See documentation for the full list of included features.
|
||||||
|
|
||||||
|
**RL Algorithms**:
|
||||||
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
||||||
|
|
||||||
|
|
||||||
|
<!-- TODO: uncomment when the repo is public -->
|
||||||
|
<!-- ## Documentation
|
||||||
|
|
||||||
|
Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/) -->
|
||||||
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
|
**Note:** You need the `master` version of [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3/).
|
||||||
|
|
||||||
|
To install Stable Baselines3 `master` version:
|
||||||
|
```
|
||||||
|
pip install git+https://github.com/DLR-RM/stable-baselines3
|
||||||
|
```
|
||||||
|
|
||||||
|
Install Stable Baselines3 - Contrib using pip:
|
||||||
```
|
```
|
||||||
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## How To Contribute
|
||||||
|
|
||||||
|
If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md) guide first.
|
||||||
|
|
||||||
|
|
||||||
## Citing the Project
|
## Citing the Project
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
## Stable Baselines3 Documentation
|
## Stable Baselines3 Contrib Documentation
|
||||||
|
|
||||||
This folder contains documentation for the RL baselines.
|
This folder contains documentation for the RL baselines contribution repository.
|
||||||
|
|
||||||
|
|
||||||
### Build the Documentation
|
### Build the Documentation
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
Welcome to Stable Baselines3 Contrib docs!
|
Welcome to Stable Baselines3 Contrib docs!
|
||||||
==========================================
|
==========================================
|
||||||
|
|
||||||
Contrib package for `Stable Baselines3 <https://github.com/DLR-RM/stable-baselines3>`_ - Experimental code.
|
Contrib package for `Stable Baselines3 (SB3) <https://github.com/DLR-RM/stable-baselines3>`_ - Experimental code.
|
||||||
|
|
||||||
|
|
||||||
Github repository: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
Github repository: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
||||||
|
|
@ -64,11 +64,7 @@ To cite this project in publications:
|
||||||
Contributing
|
Contributing
|
||||||
------------
|
------------
|
||||||
|
|
||||||
To any interested in making the rl baselines better, there are still some improvements
|
If you want to contribute, please read `CONTRIBUTING.md <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/CONTRIBUTING.md>`_ first.
|
||||||
that need to be done.
|
|
||||||
You can check issues in the `repo <https://github.com/DLR-RM/stable-baselines3/issues>`_.
|
|
||||||
|
|
||||||
If you want to contribute, please read `CONTRIBUTING.md <https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md>`_ first.
|
|
||||||
|
|
||||||
Indices and tables
|
Indices and tables
|
||||||
-------------------
|
-------------------
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,8 @@ Others:
|
||||||
|
|
||||||
Documentation:
|
Documentation:
|
||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
- Added initial documentation
|
||||||
|
- Added contribution guide and related PR templates
|
||||||
|
|
||||||
|
|
||||||
Maintainers
|
Maintainers
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ def test_tqc(ent_coef):
|
||||||
create_eval_env=True,
|
create_eval_env=True,
|
||||||
ent_coef=ent_coef,
|
ent_coef=ent_coef,
|
||||||
)
|
)
|
||||||
model.learn(total_timesteps=500, eval_freq=250)
|
model.learn(total_timesteps=300, eval_freq=250)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_critics", [1, 3])
|
@pytest.mark.parametrize("n_critics", [1, 3])
|
||||||
|
|
@ -23,7 +23,7 @@ def test_n_critics(n_critics):
|
||||||
model = TQC(
|
model = TQC(
|
||||||
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1
|
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1
|
||||||
)
|
)
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
|
|
||||||
def test_sde():
|
def test_sde():
|
||||||
|
|
@ -35,6 +35,6 @@ def test_sde():
|
||||||
learning_starts=100,
|
learning_starts=100,
|
||||||
verbose=1,
|
verbose=1,
|
||||||
)
|
)
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=300)
|
||||||
model.policy.reset_noise()
|
model.policy.reset_noise()
|
||||||
model.policy.actor.get_std()
|
model.policy.actor.get_std()
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ def test_save_load(tmp_path, model_class):
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
env.reset()
|
env.reset()
|
||||||
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||||
|
|
@ -152,7 +152,7 @@ def test_save_load(tmp_path, model_class):
|
||||||
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||||
|
|
||||||
# check if learn still works
|
# check if learn still works
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
del model
|
del model
|
||||||
|
|
||||||
|
|
@ -224,7 +224,7 @@ def test_save_load_replay_buffer(tmp_path, model_class):
|
||||||
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
|
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
|
||||||
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
|
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
|
||||||
model = model_class("MlpPolicy", select_env(model_class), buffer_size=1000)
|
model = model_class("MlpPolicy", select_env(model_class), buffer_size=1000)
|
||||||
model.learn(500)
|
model.learn(300)
|
||||||
old_replay_buffer = deepcopy(model.replay_buffer)
|
old_replay_buffer = deepcopy(model.replay_buffer)
|
||||||
model.save_replay_buffer(path)
|
model.save_replay_buffer(path)
|
||||||
model.replay_buffer = None
|
model.replay_buffer = None
|
||||||
|
|
@ -268,7 +268,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
|
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
|
||||||
model.learn(total_timesteps=500)
|
model.learn(total_timesteps=300)
|
||||||
|
|
||||||
env.reset()
|
env.reset()
|
||||||
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue