Release v1.6.0 and bug fix for TRPO (#84)
This commit is contained in:
parent
db4c0114d0
commit
087951d34b
|
|
@ -6,9 +6,9 @@ dependencies:
|
|||
- cpuonly=1.0=0
|
||||
- pip=21.1
|
||||
- python=3.7
|
||||
- pytorch=1.8.1=py3.7_cpu_0
|
||||
- pytorch=1.11=py3.7_cpu_0
|
||||
- pip:
|
||||
- gym>=0.17.2
|
||||
- gym==0.21
|
||||
- cloudpickle
|
||||
- opencv-python-headless
|
||||
- pandas
|
||||
|
|
@ -17,5 +17,5 @@ dependencies:
|
|||
- sphinx_autodoc_typehints
|
||||
- stable-baselines3>=1.3.0
|
||||
- sphinx>=4.2
|
||||
# See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
|
||||
- sphinx_rtd_theme>=1.0
|
||||
- sphinx_copybutton
|
||||
|
|
|
|||
13
docs/conf.py
13
docs/conf.py
|
|
@ -24,6 +24,14 @@ try:
|
|||
except ImportError:
|
||||
enable_spell_check = False
|
||||
|
||||
# Try to enable copy button
|
||||
try:
|
||||
import sphinx_copybutton # noqa: F401
|
||||
|
||||
enable_copy_button = True
|
||||
except ImportError:
|
||||
enable_copy_button = False
|
||||
|
||||
# source code directory, relative to this file, for sphinx-autobuild
|
||||
sys.path.insert(0, os.path.abspath(".."))
|
||||
|
||||
|
|
@ -51,7 +59,7 @@ with open(version_file) as file_handler:
|
|||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "Stable Baselines3 - Contrib"
|
||||
copyright = "2020, Stable Baselines3"
|
||||
copyright = "2022, Stable Baselines3"
|
||||
author = "Stable Baselines3 Contributors"
|
||||
|
||||
# The short X.Y version
|
||||
|
|
@ -83,6 +91,9 @@ extensions = [
|
|||
if enable_spell_check:
|
||||
extensions.append("sphinxcontrib.spelling")
|
||||
|
||||
if enable_copy_button:
|
||||
extensions.append("sphinx_copybutton")
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ["_templates"]
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.5.1a9 (WIP)
|
||||
Release 1.6.0 (2022-07-11)
|
||||
-------------------------------
|
||||
|
||||
**Add RecurrentPPO (aka PPO LSTM)**
|
||||
|
|
@ -25,8 +25,9 @@ New Features:
|
|||
- Added ``RecurrentPPO`` (aka PPO LSTM)
|
||||
|
||||
Bug Fixes:
|
||||
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
|
||||
^^^^^^^^^^
|
||||
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
|
||||
- Fixed a bug in ``TRPO`` where kl divergence was not implemented for ``MultiDiscrete`` space
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.distributions import kl_divergence
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance
|
||||
from torch import nn
|
||||
from torch.distributions import kl_divergence
|
||||
from torch.nn import functional as F
|
||||
|
||||
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
|
||||
|
|
@ -279,7 +279,7 @@ class TRPO(OnPolicyAlgorithm):
|
|||
policy_objective = (advantages * ratio).mean()
|
||||
|
||||
# KL divergence
|
||||
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean()
|
||||
kl_div = kl_divergence(distribution, old_distribution).mean()
|
||||
|
||||
# Surrogate & KL gradient
|
||||
self.policy.optimizer.zero_grad()
|
||||
|
|
@ -332,7 +332,7 @@ class TRPO(OnPolicyAlgorithm):
|
|||
new_policy_objective = (advantages * ratio).mean()
|
||||
|
||||
# New KL-divergence
|
||||
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean()
|
||||
kl_div = kl_divergence(distribution, old_distribution).mean()
|
||||
|
||||
# Constraint criteria:
|
||||
# we need to improve the surrogate policy objective
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.5.1a9
|
||||
1.6.0
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.5.1a7",
|
||||
"stable_baselines3>=1.6.0",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from stable_baselines3.common.envs import IdentityEnv, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from sb3_contrib import QRDQN, TRPO
|
||||
|
||||
DIM = 4
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [QRDQN, TRPO])
|
||||
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
|
||||
def test_discrete(model_class, env):
|
||||
env_ = DummyVecEnv([lambda: env])
|
||||
kwargs = {}
|
||||
n_steps = 1500
|
||||
if model_class == QRDQN:
|
||||
kwargs = dict(
|
||||
learning_starts=0,
|
||||
policy_kwargs=dict(n_quantiles=25, net_arch=[32]),
|
||||
target_update_interval=10,
|
||||
train_freq=2,
|
||||
batch_size=256,
|
||||
)
|
||||
n_steps = 1500
|
||||
# DQN only support discrete actions
|
||||
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
|
||||
return
|
||||
elif n_steps == TRPO:
|
||||
kwargs = dict(n_steps=256, cg_max_steps=5)
|
||||
|
||||
model = model_class("MlpPolicy", env_, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
|
||||
|
||||
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||
obs = env.reset()
|
||||
|
||||
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
|
||||
Loading…
Reference in New Issue