Commit Graph

97 Commits

Author SHA1 Message Date
Antonin RAFFIN cf4ed5fe65 Release v2.1.0 (#204)
* Release v2.1.0

* Fix mypy

* Fix warnings in tests
2024-05-06 14:20:28 +01:00
Paul Auerbach 4e99b74e90
Merge branch 'Stable-Baselines-Team:master' into master 2023-08-07 14:34:38 +02:00
Paul Auerbach d97dbc727c Added doc page for SACD 2023-08-07 14:03:12 +02:00
PatrickHelm dfa23bdf9c
Bugfix/ppo mask stats window size (#199)
* fixes issue #198 regarding stats_window_size

* updates changelog

* updates test of stats_window_size

* updates test using maskable env

* removes print statement
2023-08-01 11:53:27 +02:00
Antonin RAFFIN 35f06254ba
Drop python 3.7, add 3.11 and update github templates (#194)
* Update issue templates

* Update supported python versions

* Update min PyTorch version

* Skip pytype for py311

* Update conda env for doc
2023-07-03 12:45:20 +02:00
Antonin RAFFIN de92025bb2
Prepare Release v2.0 (#192) 2023-06-23 13:10:17 +02:00
Antonin RAFFIN 6e1aba45e3
Update version and fix #188 (#190) 2023-06-07 16:51:42 +02:00
Antonin RAFFIN d467d7a844
Update AsyncEval seeding (#185)
* Update AsyncEval seeding

* Update random seed in tests
2023-05-20 10:57:31 +02:00
Antonin RAFFIN 86fb056fda
Update doc: switch from Gym to Gymnasium (#182)
* Update documentation to switch to Gymnasium

* Fix custom env type hints

* Fix type check
2023-05-10 11:40:40 +02:00
Antonin RAFFIN 21cc96cafd
Add Gymnasium support (#152)
* Add support for Gym 0.24

* Fixes for gym 0.24

* Fix for new reset signature

* Add tmp SB3 branch

* Fixes for gym 0.26

* Remove unused import

* Fix dependency

* Type annotations fixes

* Reformat

* Reformat with black 23

* Move to gymnasium

* Patch env if needed

* Fix types

* Fix CI

* Fixes for gymnasium

* Fix wrapper annotations

* Update version

* Fix type check

* Update QRDQN type hints and bug fix with multi envs

* Fix TQC type hints

* Fix TRPO type hints

* Additional fixes

* Update SB3 version

* Update issue templates and CI

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2023-04-14 13:52:07 +02:00
Antonin RAFFIN a84ad3aa7d
Release v1.8.0 (#173) 2023-04-08 15:50:04 +02:00
Jonas Reiher aacded79c5
Add stats window argument (#171)
* added missing tensorboard_log docstring

* added stats_window_size argument to all models

* changelog updated

* Update SB3 version

* fixed passing stats_window_size to parent

* added test of stats_window_size

---------

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2023-04-05 18:47:27 +02:00
Antonin RAFFIN ce115982aa
Fix QR-DQN type hints (#170) 2023-03-30 11:50:26 +02:00
Antonin RAFFIN b5fd6e65ba
Update SB3 and config (#167) 2023-03-20 12:35:37 +01:00
Antonin RAFFIN 1f9568b2da
Fix Atari Roms Download (#164) 2023-03-12 19:06:23 +01:00
Antonin RAFFIN 728c1c5b7f
Issue forms and pyproject.toml (#162)
* Issue forms and pyproject.toml

* [ci skip] Fix typos

* Fix isort config

* Use secret link to download atari roms

* Fix for mypy and update config

* Upgrade SB3 and fix warnings

* Fix doc build

* Update Makefile

* Lint first
2023-03-11 22:57:45 +01:00
Alex Pasquali 376d9551de
Update MaskablePPO docs (#150)
* MaskablePPO docs

Added a warning about possible crashes caused by chack_env in case of invalid actions.

* Reformat with black 23

* Rephrase note on action sampling

* Fix action noise

* Update changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2023-02-13 14:31:49 +01:00
Alex Pasquali 6bc8e426bf
Removed shared layers in mlp_extractor (#137)
* Removed shared layers in mlp_extractor

* Add ruff

* Update version and add warning

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2023-01-25 16:28:27 +01:00
Antonin RAFFIN 7bf9cf3f3a
Release v1.7.0 (#134) 2023-01-10 22:35:18 +01:00
Alex Pasquali b5aa9a47ce
Deprecation of shared layers in `mlp_extractor` (#133)
* Deprecation of shared layers in mlp_extractor

* Fix missing import

* Reformat and update tests

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2023-01-05 10:42:22 +01:00
Quentin Gallouédec 7c4a249fa4
Standardize the use of ``from gym import spaces`` (#131)
* Standardize from gym import spaces

* update changelog

* update issue template

* update version

* Update version
2023-01-02 15:35:00 +01:00
Quentin Gallouédec c9bd045d5c
Add support for python3.10 (#129)
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2022-12-23 00:54:35 +01:00
Quentin Gallouédec 9cf8b5076f
Construct tensors directly on GPUs (#128)
* `to(device)` to `device=device` and `float()` to `dtype=th.float32`

* Update changelog

* Fix type checking

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2022-12-23 00:44:25 +01:00
Alex Pasquali ab8684f469
[Feature] Non-shared features extractor in on-policy algorithms (#130)
* Modified sb3_contrib/common/maskable/policies.py

- Added support for non-shared features extractor in file sb3_contrib/common/maskable/policies.py
- updated changelog

* Modified sb3_contrib/common/recurrent/policies.py

* Modified sb3_contrib/qrdqn/policies.py and sb3_contrib/tqc/policies.py

* Updated test_cnn.py

* Upgrade SB3 version

* Revert changes in formatting

* Remove duplicate normalize_images

* Add test for image-like inputs

* Fixes and add more tests

* Update SB3 version

* Fix ARS warnings

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-12-23 00:23:45 +01:00
Quentin Gallouédec 6b23c6cfe3
Add `with_bias` parameter to `ARSPolicy` and fix `sb3_contrib/ars/policies.py` type hint (#122)
* Update contribution.md

* New loop struct to make mypy happy

* Update setup.cfg

* Update changelog

* fix squash_output = False in ARS policy

* Add with_bias parameter to ARSPolicy

* Make ARSLinearPolicy a special case of ARSPolicy

* Remove ars_policy from mypy exclude

* Update changelog

* Update SB3 version

* Fix to save ARS linear policy saved with sb3-contrib < 1.7.0

* Fix test

* Turn docstring into comment

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2022-12-12 13:22:09 +01:00
Quentin Gallouédec 9a728513da
Upgrade CI/github-actions (#125)
* Update ci.yml

* Update changelog.rst
2022-12-09 12:30:22 +01:00
Zikang Xiong ddb3a1355e
Expose modules in `__init__.py` with `__all__` attribute (#124)
* expose modules in __init__.py with __all__ attribute

* Update version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-12-05 15:53:57 +01:00
Quentin Gallouédec b3e4ddd09a
Fix `sb3_contrib/common/recurrent/type_aliases.py` type hint (#121)
* Update setup.cfg

* Update changelog

* Update type aliases
2022-11-29 10:41:07 +01:00
Quentin Gallouédec ded9f65bfd
Fix `sb3_contrib/common/utils.py` type hint (#120)
* Update contribution.md

* New loop struct to make mypy happy

* Update setup.cfg

* Update changelog

* Add return statement in cg solver for max_iter=0

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-11-29 10:24:44 +01:00
Quentin Gallouédec 3d28d1e5de
Mypy type checking (#119)
* Update Makefile

* Update changelog

* gitignore mypy cache

* mypy config

* Add color to mypy output

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-11-28 23:00:31 +01:00
Antonin Raffin 703fd2dd68
Fix for new flake8 version 2022-11-25 18:52:15 +01:00
Antonin RAFFIN a9735b9f31
Fix reshape LSTM states (#112)
* Fix LSTM states reshape

* Fix warnings and update changelog

* Remove unused variable

* Fix runtime error when using n_lstm_layers > 1
2022-10-26 18:03:45 +02:00
Antonin RAFFIN c75ad7dd58
Remove deprecated features (#108)
* Remove deprecated features

* Upgrade SB3

* Fix tests
2022-10-11 13:04:18 +02:00
Antonin RAFFIN 52795a307e
Add progress bar argument (#107)
* Add progress bar argument

* Sort imports
2022-10-10 18:44:13 +02:00
Quentin Gallouédec e9c97948c8
Fixed the return type of ``.load()`` methods (#106)
* Fix return type for learn using TypeVar

* Update changelog

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2022-10-10 17:21:38 +02:00
Quentin Gallouédec dec7b5303a
Deprecate ``create_eval_env``, ``eval_env`` and ``eval_freq`` parameter (#105)
* Deprecate ``eval_env``, ``eval_freq```and ``create_eval_env``

* Update changelog

* Typo

* Raise deprecation warining in _setup_learn

* Upgrade to latest SB3 version and update changelog

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-10-10 17:12:40 +02:00
Antonin RAFFIN 2490468b11
Release v1.6.1 (#104) 2022-09-29 12:30:12 +02:00
Honglu Fan cad9034fdb
Handle batch norm in target update (#99)
* Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996

* Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996

* Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996

* roll back test_cnn.py
2022-08-27 12:31:00 +02:00
Quentin Gallouédec 7993b75781
Support `device="auto"`for buffers and set it as default value (#98)
* Default device for buffer is auto

* `device=auto` in ARS

* Undo ARS change

* Update changelog

* Update min SB3 version

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2022-08-24 09:48:18 +02:00
Burak Demirbilek 049f5a16e9
Fixed missing verbose parameter passing (#97) 2022-08-16 15:54:46 +02:00
CppMaster eb48fec638
Maskable eval callback call callback fix (#93)
* call correctly both self.callback_on_new_best and self.callback - similar as in EvalCallback

* MaskableEvalCallback - updated sync_envs_normalization handling

* MaskableEvalCallback - updated sync_envs_normalization handling - test
MaskablePPO - register policies (tests fails otherwise)

* MaskableEvalCallback - updated docstring

* updated changelog.rst

* changes for stable-baselines3==1.6.0

* version range

* suggested changes

* Reformat and update version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-07-27 19:52:07 +02:00
Max Lodel fc68af8841
Fixed shared_lstm argument in CNN and MultiInput Policies for RecurrentPPO (#90)
* fixed shared_lstm parameter in CNN and MultiInput Policies

* updated tests

* changelog

* Fix FPS for recurrent PPO

* Fix import

* Update changelog

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2022-07-26 00:27:17 +02:00
Adam Gleave 7e687ac47c
Use higher resolution time_ns() and avoid division by zero (#91)
* Use higher resolution time_ns and add max to avoid division by zero

* Add missing imports

* Update changelog
2022-07-25 23:12:20 +02:00
Quentin Gallouédec 3cbd2429be
Fix returned type in predict (#88)
* actions[0] -> actions.squeeze(0)

* Update changelog

* Update changelog

* Update version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-07-18 11:49:03 +02:00
Antonin Raffin c9d621b816
Use ICRL url for PPO blog post 2022-07-12 23:49:26 +02:00
Antonin Raffin 5ec9e01b44
Update changelog 2022-07-12 23:15:14 +02:00
Antonin RAFFIN 087951d34b
Release v1.6.0 and bug fix for TRPO (#84) 2022-07-12 23:12:24 +02:00
Antonin RAFFIN db4c0114d0
Update default TQC net arch when using NatureCnn (#79)
* Update default TQC net arch when using NatureCnn

* Bump version
2022-06-18 10:53:29 +02:00
rnederstigt bfa86ce4fe
Fix masked quantities in RecurrentPPO (#78)
* Ignore masked indexes when calculating the loss functions
2022-06-13 16:00:40 +02:00
Antonin RAFFIN 75b2de1399
Recurrent PPO (#53)
* Running (not working yet) version of recurrent PPO

* Fixes for multi envs

* Save WIP, rework the sampling

* Add Box support

* Fix sample order

* Being cleanup, code is broken (again)

* First working version (no shared lstm)

* Start cleanup

* Try rnn with value function

* Re-enable batch size

* Deactivate vf rnn

* Allow any batch size

* Add support for evaluation

* Add CNN support

* Fix start of sequence

* Allow shared LSTM

* Rename mask to episode_start

* Fix type hint

* Enable LSTM for critic

* Clean code

* Fix for CNN LSTM

* Fix sampling with n_layers > 1

* Add std logger

* Update wording

* Rename and add dict obs support

* Fixes for dict obs support

* Do not run slow tests

* Fix doc

* Update recurrent PPO example

* Update README

* Use Pendulum-v1 for tests

* Fix image env

* Speedup LSTM forward pass (#63)

* added more efficient lstm implementation

* Rename and add comment

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>

* Fixes

* Remove OpenAI sampling and improve coverage

* Sync with SB3 PPO

* Pass state shape and allow lstm kwargs

* Update tests

* Add masking for padded sequences

* Update default in perf test

* Remove TODO, mask is now working

* Add helper to remove duplicated code, remove hack for padding

* Enable LSTM critic and raise threshold for cartpole with no vel

* Fix tests

* Update doc and tests

* Doc fix

* Fix for new Sphinx version

* Fix doc note

* Switch to batch first, no more additional swap

* Add comments and mask entropy loss

Co-authored-by: Neville Walo <43504521+Walon1998@users.noreply.github.com>
2022-05-30 04:31:12 +02:00