parent
7c2eb833c0
commit
9824daca44
|
|
@ -16,7 +16,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7] # 3.8 not supported yet by pytype
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
|
@ -28,7 +28,7 @@ jobs:
|
|||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# cpu version of pytorch
|
||||
pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install torch==1.8.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
# Install dependencies for docs and tests
|
||||
pip install stable_baselines3[extra,tests,docs]
|
||||
# Install master version
|
||||
|
|
|
|||
|
|
@ -3,6 +3,18 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.0rc1 (WIP)
|
||||
-------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Upgraded to Stable-Baselines3 >= 1.0rc1
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed a bug with ``QR-DQN`` predict method when using ``deterministic=False`` with image space
|
||||
|
||||
|
||||
Pre-Release 0.11.1 (2021-02-27)
|
||||
-------------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -129,10 +129,10 @@ class QRDQNPolicy(BasePolicy):
|
|||
)
|
||||
|
||||
if net_arch is None:
|
||||
if features_extractor_class == FlattenExtractor:
|
||||
net_arch = [64, 64]
|
||||
else:
|
||||
if features_extractor_class == NatureCNN:
|
||||
net_arch = []
|
||||
else:
|
||||
net_arch = [64, 64]
|
||||
|
||||
self.n_quantiles = n_quantiles
|
||||
self.net_arch = net_arch
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import numpy as np
|
|||
import torch as th
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
||||
|
||||
|
|
@ -211,7 +212,7 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
(used in recurrent policies)
|
||||
"""
|
||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||
if is_vectorized_observation(observation, self.observation_space):
|
||||
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
||||
n_batch = observation.shape[0]
|
||||
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -314,10 +314,10 @@ class TQCPolicy(BasePolicy):
|
|||
)
|
||||
|
||||
if net_arch is None:
|
||||
if features_extractor_class == FlattenExtractor:
|
||||
net_arch = [256, 256]
|
||||
else:
|
||||
if features_extractor_class == NatureCNN:
|
||||
net_arch = []
|
||||
else:
|
||||
net_arch = [256, 256]
|
||||
|
||||
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.11.1
|
||||
1.0rc1
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -62,7 +62,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3[tests,docs]>=0.11.1",
|
||||
"stable_baselines3[tests,docs]>=1.0rc1",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
import torch as th
|
||||
from stable_baselines3.common.identity_env import FakeImageEnv
|
||||
from stable_baselines3.common.utils import zip_strict
|
||||
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
|
||||
|
||||
from sb3_contrib import QRDQN, TQC
|
||||
|
||||
|
|
@ -16,19 +17,37 @@ def test_cnn(tmp_path, model_class):
|
|||
# Fake grayscale with frameskip
|
||||
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
||||
# to check that the network handle it automatically
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
|
||||
env = FakeImageEnv(
|
||||
screen_height=40,
|
||||
screen_width=40,
|
||||
n_channels=1,
|
||||
discrete=model_class not in {TQC},
|
||||
)
|
||||
kwargs = {}
|
||||
if model_class in {TQC, QRDQN}:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features and the number of quantiles
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)),
|
||||
policy_kwargs=dict(
|
||||
n_quantiles=25,
|
||||
features_extractor_kwargs=dict(features_dim=32),
|
||||
),
|
||||
)
|
||||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||
|
||||
obs = env.reset()
|
||||
|
||||
# FakeImageEnv is channel last by default and should be wrapped
|
||||
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)
|
||||
|
||||
# Test stochastic predict with channel last input
|
||||
if model_class == QRDQN:
|
||||
model.exploration_rate = 0.9
|
||||
|
||||
for _ in range(10):
|
||||
model.predict(obs, deterministic=False)
|
||||
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
|
||||
model.save(tmp_path / SAVE_NAME)
|
||||
|
|
|
|||
|
|
@ -184,17 +184,17 @@ def test_set_env(model_class):
|
|||
# create model
|
||||
model = model_class("MlpPolicy", env, **kwargs)
|
||||
# learn
|
||||
model.learn(total_timesteps=300)
|
||||
model.learn(total_timesteps=150)
|
||||
|
||||
# change env
|
||||
model.set_env(env2)
|
||||
# learn again
|
||||
model.learn(total_timesteps=300)
|
||||
model.learn(total_timesteps=150)
|
||||
|
||||
# change env test wrapping
|
||||
model.set_env(env3)
|
||||
# learn again
|
||||
model.learn(total_timesteps=300)
|
||||
model.learn(total_timesteps=150)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
|
|
|
|||
Loading…
Reference in New Issue