Add Gymnasium support (#152)
* Add support for Gym 0.24 * Fixes for gym 0.24 * Fix for new reset signature * Add tmp SB3 branch * Fixes for gym 0.26 * Remove unused import * Fix dependency * Type annotations fixes * Reformat * Reformat with black 23 * Move to gymnasium * Patch env if needed * Fix types * Fix CI * Fixes for gymnasium * Fix wrapper annotations * Update version * Fix type check * Update QRDQN type hints and bug fix with multi envs * Fix TQC type hints * Fix TRPO type hints * Additional fixes * Update SB3 version * Update issue templates and CI --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
parent
a84ad3aa7d
commit
21cc96cafd
|
|
@ -50,7 +50,8 @@ body:
|
|||
* GPU models and configuration
|
||||
* Python version
|
||||
* PyTorch version
|
||||
* Gym version
|
||||
* Gymnasium version
|
||||
* (if installed) OpenAI Gym version
|
||||
* Versions of any other relevant libraries
|
||||
|
||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||
|
|
|
|||
|
|
@ -34,9 +34,9 @@ body:
|
|||
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||
value: |
|
||||
```python
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
|
|
@ -49,15 +49,16 @@ body:
|
|||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
def reset(self, seed=None, options=None):
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
def step(self, action):
|
||||
obs = self.observation_space.sample()
|
||||
reward = 1.0
|
||||
done = False
|
||||
terminated = False
|
||||
truncated = False
|
||||
info = {}
|
||||
return obs, reward, done, info
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
env = CustomEnv()
|
||||
check_env(env)
|
||||
|
|
@ -84,7 +85,8 @@ body:
|
|||
* GPU models and configuration
|
||||
* Python version
|
||||
* PyTorch version
|
||||
* Gym version
|
||||
* Gymnasium version
|
||||
* (if installed) OpenAI Gym version
|
||||
* Versions of any other relevant libraries
|
||||
|
||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,42 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.0.0a4 (WIP)
|
||||
--------------------------
|
||||
|
||||
**Gymnasium support**
|
||||
|
||||
.. warning::
|
||||
|
||||
Stable-Baselines3 (SB3) v2.0 will be the last one supporting python 3.7 (end of life in June 2023).
|
||||
We highly recommended you to upgrade to Python >= 3.8.
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the ``shimmy`` package (@carlosluis, @arjun-kg, @tlpss)
|
||||
- Upgraded to Stable-Baselines3 >= 2.0.0
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed QRDQN update interval for multi envs
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Fixed ``sb3_contrib/tqc/*.py`` type hints
|
||||
- Fixed ``sb3_contrib/trpo/*.py`` type hints
|
||||
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Release 1.8.0 (2023-04-07)
|
||||
--------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ to specify the name (see `PR #25 <https://github.com/Stable-Baselines-Team/stabl
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ Example
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
|
||||
from sb3_contrib import QRDQN
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ Example
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from sb3_contrib import TQC
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ Example
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from sb3_contrib import TRPO
|
||||
|
|
|
|||
|
|
@ -37,9 +37,6 @@ exclude = """(?x)(
|
|||
| sb3_contrib/common/maskable/buffers.py$
|
||||
| sb3_contrib/common/envs/invalid_actions_env.py$
|
||||
| sb3_contrib/common/vec_env/async_eval.py$
|
||||
| sb3_contrib/tqc/tqc.py$
|
||||
| sb3_contrib/tqc/policies.py$
|
||||
| sb3_contrib/trpo/trpo.py$
|
||||
| sb3_contrib/ppo_mask/ppo_mask.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
)"""
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any, Dict, Optional, Type, TypeVar, Union
|
|||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn.utils
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
|
||||
|
|
@ -50,7 +50,7 @@ class ARSPolicy(BasePolicy):
|
|||
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
|
||||
)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
|
||||
actor_net = create_mlp(self.features_dim, int(action_space.n), net_arch, activation_fn, with_bias=with_bias)
|
||||
else:
|
||||
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.envs import IdentityEnv
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Generator, NamedTuple, Optional, Union
|
|||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||
from stable_baselines3.common.type_aliases import TensorDict
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, TypeVar
|
|||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.distributions import Distribution
|
||||
from torch import nn
|
||||
from torch.distributions import Categorical
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
|
||||
|
|
@ -64,7 +64,7 @@ def evaluate_policy(
|
|||
is_monitor_wrapped = False
|
||||
|
||||
if not isinstance(env, VecEnv):
|
||||
env = DummyVecEnv([lambda: env])
|
||||
env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value]
|
||||
|
||||
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
|
||||
|
||||
|
|
@ -93,7 +93,7 @@ def evaluate_policy(
|
|||
if use_masking:
|
||||
action_masks = get_action_masks(env)
|
||||
actions, state = model.predict(
|
||||
observations,
|
||||
observations, # type: ignore[arg-type]
|
||||
state=states,
|
||||
episode_start=episode_starts,
|
||||
deterministic=deterministic,
|
||||
|
|
@ -101,7 +101,10 @@ def evaluate_policy(
|
|||
)
|
||||
else:
|
||||
actions, states = model.predict(
|
||||
observations, state=states, episode_start=episode_starts, deterministic=deterministic
|
||||
observations, # type: ignore[arg-type]
|
||||
state=states,
|
||||
episode_start=episode_starts,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
observations, rewards, dones, infos = env.step(actions)
|
||||
current_rewards += rewards
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Callable, Generator, Optional, Tuple, Union
|
|||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.distributions import Distribution
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import torch as th
|
|||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
||||
from stable_baselines3.common.utils import compat_gym_seed
|
||||
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
|
||||
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
|
||||
|
||||
|
|
@ -56,7 +57,7 @@ def _worker(
|
|||
results.append((weights_idx, (episode_rewards, episode_lengths)))
|
||||
remote.send(results)
|
||||
elif cmd == "seed":
|
||||
remote.send(env.seed(data))
|
||||
remote.send(compat_gym_seed(env, seed=data))
|
||||
elif cmd == "get_obs_rms":
|
||||
remote.send(obs_rms)
|
||||
elif cmd == "sync_obs_rms":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Callable, Union
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
from typing import Dict, Union
|
||||
from typing import Any, Dict, SupportsFloat, Tuple, Union
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType
|
||||
|
||||
TimeFeatureObs = Union[np.ndarray, Dict[str, np.ndarray]]
|
||||
|
||||
|
||||
class TimeFeatureWrapper(gym.Wrapper):
|
||||
class TimeFeatureWrapper(gym.Wrapper[TimeFeatureObs, ActType, TimeFeatureObs, ActType]):
|
||||
"""
|
||||
Add remaining, normalized time to observation space for fixed length episodes.
|
||||
See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13.
|
||||
|
|
@ -33,27 +35,30 @@ class TimeFeatureWrapper(gym.Wrapper):
|
|||
if isinstance(env.observation_space, spaces.Dict):
|
||||
assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
|
||||
obs_space = env.observation_space.spaces["observation"]
|
||||
assert isinstance(obs_space, spaces.Box), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
|
||||
obs_space = env.observation_space.spaces["observation"]
|
||||
else:
|
||||
obs_space = env.observation_space
|
||||
|
||||
assert isinstance(obs_space, gym.spaces.Box), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
|
||||
assert len(obs_space.shape) == 1, "Only 1D observation spaces are supported"
|
||||
|
||||
low, high = obs_space.low, obs_space.high
|
||||
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
|
||||
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0])) # type: ignore[arg-type]
|
||||
self.dtype = obs_space.dtype
|
||||
|
||||
if isinstance(env.observation_space, spaces.Dict):
|
||||
env.observation_space.spaces["observation"] = spaces.Box(low=low, high=high, dtype=self.dtype)
|
||||
env.observation_space.spaces["observation"] = spaces.Box(
|
||||
low=low,
|
||||
high=high,
|
||||
dtype=self.dtype, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
env.observation_space = spaces.Box(low=low, high=high, dtype=self.dtype)
|
||||
env.observation_space = spaces.Box(low=low, high=high, dtype=self.dtype) # type: ignore[arg-type]
|
||||
|
||||
super().__init__(env)
|
||||
|
||||
# Try to infer the max number of steps per episode
|
||||
try:
|
||||
self._max_steps = env.spec.max_episode_steps
|
||||
self._max_steps = env.spec.max_episode_steps # type: ignore[union-attr]
|
||||
except AttributeError:
|
||||
self._max_steps = None
|
||||
|
||||
|
|
@ -64,14 +69,15 @@ class TimeFeatureWrapper(gym.Wrapper):
|
|||
self._current_step = 0
|
||||
self._test_mode = test_mode
|
||||
|
||||
def reset(self) -> GymObs:
|
||||
def reset(self, **kwargs) -> Tuple[TimeFeatureObs, Dict[str, Any]]:
|
||||
self._current_step = 0
|
||||
return self._get_obs(self.env.reset())
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
return self._get_obs(obs), info
|
||||
|
||||
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
|
||||
def step(self, action: ActType) -> Tuple[TimeFeatureObs, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
self._current_step += 1
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, done, info
|
||||
obs, reward, done, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, done, truncated, info
|
||||
|
||||
def _get_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||
"""
|
||||
|
|
@ -80,11 +86,13 @@ class TimeFeatureWrapper(gym.Wrapper):
|
|||
:param obs:
|
||||
:return:
|
||||
"""
|
||||
# for mypy
|
||||
assert self._max_steps is not None
|
||||
# Remaining time is more general
|
||||
time_feature = 1 - (self._current_step / self._max_steps)
|
||||
if self._test_mode:
|
||||
time_feature = 1.0
|
||||
time_feature = np.array(time_feature, dtype=self.dtype)
|
||||
time_feature = np.array(time_feature, dtype=self.dtype) # type: ignore[assignment]
|
||||
|
||||
if isinstance(obs, dict):
|
||||
obs["observation"] = np.append(obs["observation"], time_feature)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
|||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common import utils
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
|
||||
|
|
@ -232,6 +232,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
# Make sure training timesteps are ahead of the internal counter
|
||||
total_timesteps += self.num_timesteps
|
||||
self._total_timesteps = total_timesteps
|
||||
self._num_timesteps_at_start = self.num_timesteps
|
||||
|
||||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
||||
if reset_num_timesteps or self._last_obs is None:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Type, TypeVar, Union
|
|||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
|
|
@ -27,10 +27,12 @@ class QuantileNetwork(BasePolicy):
|
|||
dividing by 255.0 (True by default)
|
||||
"""
|
||||
|
||||
action_space: spaces.Discrete
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
action_space: spaces.Discrete,
|
||||
features_extractor: BaseFeaturesExtractor,
|
||||
features_dim: int,
|
||||
n_quantiles: int = 200,
|
||||
|
|
@ -52,7 +54,7 @@ class QuantileNetwork(BasePolicy):
|
|||
self.activation_fn = activation_fn
|
||||
self.features_dim = features_dim
|
||||
self.n_quantiles = n_quantiles
|
||||
action_dim = self.action_space.n # number of actions
|
||||
action_dim = int(self.action_space.n) # number of actions
|
||||
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
|
||||
self.quantile_net = nn.Sequential(*quantile_net)
|
||||
|
||||
|
|
@ -63,10 +65,8 @@ class QuantileNetwork(BasePolicy):
|
|||
:param obs: Observation
|
||||
:return: The estimated quantiles for each action.
|
||||
"""
|
||||
# For type checker:
|
||||
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
|
||||
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
|
||||
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
|
||||
return quantiles.view(-1, self.n_quantiles, int(self.action_space.n))
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
q_values = self(observation).mean(dim=1)
|
||||
|
|
@ -110,10 +110,13 @@ class QRDQNPolicy(BasePolicy):
|
|||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
quantile_net: QuantileNetwork
|
||||
quantile_net_target: QuantileNetwork
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
action_space: spaces.Discrete,
|
||||
lr_schedule: Schedule,
|
||||
n_quantiles: int = 200,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
|
|
@ -152,9 +155,6 @@ class QRDQNPolicy(BasePolicy):
|
|||
"activation_fn": self.activation_fn,
|
||||
"normalize_images": normalize_images,
|
||||
}
|
||||
|
||||
self.quantile_net: QuantileNetwork
|
||||
self.quantile_net_target: QuantileNetwork
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule: Schedule) -> None:
|
||||
|
|
@ -239,7 +239,7 @@ class CnnPolicy(QRDQNPolicy):
|
|||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
action_space: spaces.Discrete,
|
||||
lr_schedule: Schedule,
|
||||
n_quantiles: int = 200,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
|
|
@ -287,7 +287,7 @@ class MultiInputPolicy(QRDQNPolicy):
|
|||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
action_space: spaces.Discrete,
|
||||
lr_schedule: Schedule,
|
||||
n_quantiles: int = 200,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
|
|
@ -10,7 +11,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
|
||||
|
||||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy, QuantileNetwork
|
||||
|
||||
SelfQRDQN = TypeVar("SelfQRDQN", bound="QRDQN")
|
||||
|
||||
|
|
@ -64,6 +65,11 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
}
|
||||
# Linear schedule will be defined in `_setup_model()`
|
||||
exploration_schedule: Schedule
|
||||
quantile_net: QuantileNetwork
|
||||
quantile_net_target: QuantileNetwork
|
||||
policy: QRDQNPolicy
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -123,13 +129,11 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
self.exploration_final_eps = exploration_final_eps
|
||||
self.exploration_fraction = exploration_fraction
|
||||
self.target_update_interval = target_update_interval
|
||||
# For updating the target network with multiple envs:
|
||||
self._n_calls = 0
|
||||
self.max_grad_norm = max_grad_norm
|
||||
# "epsilon" for the epsilon-greedy exploration
|
||||
self.exploration_rate = 0.0
|
||||
# Linear schedule will be defined in `_setup_model()`
|
||||
self.exploration_schedule: Schedule
|
||||
self.quantile_net: th.nn.Module
|
||||
self.quantile_net_target: th.nn.Module
|
||||
|
||||
if "optimizer_class" not in self.policy_kwargs:
|
||||
self.policy_kwargs["optimizer_class"] = th.optim.Adam
|
||||
|
|
@ -148,10 +152,20 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
self.exploration_schedule = get_linear_fn(
|
||||
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
|
||||
)
|
||||
# Account for multiple environments
|
||||
# each call to step() corresponds to n_envs transitions
|
||||
if self.n_envs > 1:
|
||||
if self.n_envs > self.target_update_interval:
|
||||
warnings.warn(
|
||||
"The number of environments used is greater than the target network "
|
||||
f"update interval ({self.n_envs} > {self.target_update_interval}), "
|
||||
"therefore the target network will be updated after each call to env.step() "
|
||||
f"which corresponds to {self.n_envs} steps."
|
||||
)
|
||||
|
||||
self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)
|
||||
|
||||
def _create_aliases(self) -> None:
|
||||
# For type checker:
|
||||
assert isinstance(self.policy, QRDQNPolicy)
|
||||
self.quantile_net = self.policy.quantile_net
|
||||
self.quantile_net_target = self.policy.quantile_net_target
|
||||
self.n_quantiles = self.policy.n_quantiles
|
||||
|
|
@ -161,7 +175,8 @@ class QRDQN(OffPolicyAlgorithm):
|
|||
Update the exploration rate and target network if needed.
|
||||
This method is called in ``collect_rollouts()`` after each step in the environment.
|
||||
"""
|
||||
if self.num_timesteps % self.target_update_interval == 0:
|
||||
self._n_calls += 1
|
||||
if self._n_calls % self.target_update_interval == 0:
|
||||
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
||||
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
|
||||
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
|
@ -44,10 +44,12 @@ class Actor(BasePolicy):
|
|||
dividing by 255.0 (True by default)
|
||||
"""
|
||||
|
||||
action_space: spaces.Box
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
|
|
@ -95,9 +97,9 @@ class Actor(BasePolicy):
|
|||
if clip_mean > 0.0:
|
||||
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
|
||||
else:
|
||||
self.action_dist = SquashedDiagGaussianDistribution(action_dim)
|
||||
self.action_dist = SquashedDiagGaussianDistribution(action_dim) # type: ignore[assignment]
|
||||
self.mu = nn.Linear(last_layer_dim, action_dim)
|
||||
self.log_std = nn.Linear(last_layer_dim, action_dim)
|
||||
self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment]
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
|
@ -156,7 +158,7 @@ class Actor(BasePolicy):
|
|||
if self.use_sde:
|
||||
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
|
||||
# Unstructured exploration (Original implementation)
|
||||
log_std = self.log_std(latent_pi)
|
||||
log_std = self.log_std(latent_pi) # type: ignore[operator]
|
||||
# Original Implementation to cap the standard deviation
|
||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||
return mean_actions, log_std, {}
|
||||
|
|
@ -192,12 +194,15 @@ class Critic(BaseModel):
|
|||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
action_space: spaces.Box
|
||||
features_extractor: BaseFeaturesExtractor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_extractor: BaseFeaturesExtractor,
|
||||
features_dim: int,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
|
|
@ -221,12 +226,12 @@ class Critic(BaseModel):
|
|||
self.quantiles_total = n_quantiles * n_critics
|
||||
|
||||
for i in range(n_critics):
|
||||
qf_net = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
|
||||
qf_net = nn.Sequential(*qf_net)
|
||||
qf_net_list = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
|
||||
qf_net = nn.Sequential(*qf_net_list)
|
||||
self.add_module(f"qf{i}", qf_net)
|
||||
self.q_networks.append(qf_net)
|
||||
|
||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
|
||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor:
|
||||
# Learn the features extractor using the policy loss only
|
||||
# when the features_extractor is shared with the actor
|
||||
with th.set_grad_enabled(not self.share_features_extractor):
|
||||
|
|
@ -266,10 +271,14 @@ class TQCPolicy(BasePolicy):
|
|||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
actor: Actor
|
||||
critic: Critic
|
||||
critic_target: Critic
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
@ -328,15 +337,17 @@ class TQCPolicy(BasePolicy):
|
|||
"share_features_extractor": share_features_extractor,
|
||||
}
|
||||
self.critic_kwargs.update(tqc_kwargs)
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
self.share_features_extractor = share_features_extractor
|
||||
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule: Schedule) -> None:
|
||||
self.actor = self.make_actor()
|
||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
self.actor.optimizer = self.optimizer_class( # type: ignore[call-arg]
|
||||
self.actor.parameters(),
|
||||
lr=lr_schedule(1),
|
||||
**self.optimizer_kwargs,
|
||||
)
|
||||
|
||||
if self.share_features_extractor:
|
||||
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
|
||||
|
|
@ -347,7 +358,7 @@ class TQCPolicy(BasePolicy):
|
|||
# Create a separate features extractor for the critic
|
||||
# this requires more memory and computation
|
||||
self.critic = self.make_critic(features_extractor=None)
|
||||
critic_parameters = self.critic.parameters()
|
||||
critic_parameters = list(self.critic.parameters())
|
||||
|
||||
# Critic target should not share the feature extactor with critic
|
||||
self.critic_target = self.make_critic(features_extractor=None)
|
||||
|
|
@ -356,7 +367,11 @@ class TQCPolicy(BasePolicy):
|
|||
# Target networks should always be in eval mode
|
||||
self.critic_target.set_training_mode(False)
|
||||
|
||||
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
self.critic.optimizer = self.optimizer_class( # type: ignore[call-arg]
|
||||
critic_parameters,
|
||||
lr=lr_schedule(1),
|
||||
**self.optimizer_kwargs,
|
||||
)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
|
@ -447,7 +462,7 @@ class CnnPolicy(TQCPolicy):
|
|||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
@ -516,7 +531,7 @@ class MultiInputPolicy(TQCPolicy):
|
|||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Un
|
|||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
|
|
@ -11,7 +11,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
|||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||
|
||||
from sb3_contrib.common.utils import quantile_huber_loss
|
||||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
|
||||
from sb3_contrib.tqc.policies import Actor, CnnPolicy, Critic, MlpPolicy, MultiInputPolicy, TQCPolicy
|
||||
|
||||
SelfTQC = TypeVar("SelfTQC", bound="TQC")
|
||||
|
||||
|
|
@ -73,6 +73,10 @@ class TQC(OffPolicyAlgorithm):
|
|||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
}
|
||||
policy: TQCPolicy
|
||||
actor: Actor
|
||||
critic: Critic
|
||||
critic_target: Critic
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -87,7 +91,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
train_freq: int = 1,
|
||||
gradient_steps: int = 1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
replay_buffer_class: Optional[ReplayBuffer] = None,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
ent_coef: Union[str, float] = "auto",
|
||||
|
|
@ -139,7 +143,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
# Inverse of the reward scale
|
||||
self.ent_coef = ent_coef
|
||||
self.target_update_interval = target_update_interval
|
||||
self.ent_coef_optimizer = None
|
||||
self.ent_coef_optimizer: Optional[th.optim.Adam] = None
|
||||
self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
|
||||
|
||||
if _init_setup_model:
|
||||
|
|
@ -155,7 +159,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
# Target entropy is used when learning the entropy coefficient
|
||||
if self.target_entropy == "auto":
|
||||
# automatically set target entropy if needed
|
||||
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
|
||||
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore
|
||||
else:
|
||||
# Force conversion
|
||||
# this will also throw an error for unexpected string
|
||||
|
|
@ -202,7 +206,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
|
||||
for gradient_step in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||
|
||||
# We need to sample because `log_std` may have changed between two gradient steps
|
||||
if self.use_sde:
|
||||
|
|
@ -213,7 +217,7 @@ class TQC(OffPolicyAlgorithm):
|
|||
log_prob = log_prob.reshape(-1, 1)
|
||||
|
||||
ent_coef_loss = None
|
||||
if self.ent_coef_optimizer is not None:
|
||||
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
||||
# Important: detach the variable from the graph
|
||||
# so we don't change it with other losses
|
||||
# see https://github.com/rail-berkeley/softlearning/issues/60
|
||||
|
|
@ -224,11 +228,10 @@ class TQC(OffPolicyAlgorithm):
|
|||
ent_coef = self.ent_coef_tensor
|
||||
|
||||
ent_coefs.append(ent_coef.item())
|
||||
self.replay_buffer.ent_coef = ent_coef.item()
|
||||
|
||||
# Optimize entropy coefficient, also called
|
||||
# entropy temperature or alpha in the paper
|
||||
if ent_coef_loss is not None:
|
||||
if ent_coef_loss is not None and self.ent_coef_optimizer is not None:
|
||||
self.ent_coef_optimizer.zero_grad()
|
||||
ent_coef_loss.backward()
|
||||
self.ent_coef_optimizer.step()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.distributions import kl_divergence
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
||||
|
|
@ -178,16 +178,16 @@ class TRPO(OnPolicyAlgorithm):
|
|||
"""
|
||||
# This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence
|
||||
# The policy objective is also called surrogate objective
|
||||
policy_objective_gradients = []
|
||||
policy_objective_gradients_list = []
|
||||
# Contains the gradients of the KL divergence
|
||||
grad_kl = []
|
||||
grad_kl_list = []
|
||||
# Contains the shape of the gradients of the KL divergence w.r.t each parameter
|
||||
# This way the flattened gradient can be reshaped back into the original shapes and applied to
|
||||
# the parameters
|
||||
grad_shape = []
|
||||
grad_shape: List[Tuple[int, ...]] = []
|
||||
# Contains the parameters which have non-zeros KL divergence gradients
|
||||
# The list is used during the line-search to apply the step to each parameters
|
||||
actor_params = []
|
||||
actor_params: List[nn.Parameter] = []
|
||||
|
||||
for name, param in self.policy.named_parameters():
|
||||
# Skip parameters related to value function based on name
|
||||
|
|
@ -213,13 +213,13 @@ class TRPO(OnPolicyAlgorithm):
|
|||
policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True)
|
||||
|
||||
grad_shape.append(kl_param_grad.shape)
|
||||
grad_kl.append(kl_param_grad.reshape(-1))
|
||||
policy_objective_gradients.append(policy_objective_grad.reshape(-1))
|
||||
grad_kl_list.append(kl_param_grad.reshape(-1))
|
||||
policy_objective_gradients_list.append(policy_objective_grad.reshape(-1))
|
||||
actor_params.append(param)
|
||||
|
||||
# Gradients are concatenated before the conjugate gradient step
|
||||
policy_objective_gradients = th.cat(policy_objective_gradients)
|
||||
grad_kl = th.cat(grad_kl)
|
||||
policy_objective_gradients = th.cat(policy_objective_gradients_list)
|
||||
grad_kl = th.cat(grad_kl_list)
|
||||
return actor_params, policy_objective_gradients, grad_kl, grad_shape
|
||||
|
||||
def train(self) -> None:
|
||||
|
|
@ -243,10 +243,10 @@ class TRPO(OnPolicyAlgorithm):
|
|||
rollout_data = RolloutBufferSamples(
|
||||
rollout_data.observations[:: self.sub_sampling_factor],
|
||||
rollout_data.actions[:: self.sub_sampling_factor],
|
||||
None, # old values, not used here
|
||||
None, # type: ignore[arg-type] # old values, not used here
|
||||
rollout_data.old_log_prob[:: self.sub_sampling_factor],
|
||||
rollout_data.advantages[:: self.sub_sampling_factor],
|
||||
None, # returns, not used here
|
||||
None, # type: ignore[arg-type] # returns, not used here
|
||||
)
|
||||
|
||||
actions = rollout_data.actions
|
||||
|
|
@ -301,7 +301,7 @@ class TRPO(OnPolicyAlgorithm):
|
|||
line_search_max_step_size /= th.matmul(
|
||||
search_direction, hessian_vector_product_fn(search_direction, retain_graph=False)
|
||||
)
|
||||
line_search_max_step_size = th.sqrt(line_search_max_step_size)
|
||||
line_search_max_step_size = th.sqrt(line_search_max_step_size) # type: ignore[assignment, arg-type]
|
||||
|
||||
line_search_backtrack_coeff = 1.0
|
||||
original_actor_params = [param.detach().clone() for param in actor_params]
|
||||
|
|
@ -351,7 +351,7 @@ class TRPO(OnPolicyAlgorithm):
|
|||
param.data = original_param.data.clone()
|
||||
|
||||
policy_objective_values.append(policy_objective.item())
|
||||
kl_divergences.append(0)
|
||||
kl_divergences.append(0.0)
|
||||
else:
|
||||
policy_objective_values.append(new_policy_objective.item())
|
||||
kl_divergences.append(kl_div.item())
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0
|
||||
2.0.0a4
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=1.8.0,<2.0",
|
||||
"stable_baselines3>=2.0.0a4",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def test_cnn(tmp_path, model_class, share_features_extractor):
|
|||
|
||||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
|
||||
# FakeImageEnv is channel last by default and should be wrapped
|
||||
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import gym
|
||||
from typing import Dict, Optional
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
from stable_baselines3.common.envs import SimpleMultiObsEnv
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
|
||||
|
|
@ -27,8 +30,8 @@ class DummyDictEnv(gym.Env):
|
|||
else:
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
N_CHANNELS = 1
|
||||
HEIGHT = 64
|
||||
WIDTH = 64
|
||||
HEIGHT = 36
|
||||
WIDTH = 36
|
||||
|
||||
if channel_last:
|
||||
obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
|
||||
|
|
@ -65,19 +68,31 @@ class DummyDictEnv(gym.Env):
|
|||
|
||||
def step(self, action):
|
||||
reward = 0.0
|
||||
done = False
|
||||
return self.observation_space.sample(), reward, done, {}
|
||||
done = truncated = False
|
||||
return self.observation_space.sample(), reward, done, truncated, {}
|
||||
|
||||
def compute_reward(self, achieved_goal, desired_goal, info):
|
||||
return np.zeros((len(achieved_goal),))
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
if seed is not None:
|
||||
self.observation_space.seed(seed)
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
||||
def render(self, mode="human"):
|
||||
def render(self):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_discrete_actions", [True, False])
|
||||
@pytest.mark.parametrize("channel_last", [True, False])
|
||||
@pytest.mark.parametrize("nested_dict_obs", [True, False])
|
||||
@pytest.mark.parametrize("vec_only", [True, False])
|
||||
def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only):
|
||||
# Check the env used for testing
|
||||
if nested_dict_obs:
|
||||
with pytest.warns(UserWarning, match="Nested observation spaces are not supported"):
|
||||
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))
|
||||
else:
|
||||
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
|
||||
def test_consistency(model_class):
|
||||
"""
|
||||
|
|
@ -89,7 +104,7 @@ def test_consistency(model_class):
|
|||
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
|
||||
env = gym.wrappers.FlattenObservation(dict_env)
|
||||
dict_env.seed(10)
|
||||
obs = dict_env.reset()
|
||||
obs, _ = dict_env.reset()
|
||||
|
||||
kwargs = {}
|
||||
n_steps = 256
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ DIM = 4
|
|||
@pytest.mark.parametrize("model_class", [QRDQN, TRPO])
|
||||
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
|
||||
def test_discrete(model_class, env):
|
||||
env_ = DummyVecEnv([lambda: env])
|
||||
vec_env = DummyVecEnv([lambda: env])
|
||||
kwargs = {}
|
||||
n_steps = 1500
|
||||
if model_class == QRDQN:
|
||||
|
|
@ -30,9 +30,9 @@ def test_discrete(model_class, env):
|
|||
elif n_steps == TRPO:
|
||||
kwargs = dict(n_steps=256, cg_max_steps=5)
|
||||
|
||||
model = model_class("MlpPolicy", env_, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
|
||||
model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
|
||||
|
||||
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||
obs = env.reset()
|
||||
evaluate_policy(model, vec_env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||
obs = vec_env.reset()
|
||||
|
||||
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import random
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||
|
|
@ -30,12 +32,12 @@ class ToDictWrapper(gym.Wrapper):
|
|||
super().__init__(env)
|
||||
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
|
||||
|
||||
def reset(self):
|
||||
return {"obs": self.env.reset()}
|
||||
def reset(self, **kwargs) -> Tuple[Dict[str, np.ndarray], Dict]:
|
||||
return {"obs": self.env.reset(seed=kwargs.get("seed", 0))[0]}, {} # type: ignore[dict-item]
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, infos = self.env.step(action)
|
||||
return {"obs": obs}, reward, done, infos
|
||||
obs, reward, terminated, truncated, infos = self.env.step(action)
|
||||
return {"obs": obs}, reward, terminated, truncated, infos
|
||||
|
||||
|
||||
def test_identity():
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
import gym
|
||||
from typing import Dict, Optional
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gym import spaces
|
||||
from gym.envs.classic_control import CartPoleEnv
|
||||
from gym.wrappers.time_limit import TimeLimit
|
||||
from gymnasium import spaces
|
||||
from gymnasium.envs.classic_control import CartPoleEnv
|
||||
from gymnasium.wrappers.time_limit import TimeLimit
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.envs import FakeImageEnv
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
|
@ -22,12 +25,12 @@ class ToDictWrapper(gym.Wrapper):
|
|||
super().__init__(env)
|
||||
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
|
||||
|
||||
def reset(self):
|
||||
return {"obs": self.env.reset()}
|
||||
def reset(self, **kwargs):
|
||||
return {"obs": self.env.reset(**kwargs)[0]}, {}
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, infos = self.env.step(action)
|
||||
return {"obs": obs}, reward, done, infos
|
||||
obs, reward, done, truncated, infos = self.env.step(action)
|
||||
return {"obs": obs}, reward, done, truncated, infos
|
||||
|
||||
|
||||
class CartPoleNoVelEnv(CartPoleEnv):
|
||||
|
|
@ -46,15 +49,19 @@ class CartPoleNoVelEnv(CartPoleEnv):
|
|||
@staticmethod
|
||||
def _pos_obs(full_obs):
|
||||
xpos, _xvel, thetapos, _thetavel = full_obs
|
||||
return xpos, thetapos
|
||||
return np.array([xpos, thetapos])
|
||||
|
||||
def reset(self):
|
||||
full_obs = super().reset()
|
||||
return CartPoleNoVelEnv._pos_obs(full_obs)
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||
full_obs, info = super().reset(seed=seed, options=options)
|
||||
return CartPoleNoVelEnv._pos_obs(full_obs), info
|
||||
|
||||
def step(self, action):
|
||||
full_obs, rew, done, info = super().step(action)
|
||||
return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info
|
||||
full_obs, rew, terminated, truncated, info = super().step(action)
|
||||
return CartPoleNoVelEnv._pos_obs(full_obs), rew, terminated, truncated, info
|
||||
|
||||
|
||||
def test_env():
|
||||
check_env(CartPoleNoVelEnv())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import gym
|
||||
import gymnasium as gym
|
||||
import pytest
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import pathlib
|
|||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
|
|
@ -24,7 +24,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
|||
if model_class == QRDQN:
|
||||
return IdentityEnv(10)
|
||||
else:
|
||||
return IdentityEnvBox(10)
|
||||
return IdentityEnvBox(-10, 10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Union
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
|
|
@ -111,7 +111,7 @@ def test_ppo_mask_train_eval_mode():
|
|||
|
||||
batch_norm_stats_before = clone_on_policy_batch_norm(model)
|
||||
|
||||
observation = env.reset()
|
||||
observation, _ = env.reset()
|
||||
action_masks = get_action_masks(env)
|
||||
first_prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True)
|
||||
for _ in range(5):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import List
|
||||
|
||||
import pytest
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.envs import IdentityEnv
|
||||
|
||||
from sb3_contrib.common.wrappers import ActionMasker
|
||||
|
|
@ -20,10 +20,12 @@ class IdentityEnvDiscrete(IdentityEnv):
|
|||
super().__init__(ep_length=ep_length, space=space)
|
||||
|
||||
def _action_masks(self) -> List[int]:
|
||||
assert isinstance(self.action_space, spaces.Discrete)
|
||||
return [i == self.state for i in range(self.action_space.n)]
|
||||
|
||||
|
||||
def action_mask_fn(env: IdentityEnvDiscrete) -> List[int]:
|
||||
assert isinstance(env.action_space, spaces.Discrete)
|
||||
return [i == env.state for i in range(env.action_space.n)]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
from stable_baselines3.common.envs import BitFlippingEnv
|
||||
|
||||
from sb3_contrib.common.wrappers import TimeFeatureWrapper
|
||||
|
||||
|
||||
class CustomGoalEnv(gym.GoalEnv):
|
||||
class CustomGoalEnv(gym.Env):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.observation_space = spaces.Dict(
|
||||
|
|
@ -21,10 +21,10 @@ class CustomGoalEnv(gym.GoalEnv):
|
|||
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
def step(self, action):
|
||||
return self.observation_space.sample(), 0.0, False, {}
|
||||
return self.observation_space.sample(), 0.0, False, False, {}
|
||||
|
||||
|
||||
def check_time_feature(obs, timestep, max_timesteps):
|
||||
|
|
@ -37,14 +37,16 @@ def test_time_feature():
|
|||
check_env(env, warn=False)
|
||||
# Check for four episodes
|
||||
max_timesteps = 200
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
for _ in range(4):
|
||||
done = False
|
||||
check_time_feature(obs, timestep=0, max_timesteps=max_timesteps)
|
||||
for step in range(1, max_timesteps + 1):
|
||||
obs, _, done, _ = env.step(env.action_space.sample())
|
||||
obs, _, terminated, truncated, _ = env.step(env.action_space.sample())
|
||||
check_time_feature(obs, timestep=step, max_timesteps=max_timesteps)
|
||||
done = terminated or truncated
|
||||
if done:
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
|
||||
env = BitFlippingEnv()
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
@ -52,16 +54,16 @@ def test_time_feature():
|
|||
|
||||
env = CustomGoalEnv()
|
||||
env = TimeFeatureWrapper(env, max_steps=500)
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
check_time_feature(obs["observation"], timestep=0, max_timesteps=500)
|
||||
obs, _, _, _ = env.step(env.action_space.sample())
|
||||
obs = env.step(env.action_space.sample())[0]
|
||||
check_time_feature(obs["observation"], timestep=1, max_timesteps=500)
|
||||
|
||||
# In test mode, the time feature must be constant
|
||||
env = gym.make("Pendulum-v1")
|
||||
env = TimeFeatureWrapper(env, test_mode=True)
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
check_time_feature(obs, timestep=0, max_timesteps=200)
|
||||
obs, _, _, _ = env.step(env.action_space.sample())
|
||||
obs = env.step(env.action_space.sample())[0]
|
||||
# Should be the same
|
||||
check_time_feature(obs, timestep=0, max_timesteps=200)
|
||||
|
|
|
|||
Loading…
Reference in New Issue