Commit Graph

4 Commits

Author SHA1 Message Date
Quentin Gallouédec ded9f65bfd
Fix `sb3_contrib/common/utils.py` type hint (#120)
* Update contribution.md

* New loop struct to make mypy happy

* Update setup.cfg

* Update changelog

* Add return statement in cg solver for max_iter=0

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-11-29 10:24:44 +01:00
Adam Gleave 901a648507
Upgrade Gym to 0.21 (#59)
* Pendulum-v0 -> Pendulum-v1

* Reformat with black

* Update changelog

* Fix dtype bug in TimeFeatureWrapper

* Update version and removed forward calls

* Update CI

* Fix min version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2022-02-22 16:25:43 +01:00
Cyprien 59be198da0
Add Trust Region Policy Optimization (TRPO) (#40)
* Feat: adding TRPO algorithm (WIP)

WIP - Trust Region Policy Algorithm
Currently the Hessian vector product is not working (see inline comments for more detail)

* Feat: adding TRPO algorithm (WIP)

Adding no_grad block for the line search
Additional assert in the conjugate solver to help debugging

* Feat: adding TRPO algorithm (WIP)

- Adding ActorCriticPolicy.get_distribution
- Using the Distribution object to compute the KL divergence
- Checking for objective improvement in the line search
- Moving magic numbers to instance variables

* Feat: adding TRPO algorithm (WIP)

Improving numerical stability of the conjugate gradient algorithm
Critic updates

* Feat: adding TRPO algorithm (WIP)

Changes around the alpha of the line search
Adding TRPO to __init__ files

* feat: TRPO - addressing PR comments

- renaming cg_solver to conjugate_gradient_solver and renaming parameter Avp_fun to  matrix_vector_dot_func + docstring
- extra comments + better variable names in trpo.py
- defining a method for the hessian vector product instead of an inline function
- fix registering correct policies for TRPO and using correct policy base in constructor

* refactor: TRPO - policier

- refactoring sb3_contrib.common.policies to reuse as much code as possible from sb3

* feat: using updated ActorCriticPolicy from SB3

- get_distribution will be added directly to the SB3 version of ActorCriticPolicy, this commit reflects this

* Bump version for `get_distribution` support

* Add basic test

* Reformat

* [ci skip] Fix changelog

* fix: setting train mode for trpo

* fix: batch_size type hint in trpo.py

* style: renaming variables + docstring in trpo.py

* Rename + cleanup

* Move grad computation to separate method

* Remove grad norm clipping

* Remove n epochs and add sub-sampling

* Update defaults

* Add Doc

* Add more test and fixes for CNN

* Update doc + add benchmark

* Add tests + update doc

* Fix doc

* Improve names for conjugate gradient

* Update comments

* Update changelog

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2021-12-29 11:58:03 +01:00
Toshiki Watanabe b30397fff5
Add QR-DQN (#13)
* Add QR-DQN(WIP)

* Update docstring

* Add quantile_huber_loss

* Fix typo

* Remove unnecessary lines

* Update variable names and comments in quantile_huber_loss

* Fix mutable arguments

* Update variable names

* Ignore import not used warnings

* Fix default parameter of optimizer in QR-DQN

* Update quantile_huber_loss to have more reasonable interface

* update tests

* Add assertion to quantile_huber_loss

* Update variable names of quantile regression

* Update comments

* Reduce the number of quantiles during test

* Update comment

* Update quantile_huber_loss

* Fix isort

* Add document of QR-DQN without results

* Update docs

* Fix bugs

* Update doc

* Add comments about shape

* Minor edits

* Update comments

* Add benchmark

* Doc fixes

* Update doc

* Bug fix in saving/loading + update tests

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-12-21 11:17:48 +01:00