Release v1.6.0 and bug fix for TRPO (#84)
This commit is contained in:
parent
db4c0114d0
commit
087951d34b
|
|
@ -6,9 +6,9 @@ dependencies:
|
||||||
- cpuonly=1.0=0
|
- cpuonly=1.0=0
|
||||||
- pip=21.1
|
- pip=21.1
|
||||||
- python=3.7
|
- python=3.7
|
||||||
- pytorch=1.8.1=py3.7_cpu_0
|
- pytorch=1.11=py3.7_cpu_0
|
||||||
- pip:
|
- pip:
|
||||||
- gym>=0.17.2
|
- gym==0.21
|
||||||
- cloudpickle
|
- cloudpickle
|
||||||
- opencv-python-headless
|
- opencv-python-headless
|
||||||
- pandas
|
- pandas
|
||||||
|
|
@ -17,5 +17,5 @@ dependencies:
|
||||||
- sphinx_autodoc_typehints
|
- sphinx_autodoc_typehints
|
||||||
- stable-baselines3>=1.3.0
|
- stable-baselines3>=1.3.0
|
||||||
- sphinx>=4.2
|
- sphinx>=4.2
|
||||||
# See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
|
|
||||||
- sphinx_rtd_theme>=1.0
|
- sphinx_rtd_theme>=1.0
|
||||||
|
- sphinx_copybutton
|
||||||
|
|
|
||||||
13
docs/conf.py
13
docs/conf.py
|
|
@ -24,6 +24,14 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
enable_spell_check = False
|
enable_spell_check = False
|
||||||
|
|
||||||
|
# Try to enable copy button
|
||||||
|
try:
|
||||||
|
import sphinx_copybutton # noqa: F401
|
||||||
|
|
||||||
|
enable_copy_button = True
|
||||||
|
except ImportError:
|
||||||
|
enable_copy_button = False
|
||||||
|
|
||||||
# source code directory, relative to this file, for sphinx-autobuild
|
# source code directory, relative to this file, for sphinx-autobuild
|
||||||
sys.path.insert(0, os.path.abspath(".."))
|
sys.path.insert(0, os.path.abspath(".."))
|
||||||
|
|
||||||
|
|
@ -51,7 +59,7 @@ with open(version_file) as file_handler:
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "Stable Baselines3 - Contrib"
|
project = "Stable Baselines3 - Contrib"
|
||||||
copyright = "2020, Stable Baselines3"
|
copyright = "2022, Stable Baselines3"
|
||||||
author = "Stable Baselines3 Contributors"
|
author = "Stable Baselines3 Contributors"
|
||||||
|
|
||||||
# The short X.Y version
|
# The short X.Y version
|
||||||
|
|
@ -83,6 +91,9 @@ extensions = [
|
||||||
if enable_spell_check:
|
if enable_spell_check:
|
||||||
extensions.append("sphinxcontrib.spelling")
|
extensions.append("sphinxcontrib.spelling")
|
||||||
|
|
||||||
|
if enable_copy_button:
|
||||||
|
extensions.append("sphinx_copybutton")
|
||||||
|
|
||||||
# Add any paths that contain templates here, relative to this directory.
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
templates_path = ["_templates"]
|
templates_path = ["_templates"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.5.1a9 (WIP)
|
Release 1.6.0 (2022-07-11)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
**Add RecurrentPPO (aka PPO LSTM)**
|
**Add RecurrentPPO (aka PPO LSTM)**
|
||||||
|
|
@ -25,8 +25,9 @@ New Features:
|
||||||
- Added ``RecurrentPPO`` (aka PPO LSTM)
|
- Added ``RecurrentPPO`` (aka PPO LSTM)
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
|
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
|
||||||
|
- Fixed a bug in ``TRPO`` where kl divergence was not implemented for ``MultiDiscrete`` space
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
|
from stable_baselines3.common.distributions import kl_divergence
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule
|
||||||
from stable_baselines3.common.utils import explained_variance
|
from stable_baselines3.common.utils import explained_variance
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributions import kl_divergence
|
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
|
from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
|
||||||
|
|
@ -279,7 +279,7 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
policy_objective = (advantages * ratio).mean()
|
policy_objective = (advantages * ratio).mean()
|
||||||
|
|
||||||
# KL divergence
|
# KL divergence
|
||||||
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean()
|
kl_div = kl_divergence(distribution, old_distribution).mean()
|
||||||
|
|
||||||
# Surrogate & KL gradient
|
# Surrogate & KL gradient
|
||||||
self.policy.optimizer.zero_grad()
|
self.policy.optimizer.zero_grad()
|
||||||
|
|
@ -332,7 +332,7 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
new_policy_objective = (advantages * ratio).mean()
|
new_policy_objective = (advantages * ratio).mean()
|
||||||
|
|
||||||
# New KL-divergence
|
# New KL-divergence
|
||||||
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean()
|
kl_div = kl_divergence(distribution, old_distribution).mean()
|
||||||
|
|
||||||
# Constraint criteria:
|
# Constraint criteria:
|
||||||
# we need to improve the surrogate policy objective
|
# we need to improve the surrogate policy objective
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.5.1a9
|
1.6.0
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.5.1a7",
|
"stable_baselines3>=1.6.0",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from stable_baselines3.common.envs import IdentityEnv, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
|
||||||
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||||
|
|
||||||
|
from sb3_contrib import QRDQN, TRPO
|
||||||
|
|
||||||
|
DIM = 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_class", [QRDQN, TRPO])
|
||||||
|
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
|
||||||
|
def test_discrete(model_class, env):
|
||||||
|
env_ = DummyVecEnv([lambda: env])
|
||||||
|
kwargs = {}
|
||||||
|
n_steps = 1500
|
||||||
|
if model_class == QRDQN:
|
||||||
|
kwargs = dict(
|
||||||
|
learning_starts=0,
|
||||||
|
policy_kwargs=dict(n_quantiles=25, net_arch=[32]),
|
||||||
|
target_update_interval=10,
|
||||||
|
train_freq=2,
|
||||||
|
batch_size=256,
|
||||||
|
)
|
||||||
|
n_steps = 1500
|
||||||
|
# DQN only support discrete actions
|
||||||
|
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
|
||||||
|
return
|
||||||
|
elif n_steps == TRPO:
|
||||||
|
kwargs = dict(n_steps=256, cg_max_steps=5)
|
||||||
|
|
||||||
|
model = model_class("MlpPolicy", env_, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
|
||||||
|
|
||||||
|
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
|
||||||
Loading…
Reference in New Issue