Release v1.6.0 and bug fix for TRPO (#84)

This commit is contained in:
Antonin RAFFIN 2022-07-12 23:12:24 +02:00 committed by GitHub
parent db4c0114d0
commit 087951d34b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 61 additions and 11 deletions

View File

@ -6,9 +6,9 @@ dependencies:
- cpuonly=1.0=0 - cpuonly=1.0=0
- pip=21.1 - pip=21.1
- python=3.7 - python=3.7
- pytorch=1.8.1=py3.7_cpu_0 - pytorch=1.11=py3.7_cpu_0
- pip: - pip:
- gym>=0.17.2 - gym==0.21
- cloudpickle - cloudpickle
- opencv-python-headless - opencv-python-headless
- pandas - pandas
@ -17,5 +17,5 @@ dependencies:
- sphinx_autodoc_typehints - sphinx_autodoc_typehints
- stable-baselines3>=1.3.0 - stable-baselines3>=1.3.0
- sphinx>=4.2 - sphinx>=4.2
# See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
- sphinx_rtd_theme>=1.0 - sphinx_rtd_theme>=1.0
- sphinx_copybutton

View File

@ -24,6 +24,14 @@ try:
except ImportError: except ImportError:
enable_spell_check = False enable_spell_check = False
# Try to enable copy button
try:
import sphinx_copybutton # noqa: F401
enable_copy_button = True
except ImportError:
enable_copy_button = False
# source code directory, relative to this file, for sphinx-autobuild # source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath("..")) sys.path.insert(0, os.path.abspath(".."))
@ -51,7 +59,7 @@ with open(version_file) as file_handler:
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "Stable Baselines3 - Contrib" project = "Stable Baselines3 - Contrib"
copyright = "2020, Stable Baselines3" copyright = "2022, Stable Baselines3"
author = "Stable Baselines3 Contributors" author = "Stable Baselines3 Contributors"
# The short X.Y version # The short X.Y version
@ -83,6 +91,9 @@ extensions = [
if enable_spell_check: if enable_spell_check:
extensions.append("sphinxcontrib.spelling") extensions.append("sphinxcontrib.spelling")
if enable_copy_button:
extensions.append("sphinx_copybutton")
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"] templates_path = ["_templates"]

View File

@ -3,7 +3,7 @@
Changelog Changelog
========== ==========
Release 1.5.1a9 (WIP) Release 1.6.0 (2022-07-11)
------------------------------- -------------------------------
**Add RecurrentPPO (aka PPO LSTM)** **Add RecurrentPPO (aka PPO LSTM)**
@ -25,8 +25,9 @@ New Features:
- Added ``RecurrentPPO`` (aka PPO LSTM) - Added ``RecurrentPPO`` (aka PPO LSTM)
Bug Fixes: Bug Fixes:
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
^^^^^^^^^^ ^^^^^^^^^^
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
- Fixed a bug in ``TRPO`` where kl divergence was not implemented for ``MultiDiscrete`` space
Deprecations: Deprecations:
^^^^^^^^^^^^^ ^^^^^^^^^^^^^

View File

@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import numpy as np import numpy as np
import torch as th import torch as th
from gym import spaces from gym import spaces
from stable_baselines3.common.distributions import kl_divergence
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule
from stable_baselines3.common.utils import explained_variance from stable_baselines3.common.utils import explained_variance
from torch import nn from torch import nn
from torch.distributions import kl_divergence
from torch.nn import functional as F from torch.nn import functional as F
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
@ -279,7 +279,7 @@ class TRPO(OnPolicyAlgorithm):
policy_objective = (advantages * ratio).mean() policy_objective = (advantages * ratio).mean()
# KL divergence # KL divergence
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() kl_div = kl_divergence(distribution, old_distribution).mean()
# Surrogate & KL gradient # Surrogate & KL gradient
self.policy.optimizer.zero_grad() self.policy.optimizer.zero_grad()
@ -332,7 +332,7 @@ class TRPO(OnPolicyAlgorithm):
new_policy_objective = (advantages * ratio).mean() new_policy_objective = (advantages * ratio).mean()
# New KL-divergence # New KL-divergence
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() kl_div = kl_divergence(distribution, old_distribution).mean()
# Constraint criteria: # Constraint criteria:
# we need to improve the surrogate policy objective # we need to improve the surrogate policy objective

View File

@ -1 +1 @@
1.5.1a9 1.6.0

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=1.5.1a7", "stable_baselines3>=1.6.0",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

38
tests/test_identity.py Normal file
View File

@ -0,0 +1,38 @@
import numpy as np
import pytest
from stable_baselines3.common.envs import IdentityEnv, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib import QRDQN, TRPO
DIM = 4
@pytest.mark.parametrize("model_class", [QRDQN, TRPO])
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
def test_discrete(model_class, env):
env_ = DummyVecEnv([lambda: env])
kwargs = {}
n_steps = 1500
if model_class == QRDQN:
kwargs = dict(
learning_starts=0,
policy_kwargs=dict(n_quantiles=25, net_arch=[32]),
target_update_interval=10,
train_freq=2,
batch_size=256,
)
n_steps = 1500
# DQN only support discrete actions
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
return
elif n_steps == TRPO:
kwargs = dict(n_steps=256, cg_max_steps=5)
model = model_class("MlpPolicy", env_, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
obs = env.reset()
assert np.shape(model.predict(obs)[0]) == np.shape(obs)