Release v1.6.0 and bug fix for TRPO (#84)

This commit is contained in:
Antonin RAFFIN 2022-07-12 23:12:24 +02:00 committed by GitHub
parent db4c0114d0
commit 087951d34b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 61 additions and 11 deletions

View File

@ -6,9 +6,9 @@ dependencies:
- cpuonly=1.0=0
- pip=21.1
- python=3.7
- pytorch=1.8.1=py3.7_cpu_0
- pytorch=1.11=py3.7_cpu_0
- pip:
- gym>=0.17.2
- gym==0.21
- cloudpickle
- opencv-python-headless
- pandas
@ -17,5 +17,5 @@ dependencies:
- sphinx_autodoc_typehints
- stable-baselines3>=1.3.0
- sphinx>=4.2
# See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
- sphinx_rtd_theme>=1.0
- sphinx_copybutton

View File

@ -24,6 +24,14 @@ try:
except ImportError:
enable_spell_check = False
# Try to enable copy button
try:
import sphinx_copybutton # noqa: F401
enable_copy_button = True
except ImportError:
enable_copy_button = False
# source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath(".."))
@ -51,7 +59,7 @@ with open(version_file) as file_handler:
# -- Project information -----------------------------------------------------
project = "Stable Baselines3 - Contrib"
copyright = "2020, Stable Baselines3"
copyright = "2022, Stable Baselines3"
author = "Stable Baselines3 Contributors"
# The short X.Y version
@ -83,6 +91,9 @@ extensions = [
if enable_spell_check:
extensions.append("sphinxcontrib.spelling")
if enable_copy_button:
extensions.append("sphinx_copybutton")
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]

View File

@ -3,7 +3,7 @@
Changelog
==========
Release 1.5.1a9 (WIP)
Release 1.6.0 (2022-07-11)
-------------------------------
**Add RecurrentPPO (aka PPO LSTM)**
@ -25,8 +25,9 @@ New Features:
- Added ``RecurrentPPO`` (aka PPO LSTM)
Bug Fixes:
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
^^^^^^^^^^
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
- Fixed a bug in ``TRPO`` where kl divergence was not implemented for ``MultiDiscrete`` space
Deprecations:
^^^^^^^^^^^^^

View File

@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.distributions import kl_divergence
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule
from stable_baselines3.common.utils import explained_variance
from torch import nn
from torch.distributions import kl_divergence
from torch.nn import functional as F
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
@ -279,7 +279,7 @@ class TRPO(OnPolicyAlgorithm):
policy_objective = (advantages * ratio).mean()
# KL divergence
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean()
kl_div = kl_divergence(distribution, old_distribution).mean()
# Surrogate & KL gradient
self.policy.optimizer.zero_grad()
@ -332,7 +332,7 @@ class TRPO(OnPolicyAlgorithm):
new_policy_objective = (advantages * ratio).mean()
# New KL-divergence
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean()
kl_div = kl_divergence(distribution, old_distribution).mean()
# Constraint criteria:
# we need to improve the surrogate policy objective

View File

@ -1 +1 @@
1.5.1a9
1.6.0

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.5.1a7",
"stable_baselines3>=1.6.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",

38
tests/test_identity.py Normal file
View File

@ -0,0 +1,38 @@
import numpy as np
import pytest
from stable_baselines3.common.envs import IdentityEnv, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib import QRDQN, TRPO
DIM = 4
@pytest.mark.parametrize("model_class", [QRDQN, TRPO])
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
def test_discrete(model_class, env):
env_ = DummyVecEnv([lambda: env])
kwargs = {}
n_steps = 1500
if model_class == QRDQN:
kwargs = dict(
learning_starts=0,
policy_kwargs=dict(n_quantiles=25, net_arch=[32]),
target_update_interval=10,
train_freq=2,
batch_size=256,
)
n_steps = 1500
# DQN only support discrete actions
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
return
elif n_steps == TRPO:
kwargs = dict(n_steps=256, cg_max_steps=5)
model = model_class("MlpPolicy", env_, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
obs = env.reset()
assert np.shape(model.predict(obs)[0]) == np.shape(obs)