Add rollout_buffer_class to TRPO (#214)

* Add rollout_buffer_class and rollout_buffer_kwargs to TRPO

* Update requirements and changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
M. Ernestus 2023-10-30 16:10:41 +01:00 committed by Andreas Schaarschmidt
parent 4d7ed004af
commit 5e437fc4dc
4 changed files with 11 additions and 3 deletions

View File

@ -4,7 +4,7 @@ Changelog
==========
Release 2.2.0a8 (WIP)
Release 2.2.0a9 (WIP)
--------------------------
Breaking Changes:
@ -16,6 +16,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``set_options`` for ``AsyncEval``
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to TRPO
Bug Fixes:
^^^^^^^^^^

View File

@ -6,6 +6,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Un
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.distributions import kl_divergence
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
@ -53,6 +54,8 @@ class TRPO(OnPolicyAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
:param normalize_advantage: Whether to normalize or not the advantage
:param target_kl: Target Kullback-Leibler divergence between updates.
Should be small for stability. Values like 0.01, 0.05.
@ -91,6 +94,8 @@ class TRPO(OnPolicyAlgorithm):
gae_lambda: float = 0.95,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
normalize_advantage: bool = True,
target_kl: float = 0.01,
sub_sampling_factor: int = 1,
@ -114,6 +119,8 @@ class TRPO(OnPolicyAlgorithm):
max_grad_norm=0.0,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,

View File

@ -1 +1 @@
2.2.0a8
2.2.0a9

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.2.0a8,<3.0",
"stable_baselines3>=2.2.0a9,<3.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",