* Feat: adding TRPO algorithm (WIP) WIP - Trust Region Policy Algorithm Currently the Hessian vector product is not working (see inline comments for more detail) * Feat: adding TRPO algorithm (WIP) Adding no_grad block for the line search Additional assert in the conjugate solver to help debugging * Feat: adding TRPO algorithm (WIP) - Adding ActorCriticPolicy.get_distribution - Using the Distribution object to compute the KL divergence - Checking for objective improvement in the line search - Moving magic numbers to instance variables * Feat: adding TRPO algorithm (WIP) Improving numerical stability of the conjugate gradient algorithm Critic updates * Feat: adding TRPO algorithm (WIP) Changes around the alpha of the line search Adding TRPO to __init__ files * feat: TRPO - addressing PR comments - renaming cg_solver to conjugate_gradient_solver and renaming parameter Avp_fun to matrix_vector_dot_func + docstring - extra comments + better variable names in trpo.py - defining a method for the hessian vector product instead of an inline function - fix registering correct policies for TRPO and using correct policy base in constructor * refactor: TRPO - policier - refactoring sb3_contrib.common.policies to reuse as much code as possible from sb3 * feat: using updated ActorCriticPolicy from SB3 - get_distribution will be added directly to the SB3 version of ActorCriticPolicy, this commit reflects this * Bump version for `get_distribution` support * Add basic test * Reformat * [ci skip] Fix changelog * fix: setting train mode for trpo * fix: batch_size type hint in trpo.py * style: renaming variables + docstring in trpo.py * Rename + cleanup * Move grad computation to separate method * Remove grad norm clipping * Remove n epochs and add sub-sampling * Update defaults * Add Doc * Add more test and fixes for CNN * Update doc + add benchmark * Add tests + update doc * Fix doc * Improve names for conjugate gradient * Update comments * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> |
||
|---|---|---|
| .. | ||
| common | ||
| ppo_mask | ||
| qrdqn | ||
| tqc | ||
| trpo | ||
| __init__.py | ||
| py.typed | ||
| version.txt | ||