Add rollout_buffer_class to TRPO (#214)

* Add rollout_buffer_class and rollout_buffer_kwargs to TRPO

* Update requirements and changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
M. Ernestus 2023-10-30 16:10:41 +01:00 committed by Andreas Schaarschmidt
parent 4d7ed004af
commit 5e437fc4dc
4 changed files with 11 additions and 3 deletions

View File

@ -4,7 +4,7 @@ Changelog
========== ==========
Release 2.2.0a8 (WIP) Release 2.2.0a9 (WIP)
-------------------------- --------------------------
Breaking Changes: Breaking Changes:
@ -16,6 +16,7 @@ Breaking Changes:
New Features: New Features:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
- Added ``set_options`` for ``AsyncEval`` - Added ``set_options`` for ``AsyncEval``
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to TRPO
Bug Fixes: Bug Fixes:
^^^^^^^^^^ ^^^^^^^^^^

View File

@ -6,6 +6,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Un
import numpy as np import numpy as np
import torch as th import torch as th
from gymnasium import spaces from gymnasium import spaces
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.distributions import kl_divergence from stable_baselines3.common.distributions import kl_divergence
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
@ -53,6 +54,8 @@ class TRPO(OnPolicyAlgorithm):
instead of action noise exploration (default: False) instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout) Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
:param normalize_advantage: Whether to normalize or not the advantage :param normalize_advantage: Whether to normalize or not the advantage
:param target_kl: Target Kullback-Leibler divergence between updates. :param target_kl: Target Kullback-Leibler divergence between updates.
Should be small for stability. Values like 0.01, 0.05. Should be small for stability. Values like 0.01, 0.05.
@ -91,6 +94,8 @@ class TRPO(OnPolicyAlgorithm):
gae_lambda: float = 0.95, gae_lambda: float = 0.95,
use_sde: bool = False, use_sde: bool = False,
sde_sample_freq: int = -1, sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
normalize_advantage: bool = True, normalize_advantage: bool = True,
target_kl: float = 0.01, target_kl: float = 0.01,
sub_sampling_factor: int = 1, sub_sampling_factor: int = 1,
@ -114,6 +119,8 @@ class TRPO(OnPolicyAlgorithm):
max_grad_norm=0.0, max_grad_norm=0.0,
use_sde=use_sde, use_sde=use_sde,
sde_sample_freq=sde_sample_freq, sde_sample_freq=sde_sample_freq,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size, stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log, tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs, policy_kwargs=policy_kwargs,

View File

@ -1 +1 @@
2.2.0a8 2.2.0a9

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=2.2.0a8,<3.0", "stable_baselines3>=2.2.0a9,<3.0",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",