parent
7c2eb833c0
commit
9824daca44
|
|
@ -16,7 +16,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [3.6, 3.7] # 3.8 not supported yet by pytype
|
python-version: [3.6, 3.7, 3.8]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|
@ -28,7 +28,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
# cpu version of pytorch
|
# cpu version of pytorch
|
||||||
pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
pip install torch==1.8.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
# Install dependencies for docs and tests
|
# Install dependencies for docs and tests
|
||||||
pip install stable_baselines3[extra,tests,docs]
|
pip install stable_baselines3[extra,tests,docs]
|
||||||
# Install master version
|
# Install master version
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,18 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
Release 1.0rc1 (WIP)
|
||||||
|
-------------------------------
|
||||||
|
|
||||||
|
Breaking Changes:
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Upgraded to Stable-Baselines3 >= 1.0rc1
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
^^^^^^^^^^
|
||||||
|
- Fixed a bug with ``QR-DQN`` predict method when using ``deterministic=False`` with image space
|
||||||
|
|
||||||
|
|
||||||
Pre-Release 0.11.1 (2021-02-27)
|
Pre-Release 0.11.1 (2021-02-27)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -129,10 +129,10 @@ class QRDQNPolicy(BasePolicy):
|
||||||
)
|
)
|
||||||
|
|
||||||
if net_arch is None:
|
if net_arch is None:
|
||||||
if features_extractor_class == FlattenExtractor:
|
if features_extractor_class == NatureCNN:
|
||||||
net_arch = [64, 64]
|
|
||||||
else:
|
|
||||||
net_arch = []
|
net_arch = []
|
||||||
|
else:
|
||||||
|
net_arch = [64, 64]
|
||||||
|
|
||||||
self.n_quantiles = n_quantiles
|
self.n_quantiles = n_quantiles
|
||||||
self.net_arch = net_arch
|
self.net_arch = net_arch
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common import logger
|
from stable_baselines3.common import logger
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
|
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
||||||
|
|
||||||
|
|
@ -211,7 +212,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
(used in recurrent policies)
|
(used in recurrent policies)
|
||||||
"""
|
"""
|
||||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||||
if is_vectorized_observation(observation, self.observation_space):
|
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
||||||
n_batch = observation.shape[0]
|
n_batch = observation.shape[0]
|
||||||
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -314,10 +314,10 @@ class TQCPolicy(BasePolicy):
|
||||||
)
|
)
|
||||||
|
|
||||||
if net_arch is None:
|
if net_arch is None:
|
||||||
if features_extractor_class == FlattenExtractor:
|
if features_extractor_class == NatureCNN:
|
||||||
net_arch = [256, 256]
|
|
||||||
else:
|
|
||||||
net_arch = []
|
net_arch = []
|
||||||
|
else:
|
||||||
|
net_arch = [256, 256]
|
||||||
|
|
||||||
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
0.11.1
|
1.0rc1
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -62,7 +62,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3[tests,docs]>=0.11.1",
|
"stable_baselines3[tests,docs]>=1.0rc1",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import pytest
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.identity_env import FakeImageEnv
|
from stable_baselines3.common.identity_env import FakeImageEnv
|
||||||
from stable_baselines3.common.utils import zip_strict
|
from stable_baselines3.common.utils import zip_strict
|
||||||
|
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
|
||||||
|
|
||||||
from sb3_contrib import QRDQN, TQC
|
from sb3_contrib import QRDQN, TQC
|
||||||
|
|
||||||
|
|
@ -16,19 +17,37 @@ def test_cnn(tmp_path, model_class):
|
||||||
# Fake grayscale with frameskip
|
# Fake grayscale with frameskip
|
||||||
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
||||||
# to check that the network handle it automatically
|
# to check that the network handle it automatically
|
||||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
env = FakeImageEnv(
|
||||||
|
screen_height=40,
|
||||||
|
screen_width=40,
|
||||||
|
n_channels=1,
|
||||||
|
discrete=model_class not in {TQC},
|
||||||
|
)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if model_class in {TQC, QRDQN}:
|
if model_class in {TQC, QRDQN}:
|
||||||
# Avoid memory error when using replay buffer
|
# Avoid memory error when using replay buffer
|
||||||
# Reduce the size of the features and the number of quantiles
|
# Reduce the size of the features and the number of quantiles
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
buffer_size=250,
|
buffer_size=250,
|
||||||
policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)),
|
policy_kwargs=dict(
|
||||||
|
n_quantiles=25,
|
||||||
|
features_extractor_kwargs=dict(features_dim=32),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||||
|
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
# FakeImageEnv is channel last by default and should be wrapped
|
||||||
|
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)
|
||||||
|
|
||||||
|
# Test stochastic predict with channel last input
|
||||||
|
if model_class == QRDQN:
|
||||||
|
model.exploration_rate = 0.9
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
model.predict(obs, deterministic=False)
|
||||||
|
|
||||||
action, _ = model.predict(obs, deterministic=True)
|
action, _ = model.predict(obs, deterministic=True)
|
||||||
|
|
||||||
model.save(tmp_path / SAVE_NAME)
|
model.save(tmp_path / SAVE_NAME)
|
||||||
|
|
|
||||||
|
|
@ -184,17 +184,17 @@ def test_set_env(model_class):
|
||||||
# create model
|
# create model
|
||||||
model = model_class("MlpPolicy", env, **kwargs)
|
model = model_class("MlpPolicy", env, **kwargs)
|
||||||
# learn
|
# learn
|
||||||
model.learn(total_timesteps=300)
|
model.learn(total_timesteps=150)
|
||||||
|
|
||||||
# change env
|
# change env
|
||||||
model.set_env(env2)
|
model.set_env(env2)
|
||||||
# learn again
|
# learn again
|
||||||
model.learn(total_timesteps=300)
|
model.learn(total_timesteps=150)
|
||||||
|
|
||||||
# change env test wrapping
|
# change env test wrapping
|
||||||
model.set_env(env3)
|
model.set_env(env3)
|
||||||
# learn again
|
# learn again
|
||||||
model.learn(total_timesteps=300)
|
model.learn(total_timesteps=150)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue