Add rollout_buffer_class to TRPO (#214)
* Add rollout_buffer_class and rollout_buffer_kwargs to TRPO * Update requirements and changelog --------- Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
4d7ed004af
commit
5e437fc4dc
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 2.2.0a8 (WIP)
|
||||
Release 2.2.0a9 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -16,6 +16,7 @@ Breaking Changes:
|
|||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``set_options`` for ``AsyncEval``
|
||||
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to TRPO
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Un
|
|||
import numpy as np
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.distributions import kl_divergence
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
||||
|
|
@ -53,6 +54,8 @@ class TRPO(OnPolicyAlgorithm):
|
|||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
||||
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
|
||||
:param normalize_advantage: Whether to normalize or not the advantage
|
||||
:param target_kl: Target Kullback-Leibler divergence between updates.
|
||||
Should be small for stability. Values like 0.01, 0.05.
|
||||
|
|
@ -91,6 +94,8 @@ class TRPO(OnPolicyAlgorithm):
|
|||
gae_lambda: float = 0.95,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_advantage: bool = True,
|
||||
target_kl: float = 0.01,
|
||||
sub_sampling_factor: int = 1,
|
||||
|
|
@ -114,6 +119,8 @@ class TRPO(OnPolicyAlgorithm):
|
|||
max_grad_norm=0.0,
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
rollout_buffer_class=rollout_buffer_class,
|
||||
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||
stats_window_size=stats_window_size,
|
||||
tensorboard_log=tensorboard_log,
|
||||
policy_kwargs=policy_kwargs,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.2.0a8
|
||||
2.2.0a9
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=2.2.0a8,<3.0",
|
||||
"stable_baselines3>=2.2.0a9,<3.0",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
Loading…
Reference in New Issue