Add TimeFeatureWrapper (#7)
* Add TimeFeatureWrapper * Update README * Address comments
This commit is contained in:
parent
aac20bd1e6
commit
6bafcf6e88
|
|
@ -26,6 +26,9 @@ See documentation for the full list of included features.
|
||||||
**RL Algorithms**:
|
**RL Algorithms**:
|
||||||
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
|
||||||
|
|
||||||
|
**Gym Wrappers**:
|
||||||
|
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
|
||||||
|
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
.. _wrappers:
|
||||||
|
|
||||||
|
Gym Wrappers
|
||||||
|
============
|
||||||
|
|
||||||
|
Additional `Gym Wrappers <https://github.com/openai/gym/tree/master/gym/wrappers>`_ to enhance Gym environments.
|
||||||
|
|
||||||
|
.. automodule:: sb3_contrib.common.wrappers
|
||||||
|
|
||||||
|
|
||||||
|
TimeFeatureWrapper
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. autoclass:: TimeFeatureWrapper
|
||||||
|
:members:
|
||||||
|
|
@ -33,11 +33,11 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
|
||||||
|
|
||||||
modules/tqc
|
modules/tqc
|
||||||
|
|
||||||
.. .. toctree::
|
.. toctree::
|
||||||
.. :maxdepth: 1
|
:maxdepth: 1
|
||||||
.. :caption: Common
|
:caption: Common
|
||||||
..
|
|
||||||
.. common/atari_wrappers
|
common/wrappers
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,30 @@ Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
|
Pre-Release 0.11.0a0 (WIP)
|
||||||
|
-------------------------------
|
||||||
|
|
||||||
|
Breaking Changes:
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
New Features:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
- Added ``TimeFeatureWrapper`` to the wrappers
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
Deprecations:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Others:
|
||||||
|
^^^^^^^
|
||||||
|
|
||||||
|
Documentation:
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Pre-Release 0.10.0 (2020-10-28)
|
Pre-Release 0.10.0 (2020-10-28)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
|
||||||
|
|
||||||
|
|
||||||
|
class TimeFeatureWrapper(gym.Wrapper):
|
||||||
|
"""
|
||||||
|
Add remaining, normalized time to observation space for fixed length episodes.
|
||||||
|
See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Only ``gym.spaces.Box`` and ``gym.spaces.Dict`` (``gym.GoalEnv``) 1D observation spaces
|
||||||
|
are supported for now.
|
||||||
|
|
||||||
|
:param env: Gym env to wrap.
|
||||||
|
:param max_steps: Max number of steps of an episode
|
||||||
|
if it is not wrapped in a ``TimeLimit`` object.
|
||||||
|
:param test_mode: In test mode, the time feature is constant,
|
||||||
|
equal to zero. This allow to check that the agent did not overfit this feature,
|
||||||
|
learning a deterministic pre-defined sequence of actions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False):
|
||||||
|
assert isinstance(
|
||||||
|
env.observation_space, (gym.spaces.Box, gym.spaces.Dict)
|
||||||
|
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `gym.spaces.Dict` (`gym.GoalEnv`) observation spaces."
|
||||||
|
|
||||||
|
# Add a time feature to the observation
|
||||||
|
if isinstance(env.observation_space, gym.spaces.Dict):
|
||||||
|
assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
|
||||||
|
obs_space = env.observation_space.spaces["observation"]
|
||||||
|
assert isinstance(
|
||||||
|
obs_space, gym.spaces.Box
|
||||||
|
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
|
||||||
|
obs_space = env.observation_space.spaces["observation"]
|
||||||
|
else:
|
||||||
|
obs_space = env.observation_space
|
||||||
|
|
||||||
|
assert len(obs_space.shape) == 1, "Only 1D observation spaces are supported"
|
||||||
|
|
||||||
|
low, high = obs_space.low, obs_space.high
|
||||||
|
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
|
||||||
|
|
||||||
|
if isinstance(env.observation_space, gym.spaces.Dict):
|
||||||
|
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)
|
||||||
|
|
||||||
|
super(TimeFeatureWrapper, self).__init__(env)
|
||||||
|
|
||||||
|
# Try to infer the max number of steps per episode
|
||||||
|
try:
|
||||||
|
self._max_steps = env.spec.max_episode_steps
|
||||||
|
except AttributeError:
|
||||||
|
self._max_steps = None
|
||||||
|
|
||||||
|
# Fallback to provided value
|
||||||
|
if self._max_steps is None:
|
||||||
|
self._max_steps = max_steps
|
||||||
|
|
||||||
|
self._current_step = 0
|
||||||
|
self._test_mode = test_mode
|
||||||
|
|
||||||
|
def reset(self) -> GymObs:
|
||||||
|
self._current_step = 0
|
||||||
|
return self._get_obs(self.env.reset())
|
||||||
|
|
||||||
|
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
|
||||||
|
self._current_step += 1
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
return self._get_obs(obs), reward, done, info
|
||||||
|
|
||||||
|
def _get_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Concatenate the time feature to the current observation.
|
||||||
|
|
||||||
|
:param obs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Remaining time is more general
|
||||||
|
time_feature = 1 - (self._current_step / self._max_steps)
|
||||||
|
if self._test_mode:
|
||||||
|
time_feature = 1.0
|
||||||
|
|
||||||
|
if isinstance(obs, dict):
|
||||||
|
obs["observation"] = np.append(obs["observation"], time_feature)
|
||||||
|
return obs
|
||||||
|
return np.append(obs, time_feature)
|
||||||
|
|
@ -1 +1 @@
|
||||||
0.10.0
|
0.11.0a0
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators
|
||||||
per-file-ignores =
|
per-file-ignores =
|
||||||
./sb3_contrib/__init__.py:F401
|
./sb3_contrib/__init__.py:F401
|
||||||
./sb3_contrib/tqc/__init__.py:F401
|
./sb3_contrib/tqc/__init__.py:F401
|
||||||
|
./sb3_contrib/common/wrappers/__init__.py:F401
|
||||||
exclude =
|
exclude =
|
||||||
# No need to traverse our git directory
|
# No need to traverse our git directory
|
||||||
.git,
|
.git,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from gym import spaces
|
||||||
|
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
|
||||||
|
from stable_baselines3.common.env_checker import check_env
|
||||||
|
|
||||||
|
from sb3_contrib.common.wrappers import TimeFeatureWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class CustomGoalEnv(gym.GoalEnv):
|
||||||
|
"""docstring for CustomGoalEnv."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CustomGoalEnv, self).__init__()
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"observation": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
|
||||||
|
"achieved_goal": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
|
||||||
|
"desired_goal": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
return self.observation_space.sample(), 0.0, False, {}
|
||||||
|
|
||||||
|
|
||||||
|
def check_time_feature(obs, timestep, max_timesteps):
|
||||||
|
assert np.allclose(obs[-1], 1.0 - timestep / max_timesteps)
|
||||||
|
|
||||||
|
|
||||||
|
def test_time_feature():
|
||||||
|
env = gym.make("Pendulum-v0")
|
||||||
|
env = TimeFeatureWrapper(env)
|
||||||
|
check_env(env, warn=False)
|
||||||
|
# Check for four episodes
|
||||||
|
max_timesteps = 200
|
||||||
|
obs = env.reset()
|
||||||
|
for _ in range(4):
|
||||||
|
check_time_feature(obs, timestep=0, max_timesteps=max_timesteps)
|
||||||
|
for step in range(1, max_timesteps + 1):
|
||||||
|
obs, _, done, _ = env.step(env.action_space.sample())
|
||||||
|
check_time_feature(obs, timestep=step, max_timesteps=max_timesteps)
|
||||||
|
if done:
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
env = BitFlippingEnv()
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
env = TimeFeatureWrapper(env)
|
||||||
|
|
||||||
|
env = CustomGoalEnv()
|
||||||
|
env = TimeFeatureWrapper(env, max_steps=500)
|
||||||
|
obs = env.reset()
|
||||||
|
check_time_feature(obs["observation"], timestep=0, max_timesteps=500)
|
||||||
|
obs, _, _, _ = env.step(env.action_space.sample())
|
||||||
|
check_time_feature(obs["observation"], timestep=1, max_timesteps=500)
|
||||||
|
|
||||||
|
# In test mode, the time feature must be constant
|
||||||
|
env = gym.make("Pendulum-v0")
|
||||||
|
env = TimeFeatureWrapper(env, test_mode=True)
|
||||||
|
obs = env.reset()
|
||||||
|
check_time_feature(obs, timestep=0, max_timesteps=200)
|
||||||
|
obs, _, _, _ = env.step(env.action_space.sample())
|
||||||
|
# Should be the same
|
||||||
|
check_time_feature(obs, timestep=0, max_timesteps=200)
|
||||||
Loading…
Reference in New Issue