Bug fix for QR-DQN (#21)

* Bug fix for QR-DQN

* Upgrade SB3
This commit is contained in:
Antonin RAFFIN 2021-03-06 14:54:43 +01:00 committed by GitHub
parent 7c2eb833c0
commit 9824daca44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 48 additions and 16 deletions

View File

@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: [3.6, 3.7] # 3.8 not supported yet by pytype python-version: [3.6, 3.7, 3.8]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
@ -28,7 +28,7 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
# cpu version of pytorch # cpu version of pytorch
pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install torch==1.8.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# Install dependencies for docs and tests # Install dependencies for docs and tests
pip install stable_baselines3[extra,tests,docs] pip install stable_baselines3[extra,tests,docs]
# Install master version # Install master version

View File

@ -3,6 +3,18 @@
Changelog Changelog
========== ==========
Release 1.0rc1 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.0rc1
Bug Fixes:
^^^^^^^^^^
- Fixed a bug with ``QR-DQN`` predict method when using ``deterministic=False`` with image space
Pre-Release 0.11.1 (2021-02-27) Pre-Release 0.11.1 (2021-02-27)
------------------------------- -------------------------------

View File

@ -129,10 +129,10 @@ class QRDQNPolicy(BasePolicy):
) )
if net_arch is None: if net_arch is None:
if features_extractor_class == FlattenExtractor: if features_extractor_class == NatureCNN:
net_arch = [64, 64]
else:
net_arch = [] net_arch = []
else:
net_arch = [64, 64]
self.n_quantiles = n_quantiles self.n_quantiles = n_quantiles
self.net_arch = net_arch self.net_arch = net_arch

View File

@ -5,6 +5,7 @@ import numpy as np
import torch as th import torch as th
from stable_baselines3.common import logger from stable_baselines3.common import logger
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
@ -211,7 +212,7 @@ class QRDQN(OffPolicyAlgorithm):
(used in recurrent policies) (used in recurrent policies)
""" """
if not deterministic and np.random.rand() < self.exploration_rate: if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(observation, self.observation_space): if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
n_batch = observation.shape[0] n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)]) action = np.array([self.action_space.sample() for _ in range(n_batch)])
else: else:

View File

@ -314,10 +314,10 @@ class TQCPolicy(BasePolicy):
) )
if net_arch is None: if net_arch is None:
if features_extractor_class == FlattenExtractor: if features_extractor_class == NatureCNN:
net_arch = [256, 256]
else:
net_arch = [] net_arch = []
else:
net_arch = [256, 256]
actor_arch, critic_arch = get_actor_critic_arch(net_arch) actor_arch, critic_arch = get_actor_critic_arch(net_arch)

View File

@ -1 +1 @@
0.11.1 1.0rc1

View File

@ -62,7 +62,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3[tests,docs]>=0.11.1", "stable_baselines3[tests,docs]>=1.0rc1",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

View File

@ -6,6 +6,7 @@ import pytest
import torch as th import torch as th
from stable_baselines3.common.identity_env import FakeImageEnv from stable_baselines3.common.identity_env import FakeImageEnv
from stable_baselines3.common.utils import zip_strict from stable_baselines3.common.utils import zip_strict
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
from sb3_contrib import QRDQN, TQC from sb3_contrib import QRDQN, TQC
@ -16,19 +17,37 @@ def test_cnn(tmp_path, model_class):
# Fake grayscale with frameskip # Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution # Atari after preprocessing: 84x84x1, here we are using lower resolution
# to check that the network handle it automatically # to check that the network handle it automatically
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC}) env = FakeImageEnv(
screen_height=40,
screen_width=40,
n_channels=1,
discrete=model_class not in {TQC},
)
kwargs = {} kwargs = {}
if model_class in {TQC, QRDQN}: if model_class in {TQC, QRDQN}:
# Avoid memory error when using replay buffer # Avoid memory error when using replay buffer
# Reduce the size of the features and the number of quantiles # Reduce the size of the features and the number of quantiles
kwargs = dict( kwargs = dict(
buffer_size=250, buffer_size=250,
policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)), policy_kwargs=dict(
n_quantiles=25,
features_extractor_kwargs=dict(features_dim=32),
),
) )
model = model_class("CnnPolicy", env, **kwargs).learn(250) model = model_class("CnnPolicy", env, **kwargs).learn(250)
obs = env.reset() obs = env.reset()
# FakeImageEnv is channel last by default and should be wrapped
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)
# Test stochastic predict with channel last input
if model_class == QRDQN:
model.exploration_rate = 0.9
for _ in range(10):
model.predict(obs, deterministic=False)
action, _ = model.predict(obs, deterministic=True) action, _ = model.predict(obs, deterministic=True)
model.save(tmp_path / SAVE_NAME) model.save(tmp_path / SAVE_NAME)

View File

@ -184,17 +184,17 @@ def test_set_env(model_class):
# create model # create model
model = model_class("MlpPolicy", env, **kwargs) model = model_class("MlpPolicy", env, **kwargs)
# learn # learn
model.learn(total_timesteps=300) model.learn(total_timesteps=150)
# change env # change env
model.set_env(env2) model.set_env(env2)
# learn again # learn again
model.learn(total_timesteps=300) model.learn(total_timesteps=150)
# change env test wrapping # change env test wrapping
model.set_env(env3) model.set_env(env3)
# learn again # learn again
model.learn(total_timesteps=300) model.learn(total_timesteps=150)
@pytest.mark.parametrize("model_class", MODEL_LIST) @pytest.mark.parametrize("model_class", MODEL_LIST)