Dictionary Observations (#29)
* Add TQC support for new HER version * Add dict obs support * Add support for dict obs
This commit is contained in:
parent
61bfdbc00a
commit
3665695d1e
|
|
@ -3,11 +3,14 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 1.1.0a5 (WIP)
|
Release 1.1.0a6 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Added support for Dictionary observation spaces (cf SB3 doc)
|
||||||
|
- Upgraded to Stable-Baselines3 >= 1.1.0
|
||||||
|
- Added proper handling of timeouts for off-policy algorithms (cf SB3 doc)
|
||||||
|
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ instead of predicting the mean return (DQN).
|
||||||
|
|
||||||
MlpPolicy
|
MlpPolicy
|
||||||
CnnPolicy
|
CnnPolicy
|
||||||
|
MultiInputPolicy
|
||||||
|
|
||||||
|
|
||||||
Notes
|
Notes
|
||||||
|
|
@ -43,6 +44,7 @@ Discrete ✔ ✔
|
||||||
Box ❌ ✔
|
Box ❌ ✔
|
||||||
MultiDiscrete ❌ ✔
|
MultiDiscrete ❌ ✔
|
||||||
MultiBinary ❌ ✔
|
MultiBinary ❌ ✔
|
||||||
|
Dict ❌ ✔️
|
||||||
============= ====== ===========
|
============= ====== ===========
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -149,3 +151,6 @@ QR-DQN Policies
|
||||||
|
|
||||||
.. autoclass:: CnnPolicy
|
.. autoclass:: CnnPolicy
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: MultiInputPolicy
|
||||||
|
:members:
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ It truncates the quantiles predicted by different networks (a bit as it is done
|
||||||
|
|
||||||
MlpPolicy
|
MlpPolicy
|
||||||
CnnPolicy
|
CnnPolicy
|
||||||
|
MultiInputPolicy
|
||||||
|
|
||||||
|
|
||||||
Notes
|
Notes
|
||||||
|
|
@ -42,6 +43,7 @@ Discrete ❌ ✔️
|
||||||
Box ✔️ ✔️
|
Box ✔️ ✔️
|
||||||
MultiDiscrete ❌ ✔️
|
MultiDiscrete ❌ ✔️
|
||||||
MultiBinary ❌ ✔️
|
MultiBinary ❌ ✔️
|
||||||
|
Dict ❌ ✔️
|
||||||
============= ====== ===========
|
============= ====== ===========
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,2 @@
|
||||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy
|
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
from sb3_contrib.qrdqn.qrdqn import QRDQN
|
from sb3_contrib.qrdqn.qrdqn import QRDQN
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,13 @@ from typing import Any, Dict, List, Optional, Type
|
||||||
import gym
|
import gym
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
|
from stable_baselines3.common.torch_layers import (
|
||||||
|
BaseFeaturesExtractor,
|
||||||
|
CombinedExtractor,
|
||||||
|
FlattenExtractor,
|
||||||
|
NatureCNN,
|
||||||
|
create_mlp,
|
||||||
|
)
|
||||||
from stable_baselines3.common.type_aliases import Schedule
|
from stable_baselines3.common.type_aliases import Schedule
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
@ -245,5 +251,54 @@ class CnnPolicy(QRDQNPolicy):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiInputPolicy(QRDQNPolicy):
|
||||||
|
"""
|
||||||
|
Policy class for QR-DQN when using dict observations as input.
|
||||||
|
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param lr_schedule: Learning rate schedule (could be constant)
|
||||||
|
:param n_quantiles: Number of quantiles
|
||||||
|
:param net_arch: The specification of the network architecture.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param features_extractor_class: Features extractor to use.
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
:param optimizer_class: The optimizer to use,
|
||||||
|
``th.optim.Adam`` by default
|
||||||
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
lr_schedule: Schedule,
|
||||||
|
n_quantiles: int = 200,
|
||||||
|
net_arch: Optional[List[int]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||||
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
super(MultiInputPolicy, self).__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
lr_schedule,
|
||||||
|
n_quantiles,
|
||||||
|
net_arch,
|
||||||
|
activation_fn,
|
||||||
|
features_extractor_class,
|
||||||
|
features_extractor_kwargs,
|
||||||
|
normalize_images,
|
||||||
|
optimizer_class,
|
||||||
|
optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_policy("MlpPolicy", MlpPolicy)
|
register_policy("MlpPolicy", MlpPolicy)
|
||||||
register_policy("CnnPolicy", CnnPolicy)
|
register_policy("CnnPolicy", CnnPolicy)
|
||||||
|
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common import logger
|
from stable_baselines3.common import logger
|
||||||
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
|
|
@ -34,6 +35,9 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
(see ``train_freq`` and ``n_episodes_rollout``)
|
(see ``train_freq`` and ``n_episodes_rollout``)
|
||||||
Set to ``-1`` means to do as many gradient steps as steps done in the environment
|
Set to ``-1`` means to do as many gradient steps as steps done in the environment
|
||||||
during the rollout.
|
during the rollout.
|
||||||
|
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
|
||||||
|
If ``None``, it will be automatically selected.
|
||||||
|
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
|
||||||
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
||||||
at a cost of more complexity.
|
at a cost of more complexity.
|
||||||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||||
|
|
@ -59,13 +63,15 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
policy: Union[str, Type[QRDQNPolicy]],
|
policy: Union[str, Type[QRDQNPolicy]],
|
||||||
env: Union[GymEnv, str],
|
env: Union[GymEnv, str],
|
||||||
learning_rate: Union[float, Schedule] = 5e-5,
|
learning_rate: Union[float, Schedule] = 5e-5,
|
||||||
buffer_size: int = 1000000,
|
buffer_size: int = 1000000, # 1e6
|
||||||
learning_starts: int = 50000,
|
learning_starts: int = 50000,
|
||||||
batch_size: Optional[int] = 32,
|
batch_size: Optional[int] = 32,
|
||||||
tau: float = 1.0,
|
tau: float = 1.0,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
train_freq: int = 4,
|
train_freq: int = 4,
|
||||||
gradient_steps: int = 1,
|
gradient_steps: int = 1,
|
||||||
|
replay_buffer_class: Optional[ReplayBuffer] = None,
|
||||||
|
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
optimize_memory_usage: bool = False,
|
optimize_memory_usage: bool = False,
|
||||||
target_update_interval: int = 10000,
|
target_update_interval: int = 10000,
|
||||||
exploration_fraction: float = 0.005,
|
exploration_fraction: float = 0.005,
|
||||||
|
|
@ -94,6 +100,8 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
train_freq,
|
train_freq,
|
||||||
gradient_steps,
|
gradient_steps,
|
||||||
action_noise=None, # No action noise
|
action_noise=None, # No action noise
|
||||||
|
replay_buffer_class=replay_buffer_class,
|
||||||
|
replay_buffer_kwargs=replay_buffer_kwargs,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
|
@ -152,7 +160,7 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
self._update_learning_rate(self.policy.optimizer)
|
self._update_learning_rate(self.policy.optimizer)
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
for gradient_step in range(gradient_steps):
|
for _ in range(gradient_steps):
|
||||||
# Sample replay buffer
|
# Sample replay buffer
|
||||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||||
|
|
||||||
|
|
@ -213,7 +221,10 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
if not deterministic and np.random.rand() < self.exploration_rate:
|
if not deterministic and np.random.rand() < self.exploration_rate:
|
||||||
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
|
||||||
n_batch = observation.shape[0]
|
if isinstance(self.observation_space, gym.spaces.Dict):
|
||||||
|
n_batch = observation[list(observation.keys())[0]].shape[0]
|
||||||
|
else:
|
||||||
|
n_batch = observation.shape[0]
|
||||||
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
||||||
else:
|
else:
|
||||||
action = np.array(self.action_space.sample())
|
action = np.array(self.action_space.sample())
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,2 @@
|
||||||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy
|
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
from sb3_contrib.tqc.tqc import TQC
|
from sb3_contrib.tqc.tqc import TQC
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import torch as th
|
import torch as th
|
||||||
|
|
@ -7,11 +7,13 @@ from stable_baselines3.common.policies import BaseModel, BasePolicy, create_sde_
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
from stable_baselines3.common.torch_layers import (
|
from stable_baselines3.common.torch_layers import (
|
||||||
BaseFeaturesExtractor,
|
BaseFeaturesExtractor,
|
||||||
|
CombinedExtractor,
|
||||||
FlattenExtractor,
|
FlattenExtractor,
|
||||||
NatureCNN,
|
NatureCNN,
|
||||||
create_mlp,
|
create_mlp,
|
||||||
get_actor_critic_arch,
|
get_actor_critic_arch,
|
||||||
)
|
)
|
||||||
|
from stable_baselines3.common.type_aliases import Schedule
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
# CAP the standard deviation of the actor
|
# CAP the standard deviation of the actor
|
||||||
|
|
@ -278,6 +280,8 @@ class TQCPolicy(BasePolicy):
|
||||||
``th.optim.Adam`` by default
|
``th.optim.Adam`` by default
|
||||||
:param optimizer_kwargs: Additional keyword arguments,
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
excluding the learning rate, to pass to the optimizer
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
:param n_quantiles: Number of quantiles for the critic.
|
||||||
|
:param n_critics: Number of critic networks to create.
|
||||||
:param share_features_extractor: Whether to share or not the features extractor
|
:param share_features_extractor: Whether to share or not the features extractor
|
||||||
between the actor and the critic (this saves computation time)
|
between the actor and the critic (this saves computation time)
|
||||||
"""
|
"""
|
||||||
|
|
@ -286,8 +290,8 @@ class TQCPolicy(BasePolicy):
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: gym.spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: gym.spaces.Space,
|
||||||
lr_schedule: Callable,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[int]] = None,
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
log_std_init: float = -3,
|
log_std_init: float = -3,
|
||||||
|
|
@ -353,7 +357,7 @@ class TQCPolicy(BasePolicy):
|
||||||
|
|
||||||
self._build(lr_schedule)
|
self._build(lr_schedule)
|
||||||
|
|
||||||
def _build(self, lr_schedule: Callable) -> None:
|
def _build(self, lr_schedule: Schedule) -> None:
|
||||||
self.actor = self.make_actor()
|
self.actor = self.make_actor()
|
||||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||||
|
|
||||||
|
|
@ -448,6 +452,8 @@ class CnnPolicy(TQCPolicy):
|
||||||
``th.optim.Adam`` by default
|
``th.optim.Adam`` by default
|
||||||
:param optimizer_kwargs: Additional keyword arguments,
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
excluding the learning rate, to pass to the optimizer
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
:param n_quantiles: Number of quantiles for the critic.
|
||||||
|
:param n_critics: Number of critic networks to create.
|
||||||
:param share_features_extractor: Whether to share or not the features extractor
|
:param share_features_extractor: Whether to share or not the features extractor
|
||||||
between the actor and the critic (this saves computation time)
|
between the actor and the critic (this saves computation time)
|
||||||
"""
|
"""
|
||||||
|
|
@ -456,8 +462,8 @@ class CnnPolicy(TQCPolicy):
|
||||||
self,
|
self,
|
||||||
observation_space: gym.spaces.Space,
|
observation_space: gym.spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: gym.spaces.Space,
|
||||||
lr_schedule: Callable,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[List[int]] = None,
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
log_std_init: float = -3,
|
log_std_init: float = -3,
|
||||||
|
|
@ -495,5 +501,80 @@ class CnnPolicy(TQCPolicy):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiInputPolicy(TQCPolicy):
|
||||||
|
"""
|
||||||
|
Policy class (with both actor and critic) for TQC.
|
||||||
|
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param lr_schedule: Learning rate schedule (could be constant)
|
||||||
|
:param net_arch: The specification of the policy and value networks.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param use_sde: Whether to use State Dependent Exploration or not
|
||||||
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
|
:param sde_net_arch: Network architecture for extracting features
|
||||||
|
when using gSDE. If None, the latent features from the policy will be used.
|
||||||
|
Pass an empty list to use the states as features.
|
||||||
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||||
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
|
||||||
|
:param features_extractor_class: Features extractor to use.
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
:param optimizer_class: The optimizer to use,
|
||||||
|
``th.optim.Adam`` by default
|
||||||
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
:param n_quantiles: Number of quantiles for the critic.
|
||||||
|
:param n_critics: Number of critic networks to create.
|
||||||
|
:param share_features_extractor: Whether to share or not the features extractor
|
||||||
|
between the actor and the critic (this saves computation time)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
lr_schedule: Schedule,
|
||||||
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
use_sde: bool = False,
|
||||||
|
log_std_init: float = -3,
|
||||||
|
sde_net_arch: Optional[List[int]] = None,
|
||||||
|
use_expln: bool = False,
|
||||||
|
clip_mean: float = 2.0,
|
||||||
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||||
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
n_quantiles: int = 25,
|
||||||
|
n_critics: int = 2,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
|
):
|
||||||
|
super(MultiInputPolicy, self).__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
lr_schedule,
|
||||||
|
net_arch,
|
||||||
|
activation_fn,
|
||||||
|
use_sde,
|
||||||
|
log_std_init,
|
||||||
|
sde_net_arch,
|
||||||
|
use_expln,
|
||||||
|
clip_mean,
|
||||||
|
features_extractor_class,
|
||||||
|
features_extractor_kwargs,
|
||||||
|
normalize_images,
|
||||||
|
optimizer_class,
|
||||||
|
optimizer_kwargs,
|
||||||
|
n_quantiles,
|
||||||
|
n_critics,
|
||||||
|
share_features_extractor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_policy("MlpPolicy", MlpPolicy)
|
register_policy("MlpPolicy", MlpPolicy)
|
||||||
register_policy("CnnPolicy", CnnPolicy)
|
register_policy("CnnPolicy", CnnPolicy)
|
||||||
|
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common import logger
|
from stable_baselines3.common import logger
|
||||||
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.noise import ActionNoise
|
from stable_baselines3.common.noise import ActionNoise
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||||
|
|
@ -35,6 +36,9 @@ class TQC(OffPolicyAlgorithm):
|
||||||
:param gradient_steps: How many gradient update after each step
|
:param gradient_steps: How many gradient update after each step
|
||||||
:param action_noise: the action noise type (None by default), this can help
|
:param action_noise: the action noise type (None by default), this can help
|
||||||
for hard exploration problem. Cf common.noise for the different action noise type.
|
for hard exploration problem. Cf common.noise for the different action noise type.
|
||||||
|
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
|
||||||
|
If ``None``, it will be automatically selected.
|
||||||
|
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
|
||||||
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
||||||
at a cost of more complexity.
|
at a cost of more complexity.
|
||||||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||||
|
|
@ -66,7 +70,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
policy: Union[str, Type[TQCPolicy]],
|
policy: Union[str, Type[TQCPolicy]],
|
||||||
env: Union[GymEnv, str],
|
env: Union[GymEnv, str],
|
||||||
learning_rate: Union[float, Callable] = 3e-4,
|
learning_rate: Union[float, Callable] = 3e-4,
|
||||||
buffer_size: int = int(1e6),
|
buffer_size: int = 1000000, # 1e6
|
||||||
learning_starts: int = 100,
|
learning_starts: int = 100,
|
||||||
batch_size: int = 256,
|
batch_size: int = 256,
|
||||||
tau: float = 0.005,
|
tau: float = 0.005,
|
||||||
|
|
@ -74,6 +78,8 @@ class TQC(OffPolicyAlgorithm):
|
||||||
train_freq: int = 1,
|
train_freq: int = 1,
|
||||||
gradient_steps: int = 1,
|
gradient_steps: int = 1,
|
||||||
action_noise: Optional[ActionNoise] = None,
|
action_noise: Optional[ActionNoise] = None,
|
||||||
|
replay_buffer_class: Optional[ReplayBuffer] = None,
|
||||||
|
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
optimize_memory_usage: bool = False,
|
optimize_memory_usage: bool = False,
|
||||||
ent_coef: Union[str, float] = "auto",
|
ent_coef: Union[str, float] = "auto",
|
||||||
target_update_interval: int = 1,
|
target_update_interval: int = 1,
|
||||||
|
|
@ -103,7 +109,9 @@ class TQC(OffPolicyAlgorithm):
|
||||||
gamma,
|
gamma,
|
||||||
train_freq,
|
train_freq,
|
||||||
gradient_steps,
|
gradient_steps,
|
||||||
action_noise,
|
action_noise=action_noise,
|
||||||
|
replay_buffer_class=replay_buffer_class,
|
||||||
|
replay_buffer_kwargs=replay_buffer_kwargs,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.1.0a5
|
1.1.0a6
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,8 @@ license_file = LICENSE
|
||||||
env =
|
env =
|
||||||
PYTHONHASHSEED=0
|
PYTHONHASHSEED=0
|
||||||
filterwarnings =
|
filterwarnings =
|
||||||
# Tensorboard/Tensorflow warnings
|
# Tensorboard warnings
|
||||||
ignore:inspect.getargspec:DeprecationWarning:tensorflow
|
ignore::DeprecationWarning:tensorboard
|
||||||
ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning
|
|
||||||
ignore:The binary mode of fromstring is deprecated:DeprecationWarning
|
|
||||||
ignore::FutureWarning:tensorflow
|
|
||||||
# Gym warnings
|
# Gym warnings
|
||||||
ignore:Parameters to load are deprecated.:DeprecationWarning
|
ignore:Parameters to load are deprecated.:DeprecationWarning
|
||||||
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
|
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -62,7 +62,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.0",
|
"stable_baselines3>=1.1.0a6",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from copy import deepcopy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.identity_env import FakeImageEnv
|
from stable_baselines3.common.envs import FakeImageEnv
|
||||||
from stable_baselines3.common.utils import zip_strict
|
from stable_baselines3.common.utils import zip_strict
|
||||||
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
|
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,258 @@
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from gym import spaces
|
||||||
|
from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv
|
||||||
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
|
||||||
|
|
||||||
|
from sb3_contrib import QRDQN, TQC
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDictEnv(gym.Env):
|
||||||
|
"""Custom Environment for testing purposes only"""
|
||||||
|
|
||||||
|
metadata = {"render.modes": ["human"]}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
use_discrete_actions=False,
|
||||||
|
channel_last=False,
|
||||||
|
nested_dict_obs=False,
|
||||||
|
vec_only=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if use_discrete_actions:
|
||||||
|
self.action_space = spaces.Discrete(3)
|
||||||
|
else:
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||||
|
N_CHANNELS = 1
|
||||||
|
HEIGHT = 64
|
||||||
|
WIDTH = 64
|
||||||
|
|
||||||
|
if channel_last:
|
||||||
|
obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
|
||||||
|
else:
|
||||||
|
obs_shape = (N_CHANNELS, HEIGHT, WIDTH)
|
||||||
|
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
# Image obs
|
||||||
|
"img": spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8),
|
||||||
|
# Vector obs
|
||||||
|
"vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
|
||||||
|
# Discrete obs
|
||||||
|
"discrete": spaces.Discrete(4),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# For checking consistency with normal MlpPolicy
|
||||||
|
if vec_only:
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
# Vector obs
|
||||||
|
"vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if nested_dict_obs:
|
||||||
|
# Add dictionary observation inside observation space
|
||||||
|
self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)})
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
if seed is not None:
|
||||||
|
self.observation_space.seed(seed)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
reward = 0.0
|
||||||
|
done = False
|
||||||
|
return self.observation_space.sample(), reward, done, {}
|
||||||
|
|
||||||
|
def compute_reward(self, achieved_goal, desired_goal, info):
|
||||||
|
return np.zeros((len(achieved_goal),))
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
def render(self, mode="human"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
||||||
|
def test_consistency(model_class):
|
||||||
|
"""
|
||||||
|
Make sure that dict obs with vector only vs using flatten obs is equivalent.
|
||||||
|
This ensures notable that the network architectures are the same.
|
||||||
|
"""
|
||||||
|
use_discrete_actions = model_class == QRDQN
|
||||||
|
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
|
||||||
|
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
|
||||||
|
env = gym.wrappers.FlattenObservation(dict_env)
|
||||||
|
dict_env.seed(10)
|
||||||
|
obs = dict_env.reset()
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
n_steps = 256
|
||||||
|
|
||||||
|
if model_class in {}:
|
||||||
|
kwargs = dict(
|
||||||
|
n_steps=128,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Avoid memory error when using replay buffer
|
||||||
|
# Reduce the size of the features and make learning faster
|
||||||
|
kwargs = dict(
|
||||||
|
buffer_size=250,
|
||||||
|
train_freq=8,
|
||||||
|
gradient_steps=1,
|
||||||
|
)
|
||||||
|
if model_class == QRDQN:
|
||||||
|
kwargs["learning_starts"] = 0
|
||||||
|
|
||||||
|
dict_model = model_class("MultiInputPolicy", dict_env, gamma=0.5, seed=1, **kwargs)
|
||||||
|
action_before_learning_1, _ = dict_model.predict(obs, deterministic=True)
|
||||||
|
dict_model.learn(total_timesteps=n_steps)
|
||||||
|
|
||||||
|
normal_model = model_class("MlpPolicy", env, gamma=0.5, seed=1, **kwargs)
|
||||||
|
action_before_learning_2, _ = normal_model.predict(obs["vec"], deterministic=True)
|
||||||
|
normal_model.learn(total_timesteps=n_steps)
|
||||||
|
|
||||||
|
action_1, _ = dict_model.predict(obs, deterministic=True)
|
||||||
|
action_2, _ = normal_model.predict(obs["vec"], deterministic=True)
|
||||||
|
|
||||||
|
assert np.allclose(action_before_learning_1, action_before_learning_2)
|
||||||
|
assert np.allclose(action_1, action_2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
||||||
|
@pytest.mark.parametrize("channel_last", [False, True])
|
||||||
|
def test_dict_spaces(model_class, channel_last):
|
||||||
|
"""
|
||||||
|
Additional tests to check observation space support
|
||||||
|
with mixed observation.
|
||||||
|
"""
|
||||||
|
use_discrete_actions = model_class not in [TQC]
|
||||||
|
env = DummyDictEnv(use_discrete_actions=use_discrete_actions, channel_last=channel_last)
|
||||||
|
env = gym.wrappers.TimeLimit(env, 100)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
n_steps = 256
|
||||||
|
|
||||||
|
if model_class in {}:
|
||||||
|
kwargs = dict(
|
||||||
|
n_steps=128,
|
||||||
|
policy_kwargs=dict(
|
||||||
|
net_arch=[32],
|
||||||
|
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Avoid memory error when using replay buffer
|
||||||
|
# Reduce the size of the features and make learning faster
|
||||||
|
kwargs = dict(
|
||||||
|
buffer_size=250,
|
||||||
|
policy_kwargs=dict(
|
||||||
|
net_arch=[32],
|
||||||
|
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||||
|
n_quantiles=20,
|
||||||
|
),
|
||||||
|
train_freq=8,
|
||||||
|
gradient_steps=1,
|
||||||
|
)
|
||||||
|
if model_class == QRDQN:
|
||||||
|
kwargs["learning_starts"] = 0
|
||||||
|
|
||||||
|
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
|
||||||
|
|
||||||
|
model.learn(total_timesteps=n_steps)
|
||||||
|
|
||||||
|
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
||||||
|
@pytest.mark.parametrize("channel_last", [False, True])
|
||||||
|
def test_dict_vec_framestack(model_class, channel_last):
|
||||||
|
"""
|
||||||
|
Additional tests to check observation space support
|
||||||
|
for Dictionary spaces and VecEnvWrapper using MultiInputPolicy.
|
||||||
|
"""
|
||||||
|
use_discrete_actions = model_class not in [TQC]
|
||||||
|
channels_order = {"vec": None, "img": "last" if channel_last else "first"}
|
||||||
|
env = DummyVecEnv(
|
||||||
|
[lambda: SimpleMultiObsEnv(random_start=True, discrete_actions=use_discrete_actions, channel_last=channel_last)]
|
||||||
|
)
|
||||||
|
|
||||||
|
env = VecFrameStack(env, n_stack=3, channels_order=channels_order)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
n_steps = 256
|
||||||
|
|
||||||
|
if model_class in {}:
|
||||||
|
kwargs = dict(
|
||||||
|
n_steps=128,
|
||||||
|
policy_kwargs=dict(
|
||||||
|
net_arch=[32],
|
||||||
|
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Avoid memory error when using replay buffer
|
||||||
|
# Reduce the size of the features and make learning faster
|
||||||
|
kwargs = dict(
|
||||||
|
buffer_size=250,
|
||||||
|
policy_kwargs=dict(
|
||||||
|
net_arch=[32],
|
||||||
|
features_extractor_kwargs=dict(cnn_output_dim=32),
|
||||||
|
n_quantiles=20,
|
||||||
|
),
|
||||||
|
train_freq=8,
|
||||||
|
gradient_steps=1,
|
||||||
|
)
|
||||||
|
if model_class == QRDQN:
|
||||||
|
kwargs["learning_starts"] = 0
|
||||||
|
|
||||||
|
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
|
||||||
|
|
||||||
|
model.learn(total_timesteps=n_steps)
|
||||||
|
|
||||||
|
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_class", [QRDQN, TQC])
|
||||||
|
def test_vec_normalize(model_class):
|
||||||
|
"""
|
||||||
|
Additional tests to check observation space support
|
||||||
|
for GoalEnv and VecNormalize using MultiInputPolicy.
|
||||||
|
"""
|
||||||
|
env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not (model_class == QRDQN))])
|
||||||
|
env = VecNormalize(env)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
n_steps = 256
|
||||||
|
|
||||||
|
if model_class in {}:
|
||||||
|
kwargs = dict(
|
||||||
|
n_steps=128,
|
||||||
|
policy_kwargs=dict(
|
||||||
|
net_arch=[32],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Avoid memory error when using replay buffer
|
||||||
|
# Reduce the size of the features and make learning faster
|
||||||
|
kwargs = dict(
|
||||||
|
buffer_size=250,
|
||||||
|
policy_kwargs=dict(
|
||||||
|
net_arch=[32],
|
||||||
|
),
|
||||||
|
train_freq=8,
|
||||||
|
gradient_steps=1,
|
||||||
|
)
|
||||||
|
if model_class == QRDQN:
|
||||||
|
kwargs["learning_starts"] = 0
|
||||||
|
|
||||||
|
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
|
||||||
|
|
||||||
|
model.learn(total_timesteps=n_steps)
|
||||||
|
|
||||||
|
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
|
||||||
|
|
@ -8,7 +8,7 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||||
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||||
from stable_baselines3.common.utils import get_device
|
from stable_baselines3.common.utils import get_device
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||||
|
|
||||||
|
|
@ -243,6 +243,7 @@ def test_save_load_replay_buffer(tmp_path, model_class):
|
||||||
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
|
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
|
||||||
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
|
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
|
||||||
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
|
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
|
||||||
|
infos = [[{"TimeLimit.truncated": truncated}] for truncated in old_replay_buffer.timeouts]
|
||||||
|
|
||||||
# test extending replay buffer
|
# test extending replay buffer
|
||||||
model.replay_buffer.extend(
|
model.replay_buffer.extend(
|
||||||
|
|
@ -251,6 +252,7 @@ def test_save_load_replay_buffer(tmp_path, model_class):
|
||||||
old_replay_buffer.actions,
|
old_replay_buffer.actions,
|
||||||
old_replay_buffer.rewards,
|
old_replay_buffer.rewards,
|
||||||
old_replay_buffer.dones,
|
old_replay_buffer.dones,
|
||||||
|
infos,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,13 @@ import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
|
|
||||||
from stable_baselines3.common.env_checker import check_env
|
from stable_baselines3.common.env_checker import check_env
|
||||||
|
from stable_baselines3.common.envs import BitFlippingEnv
|
||||||
|
|
||||||
from sb3_contrib.common.wrappers import TimeFeatureWrapper
|
from sb3_contrib.common.wrappers import TimeFeatureWrapper
|
||||||
|
|
||||||
|
|
||||||
class CustomGoalEnv(gym.GoalEnv):
|
class CustomGoalEnv(gym.GoalEnv):
|
||||||
"""docstring for CustomGoalEnv."""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CustomGoalEnv, self).__init__()
|
super(CustomGoalEnv, self).__init__()
|
||||||
self.observation_space = spaces.Dict(
|
self.observation_space = spaces.Dict(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue