Commit Graph

14 Commits

Author SHA1 Message Date
Antonin RAFFIN 4d7ed004af Sync SB3 Contrib with SB3 (#213)
* Update RTD config

* Switch to ruff for sorting imports

* Evaluate falsy to truthy with not rather than `is False`

* Add `features_extractor` argument to maskable policy

* Add set_options for AsyncEval

* Doc fixes
2024-05-06 14:20:28 +01:00
Antonin RAFFIN 21cc96cafd
Add Gymnasium support (#152)
* Add support for Gym 0.24

* Fixes for gym 0.24

* Fix for new reset signature

* Add tmp SB3 branch

* Fixes for gym 0.26

* Remove unused import

* Fix dependency

* Type annotations fixes

* Reformat

* Reformat with black 23

* Move to gymnasium

* Patch env if needed

* Fix types

* Fix CI

* Fixes for gymnasium

* Fix wrapper annotations

* Update version

* Fix type check

* Update QRDQN type hints and bug fix with multi envs

* Fix TQC type hints

* Fix TRPO type hints

* Additional fixes

* Update SB3 version

* Update issue templates and CI

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2023-04-14 13:52:07 +02:00
Antonin RAFFIN a84ad3aa7d
Release v1.8.0 (#173) 2023-04-08 15:50:04 +02:00
Alex Pasquali 376d9551de
Update MaskablePPO docs (#150)
* MaskablePPO docs

Added a warning about possible crashes caused by chack_env in case of invalid actions.

* Reformat with black 23

* Rephrase note on action sampling

* Fix action noise

* Update changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2023-02-13 14:31:49 +01:00
Alex Pasquali 6bc8e426bf
Removed shared layers in mlp_extractor (#137)
* Removed shared layers in mlp_extractor

* Add ruff

* Update version and add warning

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2023-01-25 16:28:27 +01:00
Alex Pasquali b5aa9a47ce
Deprecation of shared layers in `mlp_extractor` (#133)
* Deprecation of shared layers in mlp_extractor

* Fix missing import

* Reformat and update tests

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2023-01-05 10:42:22 +01:00
Quentin Gallouédec 7c4a249fa4
Standardize the use of ``from gym import spaces`` (#131)
* Standardize from gym import spaces

* update changelog

* update issue template

* update version

* Update version
2023-01-02 15:35:00 +01:00
Alex Pasquali ab8684f469
[Feature] Non-shared features extractor in on-policy algorithms (#130)
* Modified sb3_contrib/common/maskable/policies.py

- Added support for non-shared features extractor in file sb3_contrib/common/maskable/policies.py
- updated changelog

* Modified sb3_contrib/common/recurrent/policies.py

* Modified sb3_contrib/qrdqn/policies.py and sb3_contrib/tqc/policies.py

* Updated test_cnn.py

* Upgrade SB3 version

* Revert changes in formatting

* Remove duplicate normalize_images

* Add test for image-like inputs

* Fixes and add more tests

* Update SB3 version

* Fix ARS warnings

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-12-23 00:23:45 +01:00
Quentin Gallouédec 3cbd2429be
Fix returned type in predict (#88)
* actions[0] -> actions.squeeze(0)

* Update changelog

* Update changelog

* Update version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-07-18 11:49:03 +02:00
Antonin RAFFIN 75b2de1399
Recurrent PPO (#53)
* Running (not working yet) version of recurrent PPO

* Fixes for multi envs

* Save WIP, rework the sampling

* Add Box support

* Fix sample order

* Being cleanup, code is broken (again)

* First working version (no shared lstm)

* Start cleanup

* Try rnn with value function

* Re-enable batch size

* Deactivate vf rnn

* Allow any batch size

* Add support for evaluation

* Add CNN support

* Fix start of sequence

* Allow shared LSTM

* Rename mask to episode_start

* Fix type hint

* Enable LSTM for critic

* Clean code

* Fix for CNN LSTM

* Fix sampling with n_layers > 1

* Add std logger

* Update wording

* Rename and add dict obs support

* Fixes for dict obs support

* Do not run slow tests

* Fix doc

* Update recurrent PPO example

* Update README

* Use Pendulum-v1 for tests

* Fix image env

* Speedup LSTM forward pass (#63)

* added more efficient lstm implementation

* Rename and add comment

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>

* Fixes

* Remove OpenAI sampling and improve coverage

* Sync with SB3 PPO

* Pass state shape and allow lstm kwargs

* Update tests

* Add masking for padded sequences

* Update default in perf test

* Remove TODO, mask is now working

* Add helper to remove duplicated code, remove hack for padding

* Enable LSTM critic and raise threshold for cartpole with no vel

* Fix tests

* Update doc and tests

* Doc fix

* Fix for new Sphinx version

* Fix doc note

* Switch to batch first, no more additional swap

* Add comments and mask entropy loss

Co-authored-by: Neville Walo <43504521+Walon1998@users.noreply.github.com>
2022-05-30 04:31:12 +02:00
Antonin RAFFIN bec00386d1
Upgrade to python 3.7+ syntax (#69)
* Upgrade to python 3.7+ syntax

* Switch to PyTorch 1.11
2022-04-25 13:02:07 +02:00
Antonin RAFFIN b44689b0ea
Update Maskable PPO to match SB3 PPO + improve coverage (#56) 2021-12-10 12:48:19 +01:00
Antonin RAFFIN cd0a5e516f
Update citation (#54)
* Update citation

* Fixes for new SB3 version

* Fix type hint

* Additional fixes
2021-12-01 19:09:32 +01:00
kronion ab24f8039f
PPO variant with invalid action masking (#25)
* Add wrappers

* Add maskable distributions

* Add mypy configuration

* Add maskable base datastructures

* Add ppo_mask package

* Fix circular dependency and remove test code that slipped in

* Automatically mask vecenv if env is masked

* Fix debugging change that slipped in

* Workaround for subclassing RolloutBufferSamples

* Duplicate lots of policy code in order to swap out the distributions used

* Fix pytype error

* Maintain py 3.6 compatibility

* Fix isort lint errors

* Use pyproject.toml to configure black line length

* Blacken

* Remove mypy.ini

* Fully replace RolloutBufferSamples

* Drop support for continuous distributions, remove SDE-related code

* Eliminate MaskableAlgorithm and MaskableOnPolicyAlgorithm

* Fix formatting

* Override superclass methods as needed, fix circular import, improve naming

* Fix codestyle

* Eliminate VecActionMasker, replace with utils

* Fix codestyle

* Support masking for MultiDiscrete action spaces

* Fix codestyle

* Don't require the env to provide the mask already flattened

* Consistent naming, prefer 'Maskable' to 'Masked'

* Register policy

* Link to abstract instead of pdf

* Allow distribution masking to be unapplied + improved comments and docstrings

* Don't use deprecated implicit optional typing

* Check codestyle

* Add docstring and remove misplaced TODO

* Simplify env masking API, error if API unmet. Make use_masking a learn() kwarg

* Fix codestyle

* Update various internals to be consistent with latest SB3

* Simplify MaskableRolloutBuffer reset

* Add docstring and type annotations

* Ensure old probs aren't cached

* Fix for new logger

* Add test + fixes

* Start doc

* Fix type annotation

* Remove abstract class + add test

* Fix evaluation (add support for multi envs)

* Handle merge conflicts in documentation

* Bugfix: mask updates should apply to original logits, not the last masked output

* Add test of distribution masking behavior

* Reformat

* Add MultiBinary support, remove unneeded distribution type checks

* Remove unused import

* Fix when using multiple envs

* Remove addressed TODO

* Upgrade for SB3 1.2.0

* Update docs with results + how to replicate

* Add action masker tests, move wrapper tests

* Move distributions, add more distribution tests

* Add MaskablePPO tests, simplify and rename discrete test env

* Address TODO

* Add tests for MaskableMultiCategoricalDistribution, fix distributions

* Add maskable identity envs for all supported action spaces, add tests, fix bug

* Formatting fixes

* Update doc env

* Dict support not ready

* Cleanup

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2021-09-23 14:50:10 +02:00