Commit Graph

10 Commits

Author SHA1 Message Date
Antonin RAFFIN 4d7ed004af Sync SB3 Contrib with SB3 (#213)
* Update RTD config

* Switch to ruff for sorting imports

* Evaluate falsy to truthy with not rather than `is False`

* Add `features_extractor` argument to maskable policy

* Add set_options for AsyncEval

* Doc fixes
2024-05-06 14:20:28 +01:00
Antonin RAFFIN 35f06254ba
Drop python 3.7, add 3.11 and update github templates (#194)
* Update issue templates

* Update supported python versions

* Update min PyTorch version

* Skip pytype for py311

* Update conda env for doc
2023-07-03 12:45:20 +02:00
Antonin RAFFIN de92025bb2
Prepare Release v2.0 (#192) 2023-06-23 13:10:17 +02:00
Antonin RAFFIN 86fb056fda
Update doc: switch from Gym to Gymnasium (#182)
* Update documentation to switch to Gymnasium

* Fix custom env type hints

* Fix type check
2023-05-10 11:40:40 +02:00
Antonin RAFFIN 21cc96cafd
Add Gymnasium support (#152)
* Add support for Gym 0.24

* Fixes for gym 0.24

* Fix for new reset signature

* Add tmp SB3 branch

* Fixes for gym 0.26

* Remove unused import

* Fix dependency

* Type annotations fixes

* Reformat

* Reformat with black 23

* Move to gymnasium

* Patch env if needed

* Fix types

* Fix CI

* Fixes for gymnasium

* Fix wrapper annotations

* Update version

* Fix type check

* Update QRDQN type hints and bug fix with multi envs

* Fix TQC type hints

* Fix TRPO type hints

* Additional fixes

* Update SB3 version

* Update issue templates and CI

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2023-04-14 13:52:07 +02:00
Antonin RAFFIN ce115982aa
Fix QR-DQN type hints (#170) 2023-03-30 11:50:26 +02:00
Antonin RAFFIN b5fd6e65ba
Update SB3 and config (#167) 2023-03-20 12:35:37 +01:00
Antonin RAFFIN 728c1c5b7f
Issue forms and pyproject.toml (#162)
* Issue forms and pyproject.toml

* [ci skip] Fix typos

* Fix isort config

* Use secret link to download atari roms

* Fix for mypy and update config

* Upgrade SB3 and fix warnings

* Fix doc build

* Update Makefile

* Lint first
2023-03-11 22:57:45 +01:00
Antonin Raffin 1d0edd2dab
Move pytype to pyproject.toml 2023-01-10 22:55:12 +01:00
kronion ab24f8039f
PPO variant with invalid action masking (#25)
* Add wrappers

* Add maskable distributions

* Add mypy configuration

* Add maskable base datastructures

* Add ppo_mask package

* Fix circular dependency and remove test code that slipped in

* Automatically mask vecenv if env is masked

* Fix debugging change that slipped in

* Workaround for subclassing RolloutBufferSamples

* Duplicate lots of policy code in order to swap out the distributions used

* Fix pytype error

* Maintain py 3.6 compatibility

* Fix isort lint errors

* Use pyproject.toml to configure black line length

* Blacken

* Remove mypy.ini

* Fully replace RolloutBufferSamples

* Drop support for continuous distributions, remove SDE-related code

* Eliminate MaskableAlgorithm and MaskableOnPolicyAlgorithm

* Fix formatting

* Override superclass methods as needed, fix circular import, improve naming

* Fix codestyle

* Eliminate VecActionMasker, replace with utils

* Fix codestyle

* Support masking for MultiDiscrete action spaces

* Fix codestyle

* Don't require the env to provide the mask already flattened

* Consistent naming, prefer 'Maskable' to 'Masked'

* Register policy

* Link to abstract instead of pdf

* Allow distribution masking to be unapplied + improved comments and docstrings

* Don't use deprecated implicit optional typing

* Check codestyle

* Add docstring and remove misplaced TODO

* Simplify env masking API, error if API unmet. Make use_masking a learn() kwarg

* Fix codestyle

* Update various internals to be consistent with latest SB3

* Simplify MaskableRolloutBuffer reset

* Add docstring and type annotations

* Ensure old probs aren't cached

* Fix for new logger

* Add test + fixes

* Start doc

* Fix type annotation

* Remove abstract class + add test

* Fix evaluation (add support for multi envs)

* Handle merge conflicts in documentation

* Bugfix: mask updates should apply to original logits, not the last masked output

* Add test of distribution masking behavior

* Reformat

* Add MultiBinary support, remove unneeded distribution type checks

* Remove unused import

* Fix when using multiple envs

* Remove addressed TODO

* Upgrade for SB3 1.2.0

* Update docs with results + how to replicate

* Add action masker tests, move wrapper tests

* Move distributions, add more distribution tests

* Add MaskablePPO tests, simplify and rename discrete test env

* Address TODO

* Add tests for MaskableMultiCategoricalDistribution, fix distributions

* Add maskable identity envs for all supported action spaces, add tests, fix bug

* Formatting fixes

* Update doc env

* Dict support not ready

* Cleanup

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2021-09-23 14:50:10 +02:00