Commit Graph

105 Commits

Author SHA1 Message Date
Antonin RAFFIN 89d0113037 Update ruff and SB3 dependencies (#232) 2024-05-06 14:20:28 +01:00
Antonin RAFFIN 7dd6c39fba Fix MaskablePPO type annotations (#233)
* Fix `sb3_contrib/common/maskable/callbacks.py` type annotations

* Fix `sb3_contrib/common/vec_env/async_eval.py` type annotations

* Fix `sb3_contrib/common/maskable/buffers.py` type hints

* Fix `sb3_contrib/common/maskable/distributions.py` type hints

* Fix `sb3_contrib/common/maskable/policies.py` type hints

* Fix `sb3_contrib/ppo_mask/ppo_mask.py` type hints

* Update changelog and fix type hints
2024-05-06 14:20:28 +01:00
Armand du Parc Locmaria cd31e89e26 Fix `train_freq` type annotation for TQC and QR-DQN (#229)
* fix train_freq type for tqc and qrdn

* fix typo

* Update changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2024-05-06 14:20:28 +01:00
Antonin RAFFIN 3f0c5088b3 Update QRDQN defaults (#225) 2024-05-06 14:20:28 +01:00
Rogério Júnior 1553b66ee4 Update `_process_sequence()` docstring (#219)
Small modification in the documentation of `_process_sequence` to reflect the correct order (`hidden_state`, `cell_state`) of `lstm_states`.

See order here: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
2024-05-06 14:20:28 +01:00
Antonin RAFFIN 94a5daab02 Update SB3 version (#217) 2024-05-06 14:20:28 +01:00
Antonin RAFFIN ebb74c44ec Release v2.2.0 (#216) 2024-05-06 14:20:28 +01:00
Antonin RAFFIN c965ba9d3b Remove PyType and upgrade to latest SB3 version (#215) 2024-05-06 14:20:28 +01:00
M. Ernestus 5e437fc4dc Add rollout_buffer_class to TRPO (#214)
* Add rollout_buffer_class and rollout_buffer_kwargs to TRPO

* Update requirements and changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2024-05-06 14:20:28 +01:00
Antonin RAFFIN 4d7ed004af Sync SB3 Contrib with SB3 (#213)
* Update RTD config

* Switch to ruff for sorting imports

* Evaluate falsy to truthy with not rather than `is False`

* Add `features_extractor` argument to maskable policy

* Add set_options for AsyncEval

* Doc fixes
2024-05-06 14:20:28 +01:00
Antonin RAFFIN cf4ed5fe65 Release v2.1.0 (#204)
* Release v2.1.0

* Fix mypy

* Fix warnings in tests
2024-05-06 14:20:28 +01:00
Paul Auerbach 4e99b74e90
Merge branch 'Stable-Baselines-Team:master' into master 2023-08-07 14:34:38 +02:00
Paul Auerbach fca2c6d490 Prepared files for merge request (minor cleanup) 2023-08-04 18:19:34 +02:00
Paul Auerbach 4a37f58259 Code style changes 2023-08-02 13:52:40 +02:00
Paul Auerbach 7711813dbb Reworked code to work whith more than 2 critic networks 2023-08-02 13:17:22 +02:00
Paul Auerbach 875b8bca0d Fixed bugs in that lead to wrong results, currently only working with 2
critics
2023-08-01 15:09:55 +02:00
PatrickHelm dfa23bdf9c
Bugfix/ppo mask stats window size (#199)
* fixes issue #198 regarding stats_window_size

* updates changelog

* updates test of stats_window_size

* updates test using maskable env

* removes print statement
2023-08-01 11:53:27 +02:00
Paul Auerbach a14ae69b6b Added first version of SAC Discrete, which is running but not learning
currently
2023-07-31 16:07:08 +02:00
Antonin RAFFIN 35f06254ba
Drop python 3.7, add 3.11 and update github templates (#194)
* Update issue templates

* Update supported python versions

* Update min PyTorch version

* Skip pytype for py311

* Update conda env for doc
2023-07-03 12:45:20 +02:00
Antonin RAFFIN de92025bb2
Prepare Release v2.0 (#192) 2023-06-23 13:10:17 +02:00
Antonin RAFFIN 6e1aba45e3
Update version and fix #188 (#190) 2023-06-07 16:51:42 +02:00
Antonin RAFFIN d467d7a844
Update AsyncEval seeding (#185)
* Update AsyncEval seeding

* Update random seed in tests
2023-05-20 10:57:31 +02:00
Antonin RAFFIN 86fb056fda
Update doc: switch from Gym to Gymnasium (#182)
* Update documentation to switch to Gymnasium

* Fix custom env type hints

* Fix type check
2023-05-10 11:40:40 +02:00
Antonin RAFFIN 21cc96cafd
Add Gymnasium support (#152)
* Add support for Gym 0.24

* Fixes for gym 0.24

* Fix for new reset signature

* Add tmp SB3 branch

* Fixes for gym 0.26

* Remove unused import

* Fix dependency

* Type annotations fixes

* Reformat

* Reformat with black 23

* Move to gymnasium

* Patch env if needed

* Fix types

* Fix CI

* Fixes for gymnasium

* Fix wrapper annotations

* Update version

* Fix type check

* Update QRDQN type hints and bug fix with multi envs

* Fix TQC type hints

* Fix TRPO type hints

* Additional fixes

* Update SB3 version

* Update issue templates and CI

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2023-04-14 13:52:07 +02:00
Antonin RAFFIN a84ad3aa7d
Release v1.8.0 (#173) 2023-04-08 15:50:04 +02:00
Jonas Reiher aacded79c5
Add stats window argument (#171)
* added missing tensorboard_log docstring

* added stats_window_size argument to all models

* changelog updated

* Update SB3 version

* fixed passing stats_window_size to parent

* added test of stats_window_size

---------

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2023-04-05 18:47:27 +02:00
Antonin RAFFIN ce115982aa
Fix QR-DQN type hints (#170) 2023-03-30 11:50:26 +02:00
Antonin RAFFIN b5fd6e65ba
Update SB3 and config (#167) 2023-03-20 12:35:37 +01:00
Antonin RAFFIN 1f9568b2da
Fix Atari Roms Download (#164) 2023-03-12 19:06:23 +01:00
Antonin RAFFIN 728c1c5b7f
Issue forms and pyproject.toml (#162)
* Issue forms and pyproject.toml

* [ci skip] Fix typos

* Fix isort config

* Use secret link to download atari roms

* Fix for mypy and update config

* Upgrade SB3 and fix warnings

* Fix doc build

* Update Makefile

* Lint first
2023-03-11 22:57:45 +01:00
Alex Pasquali 376d9551de
Update MaskablePPO docs (#150)
* MaskablePPO docs

Added a warning about possible crashes caused by chack_env in case of invalid actions.

* Reformat with black 23

* Rephrase note on action sampling

* Fix action noise

* Update changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2023-02-13 14:31:49 +01:00
Alex Pasquali 6bc8e426bf
Removed shared layers in mlp_extractor (#137)
* Removed shared layers in mlp_extractor

* Add ruff

* Update version and add warning

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2023-01-25 16:28:27 +01:00
Antonin RAFFIN 7bf9cf3f3a
Release v1.7.0 (#134) 2023-01-10 22:35:18 +01:00
Alex Pasquali b5aa9a47ce
Deprecation of shared layers in `mlp_extractor` (#133)
* Deprecation of shared layers in mlp_extractor

* Fix missing import

* Reformat and update tests

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2023-01-05 10:42:22 +01:00
Quentin Gallouédec 7c4a249fa4
Standardize the use of ``from gym import spaces`` (#131)
* Standardize from gym import spaces

* update changelog

* update issue template

* update version

* Update version
2023-01-02 15:35:00 +01:00
Quentin Gallouédec 9cf8b5076f
Construct tensors directly on GPUs (#128)
* `to(device)` to `device=device` and `float()` to `dtype=th.float32`

* Update changelog

* Fix type checking

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2022-12-23 00:44:25 +01:00
Alex Pasquali ab8684f469
[Feature] Non-shared features extractor in on-policy algorithms (#130)
* Modified sb3_contrib/common/maskable/policies.py

- Added support for non-shared features extractor in file sb3_contrib/common/maskable/policies.py
- updated changelog

* Modified sb3_contrib/common/recurrent/policies.py

* Modified sb3_contrib/qrdqn/policies.py and sb3_contrib/tqc/policies.py

* Updated test_cnn.py

* Upgrade SB3 version

* Revert changes in formatting

* Remove duplicate normalize_images

* Add test for image-like inputs

* Fixes and add more tests

* Update SB3 version

* Fix ARS warnings

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-12-23 00:23:45 +01:00
Quentin Gallouédec 6b23c6cfe3
Add `with_bias` parameter to `ARSPolicy` and fix `sb3_contrib/ars/policies.py` type hint (#122)
* Update contribution.md

* New loop struct to make mypy happy

* Update setup.cfg

* Update changelog

* fix squash_output = False in ARS policy

* Add with_bias parameter to ARSPolicy

* Make ARSLinearPolicy a special case of ARSPolicy

* Remove ars_policy from mypy exclude

* Update changelog

* Update SB3 version

* Fix to save ARS linear policy saved with sb3-contrib < 1.7.0

* Fix test

* Turn docstring into comment

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2022-12-12 13:22:09 +01:00
Zikang Xiong ddb3a1355e
Expose modules in `__init__.py` with `__all__` attribute (#124)
* expose modules in __init__.py with __all__ attribute

* Update version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-12-05 15:53:57 +01:00
Quentin Gallouédec b3e4ddd09a
Fix `sb3_contrib/common/recurrent/type_aliases.py` type hint (#121)
* Update setup.cfg

* Update changelog

* Update type aliases
2022-11-29 10:41:07 +01:00
Quentin Gallouédec ded9f65bfd
Fix `sb3_contrib/common/utils.py` type hint (#120)
* Update contribution.md

* New loop struct to make mypy happy

* Update setup.cfg

* Update changelog

* Add return statement in cg solver for max_iter=0

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-11-29 10:24:44 +01:00
Quentin Gallouédec 36aeae18b5
Fix `Self` return type (#116)
* Self hint for distributions

* ClassSelf to SelfClass
2022-11-22 13:12:35 +01:00
Antonin RAFFIN a9735b9f31
Fix reshape LSTM states (#112)
* Fix LSTM states reshape

* Fix warnings and update changelog

* Remove unused variable

* Fix runtime error when using n_lstm_layers > 1
2022-10-26 18:03:45 +02:00
Antonin RAFFIN c75ad7dd58
Remove deprecated features (#108)
* Remove deprecated features

* Upgrade SB3

* Fix tests
2022-10-11 13:04:18 +02:00
Antonin RAFFIN 52795a307e
Add progress bar argument (#107)
* Add progress bar argument

* Sort imports
2022-10-10 18:44:13 +02:00
Quentin Gallouédec e9c97948c8
Fixed the return type of ``.load()`` methods (#106)
* Fix return type for learn using TypeVar

* Update changelog

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2022-10-10 17:21:38 +02:00
Quentin Gallouédec dec7b5303a
Deprecate ``create_eval_env``, ``eval_env`` and ``eval_freq`` parameter (#105)
* Deprecate ``eval_env``, ``eval_freq```and ``create_eval_env``

* Update changelog

* Typo

* Raise deprecation warining in _setup_learn

* Upgrade to latest SB3 version and update changelog

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-10-10 17:12:40 +02:00
Antonin RAFFIN 2490468b11
Release v1.6.1 (#104) 2022-09-29 12:30:12 +02:00
Honglu Fan cad9034fdb
Handle batch norm in target update (#99)
* Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996

* Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996

* Copy running stats regardless of tau in QRDQN and TQC. See https://github.com/DLR-RM/stable-baselines3/issues/996

* roll back test_cnn.py
2022-08-27 12:31:00 +02:00
Quentin Gallouédec 7993b75781
Support `device="auto"`for buffers and set it as default value (#98)
* Default device for buffer is auto

* `device=auto` in ARS

* Undo ARS change

* Update changelog

* Update min SB3 version

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2022-08-24 09:48:18 +02:00