Add TimeFeatureWrapper (#7)

* Add TimeFeatureWrapper

* Update README

* Address comments
This commit is contained in:
Antonin RAFFIN 2020-11-13 12:00:56 +01:00 committed by GitHub
parent aac20bd1e6
commit 6bafcf6e88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 210 additions and 6 deletions

View File

@ -26,6 +26,9 @@ See documentation for the full list of included features.
**RL Algorithms**:
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
## Documentation

15
docs/common/wrappers.rst Normal file
View File

@ -0,0 +1,15 @@
.. _wrappers:
Gym Wrappers
============
Additional `Gym Wrappers <https://github.com/openai/gym/tree/master/gym/wrappers>`_ to enhance Gym environments.
.. automodule:: sb3_contrib.common.wrappers
TimeFeatureWrapper
------------------
.. autoclass:: TimeFeatureWrapper
:members:

View File

@ -33,11 +33,11 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
modules/tqc
.. .. toctree::
.. :maxdepth: 1
.. :caption: Common
..
.. common/atari_wrappers
.. toctree::
:maxdepth: 1
:caption: Common
common/wrappers
.. toctree::
:maxdepth: 1

View File

@ -4,6 +4,30 @@ Changelog
==========
Pre-Release 0.11.0a0 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
- Added ``TimeFeatureWrapper`` to the wrappers
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Documentation:
^^^^^^^^^^^^^^
Pre-Release 0.10.0 (2020-10-28)
-------------------------------

View File

View File

@ -0,0 +1 @@
from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper

View File

@ -0,0 +1,91 @@
from typing import Dict, Union
import gym
import numpy as np
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
class TimeFeatureWrapper(gym.Wrapper):
"""
Add remaining, normalized time to observation space for fixed length episodes.
See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13.
.. note::
Only ``gym.spaces.Box`` and ``gym.spaces.Dict`` (``gym.GoalEnv``) 1D observation spaces
are supported for now.
:param env: Gym env to wrap.
:param max_steps: Max number of steps of an episode
if it is not wrapped in a ``TimeLimit`` object.
:param test_mode: In test mode, the time feature is constant,
equal to zero. This allow to check that the agent did not overfit this feature,
learning a deterministic pre-defined sequence of actions.
"""
def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False):
assert isinstance(
env.observation_space, (gym.spaces.Box, gym.spaces.Dict)
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `gym.spaces.Dict` (`gym.GoalEnv`) observation spaces."
# Add a time feature to the observation
if isinstance(env.observation_space, gym.spaces.Dict):
assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
obs_space = env.observation_space.spaces["observation"]
assert isinstance(
obs_space, gym.spaces.Box
), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
obs_space = env.observation_space.spaces["observation"]
else:
obs_space = env.observation_space
assert len(obs_space.shape) == 1, "Only 1D observation spaces are supported"
low, high = obs_space.low, obs_space.high
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
if isinstance(env.observation_space, gym.spaces.Dict):
env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=np.float32)
else:
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)
super(TimeFeatureWrapper, self).__init__(env)
# Try to infer the max number of steps per episode
try:
self._max_steps = env.spec.max_episode_steps
except AttributeError:
self._max_steps = None
# Fallback to provided value
if self._max_steps is None:
self._max_steps = max_steps
self._current_step = 0
self._test_mode = test_mode
def reset(self) -> GymObs:
self._current_step = 0
return self._get_obs(self.env.reset())
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
self._current_step += 1
obs, reward, done, info = self.env.step(action)
return self._get_obs(obs), reward, done, info
def _get_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Concatenate the time feature to the current observation.
:param obs:
:return:
"""
# Remaining time is more general
time_feature = 1 - (self._current_step / self._max_steps)
if self._test_mode:
time_feature = 1.0
if isinstance(obs, dict):
obs["observation"] = np.append(obs["observation"], time_feature)
return obs
return np.append(obs, time_feature)

View File

@ -1 +1 @@
0.10.0
0.11.0a0

View File

@ -26,6 +26,7 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators
per-file-ignores =
./sb3_contrib/__init__.py:F401
./sb3_contrib/tqc/__init__.py:F401
./sb3_contrib/common/wrappers/__init__.py:F401
exclude =
# No need to traverse our git directory
.git,

69
tests/test_wrappers.py Normal file
View File

@ -0,0 +1,69 @@
import gym
import numpy as np
import pytest
from gym import spaces
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
from stable_baselines3.common.env_checker import check_env
from sb3_contrib.common.wrappers import TimeFeatureWrapper
class CustomGoalEnv(gym.GoalEnv):
"""docstring for CustomGoalEnv."""
def __init__(self):
super(CustomGoalEnv, self).__init__()
self.observation_space = spaces.Dict(
{
"observation": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
"achieved_goal": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
"desired_goal": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
}
)
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
def reset(self):
return self.observation_space.sample()
def step(self, action):
return self.observation_space.sample(), 0.0, False, {}
def check_time_feature(obs, timestep, max_timesteps):
assert np.allclose(obs[-1], 1.0 - timestep / max_timesteps)
def test_time_feature():
env = gym.make("Pendulum-v0")
env = TimeFeatureWrapper(env)
check_env(env, warn=False)
# Check for four episodes
max_timesteps = 200
obs = env.reset()
for _ in range(4):
check_time_feature(obs, timestep=0, max_timesteps=max_timesteps)
for step in range(1, max_timesteps + 1):
obs, _, done, _ = env.step(env.action_space.sample())
check_time_feature(obs, timestep=step, max_timesteps=max_timesteps)
if done:
obs = env.reset()
env = BitFlippingEnv()
with pytest.raises(AssertionError):
env = TimeFeatureWrapper(env)
env = CustomGoalEnv()
env = TimeFeatureWrapper(env, max_steps=500)
obs = env.reset()
check_time_feature(obs["observation"], timestep=0, max_timesteps=500)
obs, _, _, _ = env.step(env.action_space.sample())
check_time_feature(obs["observation"], timestep=1, max_timesteps=500)
# In test mode, the time feature must be constant
env = gym.make("Pendulum-v0")
env = TimeFeatureWrapper(env, test_mode=True)
obs = env.reset()
check_time_feature(obs, timestep=0, max_timesteps=200)
obs, _, _, _ = env.step(env.action_space.sample())
# Should be the same
check_time_feature(obs, timestep=0, max_timesteps=200)