From 5e437fc4dc338c7013610fa93170a0f688216699 Mon Sep 17 00:00:00 2001 From: "M. Ernestus" Date: Mon, 30 Oct 2023 16:10:41 +0100 Subject: [PATCH] Add rollout_buffer_class to TRPO (#214) * Add rollout_buffer_class and rollout_buffer_kwargs to TRPO * Update requirements and changelog --------- Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 3 ++- sb3_contrib/trpo/trpo.py | 7 +++++++ sb3_contrib/version.txt | 2 +- setup.py | 2 +- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9d4a4d2..be39952 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 2.2.0a8 (WIP) +Release 2.2.0a9 (WIP) -------------------------- Breaking Changes: @@ -16,6 +16,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added ``set_options`` for ``AsyncEval`` +- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to TRPO Bug Fixes: ^^^^^^^^^^ diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index fd59f04..8165081 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -6,6 +6,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Un import numpy as np import torch as th from gymnasium import spaces +from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.distributions import kl_divergence from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy @@ -53,6 +54,8 @@ class TRPO(OnPolicyAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) + :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. + :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation :param normalize_advantage: Whether to normalize or not the advantage :param target_kl: Target Kullback-Leibler divergence between updates. Should be small for stability. Values like 0.01, 0.05. @@ -91,6 +94,8 @@ class TRPO(OnPolicyAlgorithm): gae_lambda: float = 0.95, use_sde: bool = False, sde_sample_freq: int = -1, + rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, normalize_advantage: bool = True, target_kl: float = 0.01, sub_sampling_factor: int = 1, @@ -114,6 +119,8 @@ class TRPO(OnPolicyAlgorithm): max_grad_norm=0.0, use_sde=use_sde, sde_sample_freq=sde_sample_freq, + rollout_buffer_class=rollout_buffer_class, + rollout_buffer_kwargs=rollout_buffer_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index f1f23b3..b7120ad 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.2.0a8 +2.2.0a9 diff --git a/setup.py b/setup.py index 88242e2..a82172f 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.2.0a8,<3.0", + "stable_baselines3>=2.2.0a9,<3.0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin",