Add Gymnasium support (#152)
* Add support for Gym 0.24 * Fixes for gym 0.24 * Fix for new reset signature * Add tmp SB3 branch * Fixes for gym 0.26 * Remove unused import * Fix dependency * Type annotations fixes * Reformat * Reformat with black 23 * Move to gymnasium * Patch env if needed * Fix types * Fix CI * Fixes for gymnasium * Fix wrapper annotations * Update version * Fix type check * Update QRDQN type hints and bug fix with multi envs * Fix TQC type hints * Fix TRPO type hints * Additional fixes * Update SB3 version * Update issue templates and CI --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
parent
a84ad3aa7d
commit
21cc96cafd
|
|
@ -50,7 +50,8 @@ body:
|
||||||
* GPU models and configuration
|
* GPU models and configuration
|
||||||
* Python version
|
* Python version
|
||||||
* PyTorch version
|
* PyTorch version
|
||||||
* Gym version
|
* Gymnasium version
|
||||||
|
* (if installed) OpenAI Gym version
|
||||||
* Versions of any other relevant libraries
|
* Versions of any other relevant libraries
|
||||||
|
|
||||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||||
|
|
|
||||||
|
|
@ -34,9 +34,9 @@ body:
|
||||||
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
|
||||||
value: |
|
value: |
|
||||||
```python
|
```python
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
from stable_baselines3 import A2C
|
from stable_baselines3 import A2C
|
||||||
from stable_baselines3.common.env_checker import check_env
|
from stable_baselines3.common.env_checker import check_env
|
||||||
|
|
@ -49,15 +49,16 @@ body:
|
||||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
|
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
|
||||||
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
|
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed=None, options=None):
|
||||||
return self.observation_space.sample()
|
return self.observation_space.sample(), {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs = self.observation_space.sample()
|
obs = self.observation_space.sample()
|
||||||
reward = 1.0
|
reward = 1.0
|
||||||
done = False
|
terminated = False
|
||||||
|
truncated = False
|
||||||
info = {}
|
info = {}
|
||||||
return obs, reward, done, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
env = CustomEnv()
|
env = CustomEnv()
|
||||||
check_env(env)
|
check_env(env)
|
||||||
|
|
@ -84,7 +85,8 @@ body:
|
||||||
* GPU models and configuration
|
* GPU models and configuration
|
||||||
* Python version
|
* Python version
|
||||||
* PyTorch version
|
* PyTorch version
|
||||||
* Gym version
|
* Gymnasium version
|
||||||
|
* (if installed) OpenAI Gym version
|
||||||
* Versions of any other relevant libraries
|
* Versions of any other relevant libraries
|
||||||
|
|
||||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,42 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
Release 2.0.0a4 (WIP)
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
**Gymnasium support**
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Stable-Baselines3 (SB3) v2.0 will be the last one supporting python 3.7 (end of life in June 2023).
|
||||||
|
We highly recommended you to upgrade to Python >= 3.8.
|
||||||
|
|
||||||
|
|
||||||
|
Breaking Changes:
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the ``shimmy`` package (@carlosluis, @arjun-kg, @tlpss)
|
||||||
|
- Upgraded to Stable-Baselines3 >= 2.0.0
|
||||||
|
|
||||||
|
New Features:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
^^^^^^^^^^
|
||||||
|
- Fixed QRDQN update interval for multi envs
|
||||||
|
|
||||||
|
Deprecations:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Others:
|
||||||
|
^^^^^^^
|
||||||
|
- Fixed ``sb3_contrib/tqc/*.py`` type hints
|
||||||
|
- Fixed ``sb3_contrib/trpo/*.py`` type hints
|
||||||
|
|
||||||
|
|
||||||
|
Documentation:
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
||||||
Release 1.8.0 (2023-04-07)
|
Release 1.8.0 (2023-04-07)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ to specify the name (see `PR #25 <https://github.com/Stable-Baselines-Team/stabl
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ Example
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from sb3_contrib import QRDQN
|
from sb3_contrib import QRDQN
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ Example
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from sb3_contrib import TQC
|
from sb3_contrib import TQC
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ Example
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from sb3_contrib import TRPO
|
from sb3_contrib import TRPO
|
||||||
|
|
|
||||||
|
|
@ -37,9 +37,6 @@ exclude = """(?x)(
|
||||||
| sb3_contrib/common/maskable/buffers.py$
|
| sb3_contrib/common/maskable/buffers.py$
|
||||||
| sb3_contrib/common/envs/invalid_actions_env.py$
|
| sb3_contrib/common/envs/invalid_actions_env.py$
|
||||||
| sb3_contrib/common/vec_env/async_eval.py$
|
| sb3_contrib/common/vec_env/async_eval.py$
|
||||||
| sb3_contrib/tqc/tqc.py$
|
|
||||||
| sb3_contrib/tqc/policies.py$
|
|
||||||
| sb3_contrib/trpo/trpo.py$
|
|
||||||
| sb3_contrib/ppo_mask/ppo_mask.py$
|
| sb3_contrib/ppo_mask/ppo_mask.py$
|
||||||
| tests/test_train_eval_mode.py$
|
| tests/test_train_eval_mode.py$
|
||||||
)"""
|
)"""
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
import torch.nn.utils
|
import torch.nn.utils
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp
|
||||||
|
|
@ -50,7 +50,7 @@ class ARSPolicy(BasePolicy):
|
||||||
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
|
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
|
||||||
)
|
)
|
||||||
elif isinstance(action_space, spaces.Discrete):
|
elif isinstance(action_space, spaces.Discrete):
|
||||||
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
|
actor_net = create_mlp(self.features_dim, int(action_space.n), net_arch, activation_fn, with_bias=with_bias)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.envs import IdentityEnv
|
from stable_baselines3.common.envs import IdentityEnv
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from typing import Generator, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||||
from stable_baselines3.common.type_aliases import TensorDict
|
from stable_baselines3.common.type_aliases import TensorDict
|
||||||
from stable_baselines3.common.vec_env import VecNormalize
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.distributions import Distribution
|
from stable_baselines3.common.distributions import Distribution
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
|
||||||
|
|
@ -64,7 +64,7 @@ def evaluate_policy(
|
||||||
is_monitor_wrapped = False
|
is_monitor_wrapped = False
|
||||||
|
|
||||||
if not isinstance(env, VecEnv):
|
if not isinstance(env, VecEnv):
|
||||||
env = DummyVecEnv([lambda: env])
|
env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value]
|
||||||
|
|
||||||
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
|
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
|
||||||
|
|
||||||
|
|
@ -93,7 +93,7 @@ def evaluate_policy(
|
||||||
if use_masking:
|
if use_masking:
|
||||||
action_masks = get_action_masks(env)
|
action_masks = get_action_masks(env)
|
||||||
actions, state = model.predict(
|
actions, state = model.predict(
|
||||||
observations,
|
observations, # type: ignore[arg-type]
|
||||||
state=states,
|
state=states,
|
||||||
episode_start=episode_starts,
|
episode_start=episode_starts,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
|
|
@ -101,7 +101,10 @@ def evaluate_policy(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
actions, states = model.predict(
|
actions, states = model.predict(
|
||||||
observations, state=states, episode_start=episode_starts, deterministic=deterministic
|
observations, # type: ignore[arg-type]
|
||||||
|
state=states,
|
||||||
|
episode_start=episode_starts,
|
||||||
|
deterministic=deterministic,
|
||||||
)
|
)
|
||||||
observations, rewards, dones, infos = env.step(actions)
|
observations, rewards, dones, infos = env.step(actions)
|
||||||
current_rewards += rewards
|
current_rewards += rewards
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.torch_layers import (
|
from stable_baselines3.common.torch_layers import (
|
||||||
BaseFeaturesExtractor,
|
BaseFeaturesExtractor,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import Callable, Generator, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||||
from stable_baselines3.common.vec_env import VecNormalize
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.distributions import Distribution
|
from stable_baselines3.common.distributions import Distribution
|
||||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||||
from stable_baselines3.common.torch_layers import (
|
from stable_baselines3.common.torch_layers import (
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import torch as th
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
||||||
|
from stable_baselines3.common.utils import compat_gym_seed
|
||||||
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
|
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
|
||||||
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
|
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
|
||||||
|
|
||||||
|
|
@ -56,7 +57,7 @@ def _worker(
|
||||||
results.append((weights_idx, (episode_rewards, episode_lengths)))
|
results.append((weights_idx, (episode_rewards, episode_lengths)))
|
||||||
remote.send(results)
|
remote.send(results)
|
||||||
elif cmd == "seed":
|
elif cmd == "seed":
|
||||||
remote.send(env.seed(data))
|
remote.send(compat_gym_seed(env, seed=data))
|
||||||
elif cmd == "get_obs_rms":
|
elif cmd == "get_obs_rms":
|
||||||
remote.send(obs_rms)
|
remote.send(obs_rms)
|
||||||
elif cmd == "sync_obs_rms":
|
elif cmd == "sync_obs_rms":
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Callable, Union
|
from typing import Callable, Union
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
from typing import Dict, Union
|
from typing import Any, Dict, SupportsFloat, Tuple, Union
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
|
from gymnasium.core import ActType
|
||||||
|
|
||||||
|
TimeFeatureObs = Union[np.ndarray, Dict[str, np.ndarray]]
|
||||||
|
|
||||||
|
|
||||||
class TimeFeatureWrapper(gym.Wrapper):
|
class TimeFeatureWrapper(gym.Wrapper[TimeFeatureObs, ActType, TimeFeatureObs, ActType]):
|
||||||
"""
|
"""
|
||||||
Add remaining, normalized time to observation space for fixed length episodes.
|
Add remaining, normalized time to observation space for fixed length episodes.
|
||||||
See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13.
|
See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13.
|
||||||
|
|
@ -33,27 +35,30 @@ class TimeFeatureWrapper(gym.Wrapper):
|
||||||
if isinstance(env.observation_space, spaces.Dict):
|
if isinstance(env.observation_space, spaces.Dict):
|
||||||
assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
|
assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
|
||||||
obs_space = env.observation_space.spaces["observation"]
|
obs_space = env.observation_space.spaces["observation"]
|
||||||
assert isinstance(obs_space, spaces.Box), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
|
|
||||||
obs_space = env.observation_space.spaces["observation"]
|
|
||||||
else:
|
else:
|
||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
|
|
||||||
|
assert isinstance(obs_space, gym.spaces.Box), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
|
||||||
assert len(obs_space.shape) == 1, "Only 1D observation spaces are supported"
|
assert len(obs_space.shape) == 1, "Only 1D observation spaces are supported"
|
||||||
|
|
||||||
low, high = obs_space.low, obs_space.high
|
low, high = obs_space.low, obs_space.high
|
||||||
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
|
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0])) # type: ignore[arg-type]
|
||||||
self.dtype = obs_space.dtype
|
self.dtype = obs_space.dtype
|
||||||
|
|
||||||
if isinstance(env.observation_space, spaces.Dict):
|
if isinstance(env.observation_space, spaces.Dict):
|
||||||
env.observation_space.spaces["observation"] = spaces.Box(low=low, high=high, dtype=self.dtype)
|
env.observation_space.spaces["observation"] = spaces.Box(
|
||||||
|
low=low,
|
||||||
|
high=high,
|
||||||
|
dtype=self.dtype, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
env.observation_space = spaces.Box(low=low, high=high, dtype=self.dtype)
|
env.observation_space = spaces.Box(low=low, high=high, dtype=self.dtype) # type: ignore[arg-type]
|
||||||
|
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
# Try to infer the max number of steps per episode
|
# Try to infer the max number of steps per episode
|
||||||
try:
|
try:
|
||||||
self._max_steps = env.spec.max_episode_steps
|
self._max_steps = env.spec.max_episode_steps # type: ignore[union-attr]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self._max_steps = None
|
self._max_steps = None
|
||||||
|
|
||||||
|
|
@ -64,14 +69,15 @@ class TimeFeatureWrapper(gym.Wrapper):
|
||||||
self._current_step = 0
|
self._current_step = 0
|
||||||
self._test_mode = test_mode
|
self._test_mode = test_mode
|
||||||
|
|
||||||
def reset(self) -> GymObs:
|
def reset(self, **kwargs) -> Tuple[TimeFeatureObs, Dict[str, Any]]:
|
||||||
self._current_step = 0
|
self._current_step = 0
|
||||||
return self._get_obs(self.env.reset())
|
obs, info = self.env.reset(**kwargs)
|
||||||
|
return self._get_obs(obs), info
|
||||||
|
|
||||||
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
|
def step(self, action: ActType) -> Tuple[TimeFeatureObs, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||||
self._current_step += 1
|
self._current_step += 1
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs, reward, done, truncated, info = self.env.step(action)
|
||||||
return self._get_obs(obs), reward, done, info
|
return self._get_obs(obs), reward, done, truncated, info
|
||||||
|
|
||||||
def _get_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
def _get_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -80,11 +86,13 @@ class TimeFeatureWrapper(gym.Wrapper):
|
||||||
:param obs:
|
:param obs:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
# for mypy
|
||||||
|
assert self._max_steps is not None
|
||||||
# Remaining time is more general
|
# Remaining time is more general
|
||||||
time_feature = 1 - (self._current_step / self._max_steps)
|
time_feature = 1 - (self._current_step / self._max_steps)
|
||||||
if self._test_mode:
|
if self._test_mode:
|
||||||
time_feature = 1.0
|
time_feature = 1.0
|
||||||
time_feature = np.array(time_feature, dtype=self.dtype)
|
time_feature = np.array(time_feature, dtype=self.dtype) # type: ignore[assignment]
|
||||||
|
|
||||||
if isinstance(obs, dict):
|
if isinstance(obs, dict):
|
||||||
obs["observation"] = np.append(obs["observation"], time_feature)
|
obs["observation"] = np.append(obs["observation"], time_feature)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common import utils
|
from stable_baselines3.common import utils
|
||||||
from stable_baselines3.common.buffers import RolloutBuffer
|
from stable_baselines3.common.buffers import RolloutBuffer
|
||||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
|
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
|
||||||
|
|
@ -232,6 +232,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
# Make sure training timesteps are ahead of the internal counter
|
# Make sure training timesteps are ahead of the internal counter
|
||||||
total_timesteps += self.num_timesteps
|
total_timesteps += self.num_timesteps
|
||||||
self._total_timesteps = total_timesteps
|
self._total_timesteps = total_timesteps
|
||||||
|
self._num_timesteps_at_start = self.num_timesteps
|
||||||
|
|
||||||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
||||||
if reset_num_timesteps or self._last_obs is None:
|
if reset_num_timesteps or self._last_obs is None:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.buffers import RolloutBuffer
|
from stable_baselines3.common.buffers import RolloutBuffer
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.torch_layers import (
|
from stable_baselines3.common.torch_layers import (
|
||||||
BaseFeaturesExtractor,
|
BaseFeaturesExtractor,
|
||||||
|
|
@ -27,10 +27,12 @@ class QuantileNetwork(BasePolicy):
|
||||||
dividing by 255.0 (True by default)
|
dividing by 255.0 (True by default)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
action_space: spaces.Discrete
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Discrete,
|
||||||
features_extractor: BaseFeaturesExtractor,
|
features_extractor: BaseFeaturesExtractor,
|
||||||
features_dim: int,
|
features_dim: int,
|
||||||
n_quantiles: int = 200,
|
n_quantiles: int = 200,
|
||||||
|
|
@ -52,7 +54,7 @@ class QuantileNetwork(BasePolicy):
|
||||||
self.activation_fn = activation_fn
|
self.activation_fn = activation_fn
|
||||||
self.features_dim = features_dim
|
self.features_dim = features_dim
|
||||||
self.n_quantiles = n_quantiles
|
self.n_quantiles = n_quantiles
|
||||||
action_dim = self.action_space.n # number of actions
|
action_dim = int(self.action_space.n) # number of actions
|
||||||
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
|
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
|
||||||
self.quantile_net = nn.Sequential(*quantile_net)
|
self.quantile_net = nn.Sequential(*quantile_net)
|
||||||
|
|
||||||
|
|
@ -63,10 +65,8 @@ class QuantileNetwork(BasePolicy):
|
||||||
:param obs: Observation
|
:param obs: Observation
|
||||||
:return: The estimated quantiles for each action.
|
:return: The estimated quantiles for each action.
|
||||||
"""
|
"""
|
||||||
# For type checker:
|
|
||||||
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
|
|
||||||
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
|
quantiles = self.quantile_net(self.extract_features(obs, self.features_extractor))
|
||||||
return quantiles.view(-1, self.n_quantiles, self.action_space.n)
|
return quantiles.view(-1, self.n_quantiles, int(self.action_space.n))
|
||||||
|
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||||
q_values = self(observation).mean(dim=1)
|
q_values = self(observation).mean(dim=1)
|
||||||
|
|
@ -110,10 +110,13 @@ class QRDQNPolicy(BasePolicy):
|
||||||
excluding the learning rate, to pass to the optimizer
|
excluding the learning rate, to pass to the optimizer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
quantile_net: QuantileNetwork
|
||||||
|
quantile_net_target: QuantileNetwork
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Discrete,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
n_quantiles: int = 200,
|
n_quantiles: int = 200,
|
||||||
net_arch: Optional[List[int]] = None,
|
net_arch: Optional[List[int]] = None,
|
||||||
|
|
@ -152,9 +155,6 @@ class QRDQNPolicy(BasePolicy):
|
||||||
"activation_fn": self.activation_fn,
|
"activation_fn": self.activation_fn,
|
||||||
"normalize_images": normalize_images,
|
"normalize_images": normalize_images,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.quantile_net: QuantileNetwork
|
|
||||||
self.quantile_net_target: QuantileNetwork
|
|
||||||
self._build(lr_schedule)
|
self._build(lr_schedule)
|
||||||
|
|
||||||
def _build(self, lr_schedule: Schedule) -> None:
|
def _build(self, lr_schedule: Schedule) -> None:
|
||||||
|
|
@ -239,7 +239,7 @@ class CnnPolicy(QRDQNPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Discrete,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
n_quantiles: int = 200,
|
n_quantiles: int = 200,
|
||||||
net_arch: Optional[List[int]] = None,
|
net_arch: Optional[List[int]] = None,
|
||||||
|
|
@ -287,7 +287,7 @@ class MultiInputPolicy(QRDQNPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Discrete,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
n_quantiles: int = 200,
|
n_quantiles: int = 200,
|
||||||
net_arch: Optional[List[int]] = None,
|
net_arch: Optional[List[int]] = None,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.buffers import ReplayBuffer
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
|
|
@ -10,7 +11,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
||||||
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
|
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
|
||||||
|
|
||||||
from sb3_contrib.common.utils import quantile_huber_loss
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy
|
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, QRDQNPolicy, QuantileNetwork
|
||||||
|
|
||||||
SelfQRDQN = TypeVar("SelfQRDQN", bound="QRDQN")
|
SelfQRDQN = TypeVar("SelfQRDQN", bound="QRDQN")
|
||||||
|
|
||||||
|
|
@ -64,6 +65,11 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
"CnnPolicy": CnnPolicy,
|
"CnnPolicy": CnnPolicy,
|
||||||
"MultiInputPolicy": MultiInputPolicy,
|
"MultiInputPolicy": MultiInputPolicy,
|
||||||
}
|
}
|
||||||
|
# Linear schedule will be defined in `_setup_model()`
|
||||||
|
exploration_schedule: Schedule
|
||||||
|
quantile_net: QuantileNetwork
|
||||||
|
quantile_net_target: QuantileNetwork
|
||||||
|
policy: QRDQNPolicy
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -123,13 +129,11 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
self.exploration_final_eps = exploration_final_eps
|
self.exploration_final_eps = exploration_final_eps
|
||||||
self.exploration_fraction = exploration_fraction
|
self.exploration_fraction = exploration_fraction
|
||||||
self.target_update_interval = target_update_interval
|
self.target_update_interval = target_update_interval
|
||||||
|
# For updating the target network with multiple envs:
|
||||||
|
self._n_calls = 0
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
# "epsilon" for the epsilon-greedy exploration
|
# "epsilon" for the epsilon-greedy exploration
|
||||||
self.exploration_rate = 0.0
|
self.exploration_rate = 0.0
|
||||||
# Linear schedule will be defined in `_setup_model()`
|
|
||||||
self.exploration_schedule: Schedule
|
|
||||||
self.quantile_net: th.nn.Module
|
|
||||||
self.quantile_net_target: th.nn.Module
|
|
||||||
|
|
||||||
if "optimizer_class" not in self.policy_kwargs:
|
if "optimizer_class" not in self.policy_kwargs:
|
||||||
self.policy_kwargs["optimizer_class"] = th.optim.Adam
|
self.policy_kwargs["optimizer_class"] = th.optim.Adam
|
||||||
|
|
@ -148,10 +152,20 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
self.exploration_schedule = get_linear_fn(
|
self.exploration_schedule = get_linear_fn(
|
||||||
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
|
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
|
||||||
)
|
)
|
||||||
|
# Account for multiple environments
|
||||||
|
# each call to step() corresponds to n_envs transitions
|
||||||
|
if self.n_envs > 1:
|
||||||
|
if self.n_envs > self.target_update_interval:
|
||||||
|
warnings.warn(
|
||||||
|
"The number of environments used is greater than the target network "
|
||||||
|
f"update interval ({self.n_envs} > {self.target_update_interval}), "
|
||||||
|
"therefore the target network will be updated after each call to env.step() "
|
||||||
|
f"which corresponds to {self.n_envs} steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)
|
||||||
|
|
||||||
def _create_aliases(self) -> None:
|
def _create_aliases(self) -> None:
|
||||||
# For type checker:
|
|
||||||
assert isinstance(self.policy, QRDQNPolicy)
|
|
||||||
self.quantile_net = self.policy.quantile_net
|
self.quantile_net = self.policy.quantile_net
|
||||||
self.quantile_net_target = self.policy.quantile_net_target
|
self.quantile_net_target = self.policy.quantile_net_target
|
||||||
self.n_quantiles = self.policy.n_quantiles
|
self.n_quantiles = self.policy.n_quantiles
|
||||||
|
|
@ -161,7 +175,8 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
Update the exploration rate and target network if needed.
|
Update the exploration rate and target network if needed.
|
||||||
This method is called in ``collect_rollouts()`` after each step in the environment.
|
This method is called in ``collect_rollouts()`` after each step in the environment.
|
||||||
"""
|
"""
|
||||||
if self.num_timesteps % self.target_update_interval == 0:
|
self._n_calls += 1
|
||||||
|
if self._n_calls % self.target_update_interval == 0:
|
||||||
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
|
||||||
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
|
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
|
||||||
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
|
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||||
from stable_baselines3.common.policies import BaseModel, BasePolicy
|
from stable_baselines3.common.policies import BaseModel, BasePolicy
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
|
|
@ -44,10 +44,12 @@ class Actor(BasePolicy):
|
||||||
dividing by 255.0 (True by default)
|
dividing by 255.0 (True by default)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
action_space: spaces.Box
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Box,
|
||||||
net_arch: List[int],
|
net_arch: List[int],
|
||||||
features_extractor: nn.Module,
|
features_extractor: nn.Module,
|
||||||
features_dim: int,
|
features_dim: int,
|
||||||
|
|
@ -95,9 +97,9 @@ class Actor(BasePolicy):
|
||||||
if clip_mean > 0.0:
|
if clip_mean > 0.0:
|
||||||
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
|
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
|
||||||
else:
|
else:
|
||||||
self.action_dist = SquashedDiagGaussianDistribution(action_dim)
|
self.action_dist = SquashedDiagGaussianDistribution(action_dim) # type: ignore[assignment]
|
||||||
self.mu = nn.Linear(last_layer_dim, action_dim)
|
self.mu = nn.Linear(last_layer_dim, action_dim)
|
||||||
self.log_std = nn.Linear(last_layer_dim, action_dim)
|
self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment]
|
||||||
|
|
||||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
data = super()._get_constructor_parameters()
|
data = super()._get_constructor_parameters()
|
||||||
|
|
@ -156,7 +158,7 @@ class Actor(BasePolicy):
|
||||||
if self.use_sde:
|
if self.use_sde:
|
||||||
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
|
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
|
||||||
# Unstructured exploration (Original implementation)
|
# Unstructured exploration (Original implementation)
|
||||||
log_std = self.log_std(latent_pi)
|
log_std = self.log_std(latent_pi) # type: ignore[operator]
|
||||||
# Original Implementation to cap the standard deviation
|
# Original Implementation to cap the standard deviation
|
||||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||||
return mean_actions, log_std, {}
|
return mean_actions, log_std, {}
|
||||||
|
|
@ -192,12 +194,15 @@ class Critic(BaseModel):
|
||||||
between the actor and the critic (this saves computation time)
|
between the actor and the critic (this saves computation time)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
action_space: spaces.Box
|
||||||
|
features_extractor: BaseFeaturesExtractor
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Box,
|
||||||
net_arch: List[int],
|
net_arch: List[int],
|
||||||
features_extractor: nn.Module,
|
features_extractor: BaseFeaturesExtractor,
|
||||||
features_dim: int,
|
features_dim: int,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
|
|
@ -221,12 +226,12 @@ class Critic(BaseModel):
|
||||||
self.quantiles_total = n_quantiles * n_critics
|
self.quantiles_total = n_quantiles * n_critics
|
||||||
|
|
||||||
for i in range(n_critics):
|
for i in range(n_critics):
|
||||||
qf_net = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
|
qf_net_list = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
|
||||||
qf_net = nn.Sequential(*qf_net)
|
qf_net = nn.Sequential(*qf_net_list)
|
||||||
self.add_module(f"qf{i}", qf_net)
|
self.add_module(f"qf{i}", qf_net)
|
||||||
self.q_networks.append(qf_net)
|
self.q_networks.append(qf_net)
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
|
def forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor:
|
||||||
# Learn the features extractor using the policy loss only
|
# Learn the features extractor using the policy loss only
|
||||||
# when the features_extractor is shared with the actor
|
# when the features_extractor is shared with the actor
|
||||||
with th.set_grad_enabled(not self.share_features_extractor):
|
with th.set_grad_enabled(not self.share_features_extractor):
|
||||||
|
|
@ -266,10 +271,14 @@ class TQCPolicy(BasePolicy):
|
||||||
between the actor and the critic (this saves computation time)
|
between the actor and the critic (this saves computation time)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
actor: Actor
|
||||||
|
critic: Critic
|
||||||
|
critic_target: Critic
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Box,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
|
@ -328,15 +337,17 @@ class TQCPolicy(BasePolicy):
|
||||||
"share_features_extractor": share_features_extractor,
|
"share_features_extractor": share_features_extractor,
|
||||||
}
|
}
|
||||||
self.critic_kwargs.update(tqc_kwargs)
|
self.critic_kwargs.update(tqc_kwargs)
|
||||||
self.actor, self.actor_target = None, None
|
|
||||||
self.critic, self.critic_target = None, None
|
|
||||||
self.share_features_extractor = share_features_extractor
|
self.share_features_extractor = share_features_extractor
|
||||||
|
|
||||||
self._build(lr_schedule)
|
self._build(lr_schedule)
|
||||||
|
|
||||||
def _build(self, lr_schedule: Schedule) -> None:
|
def _build(self, lr_schedule: Schedule) -> None:
|
||||||
self.actor = self.make_actor()
|
self.actor = self.make_actor()
|
||||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.actor.optimizer = self.optimizer_class( # type: ignore[call-arg]
|
||||||
|
self.actor.parameters(),
|
||||||
|
lr=lr_schedule(1),
|
||||||
|
**self.optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if self.share_features_extractor:
|
if self.share_features_extractor:
|
||||||
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
|
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
|
||||||
|
|
@ -347,7 +358,7 @@ class TQCPolicy(BasePolicy):
|
||||||
# Create a separate features extractor for the critic
|
# Create a separate features extractor for the critic
|
||||||
# this requires more memory and computation
|
# this requires more memory and computation
|
||||||
self.critic = self.make_critic(features_extractor=None)
|
self.critic = self.make_critic(features_extractor=None)
|
||||||
critic_parameters = self.critic.parameters()
|
critic_parameters = list(self.critic.parameters())
|
||||||
|
|
||||||
# Critic target should not share the feature extactor with critic
|
# Critic target should not share the feature extactor with critic
|
||||||
self.critic_target = self.make_critic(features_extractor=None)
|
self.critic_target = self.make_critic(features_extractor=None)
|
||||||
|
|
@ -356,7 +367,11 @@ class TQCPolicy(BasePolicy):
|
||||||
# Target networks should always be in eval mode
|
# Target networks should always be in eval mode
|
||||||
self.critic_target.set_training_mode(False)
|
self.critic_target.set_training_mode(False)
|
||||||
|
|
||||||
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.critic.optimizer = self.optimizer_class( # type: ignore[call-arg]
|
||||||
|
critic_parameters,
|
||||||
|
lr=lr_schedule(1),
|
||||||
|
**self.optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
data = super()._get_constructor_parameters()
|
data = super()._get_constructor_parameters()
|
||||||
|
|
@ -447,7 +462,7 @@ class CnnPolicy(TQCPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Box,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
|
@ -516,7 +531,7 @@ class MultiInputPolicy(TQCPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Box,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Un
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.buffers import ReplayBuffer
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.noise import ActionNoise
|
from stable_baselines3.common.noise import ActionNoise
|
||||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
|
|
@ -11,7 +11,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||||
|
|
||||||
from sb3_contrib.common.utils import quantile_huber_loss
|
from sb3_contrib.common.utils import quantile_huber_loss
|
||||||
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TQCPolicy
|
from sb3_contrib.tqc.policies import Actor, CnnPolicy, Critic, MlpPolicy, MultiInputPolicy, TQCPolicy
|
||||||
|
|
||||||
SelfTQC = TypeVar("SelfTQC", bound="TQC")
|
SelfTQC = TypeVar("SelfTQC", bound="TQC")
|
||||||
|
|
||||||
|
|
@ -73,6 +73,10 @@ class TQC(OffPolicyAlgorithm):
|
||||||
"CnnPolicy": CnnPolicy,
|
"CnnPolicy": CnnPolicy,
|
||||||
"MultiInputPolicy": MultiInputPolicy,
|
"MultiInputPolicy": MultiInputPolicy,
|
||||||
}
|
}
|
||||||
|
policy: TQCPolicy
|
||||||
|
actor: Actor
|
||||||
|
critic: Critic
|
||||||
|
critic_target: Critic
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -87,7 +91,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
train_freq: int = 1,
|
train_freq: int = 1,
|
||||||
gradient_steps: int = 1,
|
gradient_steps: int = 1,
|
||||||
action_noise: Optional[ActionNoise] = None,
|
action_noise: Optional[ActionNoise] = None,
|
||||||
replay_buffer_class: Optional[ReplayBuffer] = None,
|
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
optimize_memory_usage: bool = False,
|
optimize_memory_usage: bool = False,
|
||||||
ent_coef: Union[str, float] = "auto",
|
ent_coef: Union[str, float] = "auto",
|
||||||
|
|
@ -139,7 +143,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
# Inverse of the reward scale
|
# Inverse of the reward scale
|
||||||
self.ent_coef = ent_coef
|
self.ent_coef = ent_coef
|
||||||
self.target_update_interval = target_update_interval
|
self.target_update_interval = target_update_interval
|
||||||
self.ent_coef_optimizer = None
|
self.ent_coef_optimizer: Optional[th.optim.Adam] = None
|
||||||
self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
|
self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
|
||||||
|
|
||||||
if _init_setup_model:
|
if _init_setup_model:
|
||||||
|
|
@ -155,7 +159,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
# Target entropy is used when learning the entropy coefficient
|
# Target entropy is used when learning the entropy coefficient
|
||||||
if self.target_entropy == "auto":
|
if self.target_entropy == "auto":
|
||||||
# automatically set target entropy if needed
|
# automatically set target entropy if needed
|
||||||
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
|
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore
|
||||||
else:
|
else:
|
||||||
# Force conversion
|
# Force conversion
|
||||||
# this will also throw an error for unexpected string
|
# this will also throw an error for unexpected string
|
||||||
|
|
@ -202,7 +206,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
|
|
||||||
for gradient_step in range(gradient_steps):
|
for gradient_step in range(gradient_steps):
|
||||||
# Sample replay buffer
|
# Sample replay buffer
|
||||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||||
|
|
||||||
# We need to sample because `log_std` may have changed between two gradient steps
|
# We need to sample because `log_std` may have changed between two gradient steps
|
||||||
if self.use_sde:
|
if self.use_sde:
|
||||||
|
|
@ -213,7 +217,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
log_prob = log_prob.reshape(-1, 1)
|
log_prob = log_prob.reshape(-1, 1)
|
||||||
|
|
||||||
ent_coef_loss = None
|
ent_coef_loss = None
|
||||||
if self.ent_coef_optimizer is not None:
|
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
||||||
# Important: detach the variable from the graph
|
# Important: detach the variable from the graph
|
||||||
# so we don't change it with other losses
|
# so we don't change it with other losses
|
||||||
# see https://github.com/rail-berkeley/softlearning/issues/60
|
# see https://github.com/rail-berkeley/softlearning/issues/60
|
||||||
|
|
@ -224,11 +228,10 @@ class TQC(OffPolicyAlgorithm):
|
||||||
ent_coef = self.ent_coef_tensor
|
ent_coef = self.ent_coef_tensor
|
||||||
|
|
||||||
ent_coefs.append(ent_coef.item())
|
ent_coefs.append(ent_coef.item())
|
||||||
self.replay_buffer.ent_coef = ent_coef.item()
|
|
||||||
|
|
||||||
# Optimize entropy coefficient, also called
|
# Optimize entropy coefficient, also called
|
||||||
# entropy temperature or alpha in the paper
|
# entropy temperature or alpha in the paper
|
||||||
if ent_coef_loss is not None:
|
if ent_coef_loss is not None and self.ent_coef_optimizer is not None:
|
||||||
self.ent_coef_optimizer.zero_grad()
|
self.ent_coef_optimizer.zero_grad()
|
||||||
ent_coef_loss.backward()
|
ent_coef_loss.backward()
|
||||||
self.ent_coef_optimizer.step()
|
self.ent_coef_optimizer.step()
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.distributions import kl_divergence
|
from stable_baselines3.common.distributions import kl_divergence
|
||||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
||||||
|
|
@ -178,16 +178,16 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
# This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence
|
# This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence
|
||||||
# The policy objective is also called surrogate objective
|
# The policy objective is also called surrogate objective
|
||||||
policy_objective_gradients = []
|
policy_objective_gradients_list = []
|
||||||
# Contains the gradients of the KL divergence
|
# Contains the gradients of the KL divergence
|
||||||
grad_kl = []
|
grad_kl_list = []
|
||||||
# Contains the shape of the gradients of the KL divergence w.r.t each parameter
|
# Contains the shape of the gradients of the KL divergence w.r.t each parameter
|
||||||
# This way the flattened gradient can be reshaped back into the original shapes and applied to
|
# This way the flattened gradient can be reshaped back into the original shapes and applied to
|
||||||
# the parameters
|
# the parameters
|
||||||
grad_shape = []
|
grad_shape: List[Tuple[int, ...]] = []
|
||||||
# Contains the parameters which have non-zeros KL divergence gradients
|
# Contains the parameters which have non-zeros KL divergence gradients
|
||||||
# The list is used during the line-search to apply the step to each parameters
|
# The list is used during the line-search to apply the step to each parameters
|
||||||
actor_params = []
|
actor_params: List[nn.Parameter] = []
|
||||||
|
|
||||||
for name, param in self.policy.named_parameters():
|
for name, param in self.policy.named_parameters():
|
||||||
# Skip parameters related to value function based on name
|
# Skip parameters related to value function based on name
|
||||||
|
|
@ -213,13 +213,13 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True)
|
policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True)
|
||||||
|
|
||||||
grad_shape.append(kl_param_grad.shape)
|
grad_shape.append(kl_param_grad.shape)
|
||||||
grad_kl.append(kl_param_grad.reshape(-1))
|
grad_kl_list.append(kl_param_grad.reshape(-1))
|
||||||
policy_objective_gradients.append(policy_objective_grad.reshape(-1))
|
policy_objective_gradients_list.append(policy_objective_grad.reshape(-1))
|
||||||
actor_params.append(param)
|
actor_params.append(param)
|
||||||
|
|
||||||
# Gradients are concatenated before the conjugate gradient step
|
# Gradients are concatenated before the conjugate gradient step
|
||||||
policy_objective_gradients = th.cat(policy_objective_gradients)
|
policy_objective_gradients = th.cat(policy_objective_gradients_list)
|
||||||
grad_kl = th.cat(grad_kl)
|
grad_kl = th.cat(grad_kl_list)
|
||||||
return actor_params, policy_objective_gradients, grad_kl, grad_shape
|
return actor_params, policy_objective_gradients, grad_kl, grad_shape
|
||||||
|
|
||||||
def train(self) -> None:
|
def train(self) -> None:
|
||||||
|
|
@ -243,10 +243,10 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
rollout_data = RolloutBufferSamples(
|
rollout_data = RolloutBufferSamples(
|
||||||
rollout_data.observations[:: self.sub_sampling_factor],
|
rollout_data.observations[:: self.sub_sampling_factor],
|
||||||
rollout_data.actions[:: self.sub_sampling_factor],
|
rollout_data.actions[:: self.sub_sampling_factor],
|
||||||
None, # old values, not used here
|
None, # type: ignore[arg-type] # old values, not used here
|
||||||
rollout_data.old_log_prob[:: self.sub_sampling_factor],
|
rollout_data.old_log_prob[:: self.sub_sampling_factor],
|
||||||
rollout_data.advantages[:: self.sub_sampling_factor],
|
rollout_data.advantages[:: self.sub_sampling_factor],
|
||||||
None, # returns, not used here
|
None, # type: ignore[arg-type] # returns, not used here
|
||||||
)
|
)
|
||||||
|
|
||||||
actions = rollout_data.actions
|
actions = rollout_data.actions
|
||||||
|
|
@ -301,7 +301,7 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
line_search_max_step_size /= th.matmul(
|
line_search_max_step_size /= th.matmul(
|
||||||
search_direction, hessian_vector_product_fn(search_direction, retain_graph=False)
|
search_direction, hessian_vector_product_fn(search_direction, retain_graph=False)
|
||||||
)
|
)
|
||||||
line_search_max_step_size = th.sqrt(line_search_max_step_size)
|
line_search_max_step_size = th.sqrt(line_search_max_step_size) # type: ignore[assignment, arg-type]
|
||||||
|
|
||||||
line_search_backtrack_coeff = 1.0
|
line_search_backtrack_coeff = 1.0
|
||||||
original_actor_params = [param.detach().clone() for param in actor_params]
|
original_actor_params = [param.detach().clone() for param in actor_params]
|
||||||
|
|
@ -351,7 +351,7 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
param.data = original_param.data.clone()
|
param.data = original_param.data.clone()
|
||||||
|
|
||||||
policy_objective_values.append(policy_objective.item())
|
policy_objective_values.append(policy_objective.item())
|
||||||
kl_divergences.append(0)
|
kl_divergences.append(0.0)
|
||||||
else:
|
else:
|
||||||
policy_objective_values.append(new_policy_objective.item())
|
policy_objective_values.append(new_policy_objective.item())
|
||||||
kl_divergences.append(kl_div.item())
|
kl_divergences.append(kl_div.item())
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.8.0
|
2.0.0a4
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.8.0,<2.0",
|
"stable_baselines3>=2.0.0a4",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ def test_cnn(tmp_path, model_class, share_features_extractor):
|
||||||
|
|
||||||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||||
|
|
||||||
obs = env.reset()
|
obs, _ = env.reset()
|
||||||
|
|
||||||
# FakeImageEnv is channel last by default and should be wrapped
|
# FakeImageEnv is channel last by default and should be wrapped
|
||||||
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)
|
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
import gym
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
|
from stable_baselines3.common.env_checker import check_env
|
||||||
from stable_baselines3.common.envs import SimpleMultiObsEnv
|
from stable_baselines3.common.envs import SimpleMultiObsEnv
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
|
||||||
|
|
@ -27,8 +30,8 @@ class DummyDictEnv(gym.Env):
|
||||||
else:
|
else:
|
||||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||||
N_CHANNELS = 1
|
N_CHANNELS = 1
|
||||||
HEIGHT = 64
|
HEIGHT = 36
|
||||||
WIDTH = 64
|
WIDTH = 36
|
||||||
|
|
||||||
if channel_last:
|
if channel_last:
|
||||||
obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
|
obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
|
||||||
|
|
@ -65,19 +68,31 @@ class DummyDictEnv(gym.Env):
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
done = False
|
done = truncated = False
|
||||||
return self.observation_space.sample(), reward, done, {}
|
return self.observation_space.sample(), reward, done, truncated, {}
|
||||||
|
|
||||||
def compute_reward(self, achieved_goal, desired_goal, info):
|
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||||
return np.zeros((len(achieved_goal),))
|
if seed is not None:
|
||||||
|
self.observation_space.seed(seed)
|
||||||
|
return self.observation_space.sample(), {}
|
||||||
|
|
||||||
def reset(self):
|
def render(self):
|
||||||
return self.observation_space.sample()
|
|
||||||
|
|
||||||
def render(self, mode="human"):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_discrete_actions", [True, False])
|
||||||
|
@pytest.mark.parametrize("channel_last", [True, False])
|
||||||
|
@pytest.mark.parametrize("nested_dict_obs", [True, False])
|
||||||
|
@pytest.mark.parametrize("vec_only", [True, False])
|
||||||
|
def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only):
|
||||||
|
# Check the env used for testing
|
||||||
|
if nested_dict_obs:
|
||||||
|
with pytest.warns(UserWarning, match="Nested observation spaces are not supported"):
|
||||||
|
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))
|
||||||
|
else:
|
||||||
|
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
|
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
|
||||||
def test_consistency(model_class):
|
def test_consistency(model_class):
|
||||||
"""
|
"""
|
||||||
|
|
@ -89,7 +104,7 @@ def test_consistency(model_class):
|
||||||
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
|
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
|
||||||
env = gym.wrappers.FlattenObservation(dict_env)
|
env = gym.wrappers.FlattenObservation(dict_env)
|
||||||
dict_env.seed(10)
|
dict_env.seed(10)
|
||||||
obs = dict_env.reset()
|
obs, _ = dict_env.reset()
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
n_steps = 256
|
n_steps = 256
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ DIM = 4
|
||||||
@pytest.mark.parametrize("model_class", [QRDQN, TRPO])
|
@pytest.mark.parametrize("model_class", [QRDQN, TRPO])
|
||||||
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
|
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
|
||||||
def test_discrete(model_class, env):
|
def test_discrete(model_class, env):
|
||||||
env_ = DummyVecEnv([lambda: env])
|
vec_env = DummyVecEnv([lambda: env])
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
n_steps = 1500
|
n_steps = 1500
|
||||||
if model_class == QRDQN:
|
if model_class == QRDQN:
|
||||||
|
|
@ -30,9 +30,9 @@ def test_discrete(model_class, env):
|
||||||
elif n_steps == TRPO:
|
elif n_steps == TRPO:
|
||||||
kwargs = dict(n_steps=256, cg_max_steps=5)
|
kwargs = dict(n_steps=256, cg_max_steps=5)
|
||||||
|
|
||||||
model = model_class("MlpPolicy", env_, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
|
model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
|
||||||
|
|
||||||
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
|
evaluate_policy(model, vec_env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||||
obs = env.reset()
|
obs = vec_env.reset()
|
||||||
|
|
||||||
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
|
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
import random
|
import random
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement
|
from stable_baselines3.common.callbacks import EventCallback, StopTrainingOnNoModelImprovement
|
||||||
from stable_baselines3.common.env_util import make_vec_env
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||||
|
|
@ -30,12 +32,12 @@ class ToDictWrapper(gym.Wrapper):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
|
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, **kwargs) -> Tuple[Dict[str, np.ndarray], Dict]:
|
||||||
return {"obs": self.env.reset()}
|
return {"obs": self.env.reset(seed=kwargs.get("seed", 0))[0]}, {} # type: ignore[dict-item]
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, done, infos = self.env.step(action)
|
obs, reward, terminated, truncated, infos = self.env.step(action)
|
||||||
return {"obs": obs}, reward, done, infos
|
return {"obs": obs}, reward, terminated, truncated, infos
|
||||||
|
|
||||||
|
|
||||||
def test_identity():
|
def test_identity():
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
import gym
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from gym.envs.classic_control import CartPoleEnv
|
from gymnasium.envs.classic_control import CartPoleEnv
|
||||||
from gym.wrappers.time_limit import TimeLimit
|
from gymnasium.wrappers.time_limit import TimeLimit
|
||||||
from stable_baselines3.common.callbacks import EvalCallback
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
|
from stable_baselines3.common.env_checker import check_env
|
||||||
from stable_baselines3.common.env_util import make_vec_env
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
from stable_baselines3.common.envs import FakeImageEnv
|
from stable_baselines3.common.envs import FakeImageEnv
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
|
|
@ -22,12 +25,12 @@ class ToDictWrapper(gym.Wrapper):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
|
self.observation_space = spaces.Dict({"obs": self.env.observation_space})
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, **kwargs):
|
||||||
return {"obs": self.env.reset()}
|
return {"obs": self.env.reset(**kwargs)[0]}, {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, done, infos = self.env.step(action)
|
obs, reward, done, truncated, infos = self.env.step(action)
|
||||||
return {"obs": obs}, reward, done, infos
|
return {"obs": obs}, reward, done, truncated, infos
|
||||||
|
|
||||||
|
|
||||||
class CartPoleNoVelEnv(CartPoleEnv):
|
class CartPoleNoVelEnv(CartPoleEnv):
|
||||||
|
|
@ -46,15 +49,19 @@ class CartPoleNoVelEnv(CartPoleEnv):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _pos_obs(full_obs):
|
def _pos_obs(full_obs):
|
||||||
xpos, _xvel, thetapos, _thetavel = full_obs
|
xpos, _xvel, thetapos, _thetavel = full_obs
|
||||||
return xpos, thetapos
|
return np.array([xpos, thetapos])
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
|
||||||
full_obs = super().reset()
|
full_obs, info = super().reset(seed=seed, options=options)
|
||||||
return CartPoleNoVelEnv._pos_obs(full_obs)
|
return CartPoleNoVelEnv._pos_obs(full_obs), info
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
full_obs, rew, done, info = super().step(action)
|
full_obs, rew, terminated, truncated, info = super().step(action)
|
||||||
return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info
|
return CartPoleNoVelEnv._pos_obs(full_obs), rew, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
def test_env():
|
||||||
|
check_env(CartPoleNoVelEnv())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import pytest
|
import pytest
|
||||||
from stable_baselines3.common.env_util import make_vec_env
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
from stable_baselines3.common.vec_env import VecNormalize
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import pathlib
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch as th
|
import torch as th
|
||||||
|
|
@ -24,7 +24,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
||||||
if model_class == QRDQN:
|
if model_class == QRDQN:
|
||||||
return IdentityEnv(10)
|
return IdentityEnv(10)
|
||||||
else:
|
else:
|
||||||
return IdentityEnvBox(10)
|
return IdentityEnvBox(-10, 10)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch as th
|
import torch as th
|
||||||
|
|
@ -111,7 +111,7 @@ def test_ppo_mask_train_eval_mode():
|
||||||
|
|
||||||
batch_norm_stats_before = clone_on_policy_batch_norm(model)
|
batch_norm_stats_before = clone_on_policy_batch_norm(model)
|
||||||
|
|
||||||
observation = env.reset()
|
observation, _ = env.reset()
|
||||||
action_masks = get_action_masks(env)
|
action_masks = get_action_masks(env)
|
||||||
first_prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True)
|
first_prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True)
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.envs import IdentityEnv
|
from stable_baselines3.common.envs import IdentityEnv
|
||||||
|
|
||||||
from sb3_contrib.common.wrappers import ActionMasker
|
from sb3_contrib.common.wrappers import ActionMasker
|
||||||
|
|
@ -20,10 +20,12 @@ class IdentityEnvDiscrete(IdentityEnv):
|
||||||
super().__init__(ep_length=ep_length, space=space)
|
super().__init__(ep_length=ep_length, space=space)
|
||||||
|
|
||||||
def _action_masks(self) -> List[int]:
|
def _action_masks(self) -> List[int]:
|
||||||
|
assert isinstance(self.action_space, spaces.Discrete)
|
||||||
return [i == self.state for i in range(self.action_space.n)]
|
return [i == self.state for i in range(self.action_space.n)]
|
||||||
|
|
||||||
|
|
||||||
def action_mask_fn(env: IdentityEnvDiscrete) -> List[int]:
|
def action_mask_fn(env: IdentityEnvDiscrete) -> List[int]:
|
||||||
|
assert isinstance(env.action_space, spaces.Discrete)
|
||||||
return [i == env.state for i in range(env.action_space.n)]
|
return [i == env.state for i in range(env.action_space.n)]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from stable_baselines3.common.env_checker import check_env
|
from stable_baselines3.common.env_checker import check_env
|
||||||
from stable_baselines3.common.envs import BitFlippingEnv
|
from stable_baselines3.common.envs import BitFlippingEnv
|
||||||
|
|
||||||
from sb3_contrib.common.wrappers import TimeFeatureWrapper
|
from sb3_contrib.common.wrappers import TimeFeatureWrapper
|
||||||
|
|
||||||
|
|
||||||
class CustomGoalEnv(gym.GoalEnv):
|
class CustomGoalEnv(gym.Env):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.observation_space = spaces.Dict(
|
self.observation_space = spaces.Dict(
|
||||||
|
|
@ -21,10 +21,10 @@ class CustomGoalEnv(gym.GoalEnv):
|
||||||
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
|
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return self.observation_space.sample()
|
return self.observation_space.sample(), {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
return self.observation_space.sample(), 0.0, False, {}
|
return self.observation_space.sample(), 0.0, False, False, {}
|
||||||
|
|
||||||
|
|
||||||
def check_time_feature(obs, timestep, max_timesteps):
|
def check_time_feature(obs, timestep, max_timesteps):
|
||||||
|
|
@ -37,14 +37,16 @@ def test_time_feature():
|
||||||
check_env(env, warn=False)
|
check_env(env, warn=False)
|
||||||
# Check for four episodes
|
# Check for four episodes
|
||||||
max_timesteps = 200
|
max_timesteps = 200
|
||||||
obs = env.reset()
|
obs, _ = env.reset()
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
|
done = False
|
||||||
check_time_feature(obs, timestep=0, max_timesteps=max_timesteps)
|
check_time_feature(obs, timestep=0, max_timesteps=max_timesteps)
|
||||||
for step in range(1, max_timesteps + 1):
|
for step in range(1, max_timesteps + 1):
|
||||||
obs, _, done, _ = env.step(env.action_space.sample())
|
obs, _, terminated, truncated, _ = env.step(env.action_space.sample())
|
||||||
check_time_feature(obs, timestep=step, max_timesteps=max_timesteps)
|
check_time_feature(obs, timestep=step, max_timesteps=max_timesteps)
|
||||||
|
done = terminated or truncated
|
||||||
if done:
|
if done:
|
||||||
obs = env.reset()
|
obs, _ = env.reset()
|
||||||
|
|
||||||
env = BitFlippingEnv()
|
env = BitFlippingEnv()
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
|
|
@ -52,16 +54,16 @@ def test_time_feature():
|
||||||
|
|
||||||
env = CustomGoalEnv()
|
env = CustomGoalEnv()
|
||||||
env = TimeFeatureWrapper(env, max_steps=500)
|
env = TimeFeatureWrapper(env, max_steps=500)
|
||||||
obs = env.reset()
|
obs, _ = env.reset()
|
||||||
check_time_feature(obs["observation"], timestep=0, max_timesteps=500)
|
check_time_feature(obs["observation"], timestep=0, max_timesteps=500)
|
||||||
obs, _, _, _ = env.step(env.action_space.sample())
|
obs = env.step(env.action_space.sample())[0]
|
||||||
check_time_feature(obs["observation"], timestep=1, max_timesteps=500)
|
check_time_feature(obs["observation"], timestep=1, max_timesteps=500)
|
||||||
|
|
||||||
# In test mode, the time feature must be constant
|
# In test mode, the time feature must be constant
|
||||||
env = gym.make("Pendulum-v1")
|
env = gym.make("Pendulum-v1")
|
||||||
env = TimeFeatureWrapper(env, test_mode=True)
|
env = TimeFeatureWrapper(env, test_mode=True)
|
||||||
obs = env.reset()
|
obs, _ = env.reset()
|
||||||
check_time_feature(obs, timestep=0, max_timesteps=200)
|
check_time_feature(obs, timestep=0, max_timesteps=200)
|
||||||
obs, _, _, _ = env.step(env.action_space.sample())
|
obs = env.step(env.action_space.sample())[0]
|
||||||
# Should be the same
|
# Should be the same
|
||||||
check_time_feature(obs, timestep=0, max_timesteps=200)
|
check_time_feature(obs, timestep=0, max_timesteps=200)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue