Add rollout_buffer_class to TRPO (#214)
* Add rollout_buffer_class and rollout_buffer_kwargs to TRPO * Update requirements and changelog --------- Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
4d7ed004af
commit
5e437fc4dc
|
|
@ -4,7 +4,7 @@ Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
Release 2.2.0a8 (WIP)
|
Release 2.2.0a9 (WIP)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -16,6 +16,7 @@ Breaking Changes:
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
- Added ``set_options`` for ``AsyncEval``
|
- Added ``set_options`` for ``AsyncEval``
|
||||||
|
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to TRPO
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Un
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
|
from stable_baselines3.common.buffers import RolloutBuffer
|
||||||
from stable_baselines3.common.distributions import kl_divergence
|
from stable_baselines3.common.distributions import kl_divergence
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
||||||
|
|
@ -53,6 +54,8 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
instead of action noise exploration (default: False)
|
instead of action noise exploration (default: False)
|
||||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||||
Default: -1 (only sample at the beginning of the rollout)
|
Default: -1 (only sample at the beginning of the rollout)
|
||||||
|
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
||||||
|
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
|
||||||
:param normalize_advantage: Whether to normalize or not the advantage
|
:param normalize_advantage: Whether to normalize or not the advantage
|
||||||
:param target_kl: Target Kullback-Leibler divergence between updates.
|
:param target_kl: Target Kullback-Leibler divergence between updates.
|
||||||
Should be small for stability. Values like 0.01, 0.05.
|
Should be small for stability. Values like 0.01, 0.05.
|
||||||
|
|
@ -91,6 +94,8 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
gae_lambda: float = 0.95,
|
gae_lambda: float = 0.95,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
sde_sample_freq: int = -1,
|
sde_sample_freq: int = -1,
|
||||||
|
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||||
|
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
normalize_advantage: bool = True,
|
normalize_advantage: bool = True,
|
||||||
target_kl: float = 0.01,
|
target_kl: float = 0.01,
|
||||||
sub_sampling_factor: int = 1,
|
sub_sampling_factor: int = 1,
|
||||||
|
|
@ -114,6 +119,8 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
max_grad_norm=0.0,
|
max_grad_norm=0.0,
|
||||||
use_sde=use_sde,
|
use_sde=use_sde,
|
||||||
sde_sample_freq=sde_sample_freq,
|
sde_sample_freq=sde_sample_freq,
|
||||||
|
rollout_buffer_class=rollout_buffer_class,
|
||||||
|
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||||
stats_window_size=stats_window_size,
|
stats_window_size=stats_window_size,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
2.2.0a8
|
2.2.0a9
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=2.2.0a8,<3.0",
|
"stable_baselines3>=2.2.0a9,<3.0",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue