* Feat: adding TRPO algorithm (WIP)
WIP - Trust Region Policy Algorithm
Currently the Hessian vector product is not working (see inline comments for more detail)
* Feat: adding TRPO algorithm (WIP)
Adding no_grad block for the line search
Additional assert in the conjugate solver to help debugging
* Feat: adding TRPO algorithm (WIP)
- Adding ActorCriticPolicy.get_distribution
- Using the Distribution object to compute the KL divergence
- Checking for objective improvement in the line search
- Moving magic numbers to instance variables
* Feat: adding TRPO algorithm (WIP)
Improving numerical stability of the conjugate gradient algorithm
Critic updates
* Feat: adding TRPO algorithm (WIP)
Changes around the alpha of the line search
Adding TRPO to __init__ files
* feat: TRPO - addressing PR comments
- renaming cg_solver to conjugate_gradient_solver and renaming parameter Avp_fun to matrix_vector_dot_func + docstring
- extra comments + better variable names in trpo.py
- defining a method for the hessian vector product instead of an inline function
- fix registering correct policies for TRPO and using correct policy base in constructor
* refactor: TRPO - policier
- refactoring sb3_contrib.common.policies to reuse as much code as possible from sb3
* feat: using updated ActorCriticPolicy from SB3
- get_distribution will be added directly to the SB3 version of ActorCriticPolicy, this commit reflects this
* Bump version for `get_distribution` support
* Add basic test
* Reformat
* [ci skip] Fix changelog
* fix: setting train mode for trpo
* fix: batch_size type hint in trpo.py
* style: renaming variables + docstring in trpo.py
* Rename + cleanup
* Move grad computation to separate method
* Remove grad norm clipping
* Remove n epochs and add sub-sampling
* Update defaults
* Add Doc
* Add more test and fixes for CNN
* Update doc + add benchmark
* Add tests + update doc
* Fix doc
* Improve names for conjugate gradient
* Update comments
* Update changelog
Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
* TQC support for multienv
* Add optional layer norm for TQC
* Add layer nprm for all policies
* Revert "Add layer nprm for all policies"
This reverts commit 1306c3c64eb12613464982c66cb416a3bbc66285.
* Revert "Add optional layer norm for TQC"
This reverts commit 200222e3a8878007aa6032d540ae74274a4d0788.
* Add experimental support to train off-policy algorithms with multiple envs
* Bump version
* Update version
* Add wrappers
* Add maskable distributions
* Add mypy configuration
* Add maskable base datastructures
* Add ppo_mask package
* Fix circular dependency and remove test code that slipped in
* Automatically mask vecenv if env is masked
* Fix debugging change that slipped in
* Workaround for subclassing RolloutBufferSamples
* Duplicate lots of policy code in order to swap out the distributions used
* Fix pytype error
* Maintain py 3.6 compatibility
* Fix isort lint errors
* Use pyproject.toml to configure black line length
* Blacken
* Remove mypy.ini
* Fully replace RolloutBufferSamples
* Drop support for continuous distributions, remove SDE-related code
* Eliminate MaskableAlgorithm and MaskableOnPolicyAlgorithm
* Fix formatting
* Override superclass methods as needed, fix circular import, improve naming
* Fix codestyle
* Eliminate VecActionMasker, replace with utils
* Fix codestyle
* Support masking for MultiDiscrete action spaces
* Fix codestyle
* Don't require the env to provide the mask already flattened
* Consistent naming, prefer 'Maskable' to 'Masked'
* Register policy
* Link to abstract instead of pdf
* Allow distribution masking to be unapplied + improved comments and docstrings
* Don't use deprecated implicit optional typing
* Check codestyle
* Add docstring and remove misplaced TODO
* Simplify env masking API, error if API unmet. Make use_masking a learn() kwarg
* Fix codestyle
* Update various internals to be consistent with latest SB3
* Simplify MaskableRolloutBuffer reset
* Add docstring and type annotations
* Ensure old probs aren't cached
* Fix for new logger
* Add test + fixes
* Start doc
* Fix type annotation
* Remove abstract class + add test
* Fix evaluation (add support for multi envs)
* Handle merge conflicts in documentation
* Bugfix: mask updates should apply to original logits, not the last masked output
* Add test of distribution masking behavior
* Reformat
* Add MultiBinary support, remove unneeded distribution type checks
* Remove unused import
* Fix when using multiple envs
* Remove addressed TODO
* Upgrade for SB3 1.2.0
* Update docs with results + how to replicate
* Add action masker tests, move wrapper tests
* Move distributions, add more distribution tests
* Add MaskablePPO tests, simplify and rename discrete test env
* Address TODO
* Add tests for MaskableMultiCategoricalDistribution, fix distributions
* Add maskable identity envs for all supported action spaces, add tests, fix bug
* Formatting fixes
* Update doc env
* Dict support not ready
* Cleanup
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>