From 6bafcf6e88ede242aea4e90d532dac3f25e6732a Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 13 Nov 2020 12:00:56 +0100 Subject: [PATCH] Add TimeFeatureWrapper (#7) * Add TimeFeatureWrapper * Update README * Address comments --- README.md | 3 + docs/common/wrappers.rst | 15 ++++ docs/index.rst | 10 +-- docs/misc/changelog.rst | 24 ++++++ sb3_contrib/common/__init__.py | 0 sb3_contrib/common/wrappers/__init__.py | 1 + sb3_contrib/common/wrappers/time_feature.py | 91 +++++++++++++++++++++ sb3_contrib/version.txt | 2 +- setup.cfg | 1 + tests/test_wrappers.py | 69 ++++++++++++++++ 10 files changed, 210 insertions(+), 6 deletions(-) create mode 100644 docs/common/wrappers.rst create mode 100644 sb3_contrib/common/__init__.py create mode 100644 sb3_contrib/common/wrappers/__init__.py create mode 100644 sb3_contrib/common/wrappers/time_feature.py create mode 100644 tests/test_wrappers.py diff --git a/README.md b/README.md index 9b002df..81349fd 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,9 @@ See documentation for the full list of included features. **RL Algorithms**: - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) +**Gym Wrappers**: +- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) + ## Documentation diff --git a/docs/common/wrappers.rst b/docs/common/wrappers.rst new file mode 100644 index 0000000..cc2f11e --- /dev/null +++ b/docs/common/wrappers.rst @@ -0,0 +1,15 @@ +.. _wrappers: + +Gym Wrappers +============ + +Additional `Gym Wrappers `_ to enhance Gym environments. + +.. automodule:: sb3_contrib.common.wrappers + + +TimeFeatureWrapper +------------------ + +.. autoclass:: TimeFeatureWrapper + :members: diff --git a/docs/index.rst b/docs/index.rst index f701110..c198d4d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,11 +33,11 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d modules/tqc -.. .. toctree:: -.. :maxdepth: 1 -.. :caption: Common -.. -.. common/atari_wrappers +.. toctree:: + :maxdepth: 1 + :caption: Common + + common/wrappers .. toctree:: :maxdepth: 1 diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 95a1381..5302fa0 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,6 +4,30 @@ Changelog ========== +Pre-Release 0.11.0a0 (WIP) +------------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ +- Added ``TimeFeatureWrapper`` to the wrappers + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + + + Pre-Release 0.10.0 (2020-10-28) ------------------------------- diff --git a/sb3_contrib/common/__init__.py b/sb3_contrib/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sb3_contrib/common/wrappers/__init__.py b/sb3_contrib/common/wrappers/__init__.py new file mode 100644 index 0000000..5e59c76 --- /dev/null +++ b/sb3_contrib/common/wrappers/__init__.py @@ -0,0 +1 @@ +from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper diff --git a/sb3_contrib/common/wrappers/time_feature.py b/sb3_contrib/common/wrappers/time_feature.py new file mode 100644 index 0000000..57a8279 --- /dev/null +++ b/sb3_contrib/common/wrappers/time_feature.py @@ -0,0 +1,91 @@ +from typing import Dict, Union + +import gym +import numpy as np +from stable_baselines3.common.type_aliases import GymObs, GymStepReturn + + +class TimeFeatureWrapper(gym.Wrapper): + """ + Add remaining, normalized time to observation space for fixed length episodes. + See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13. + + .. note:: + + Only ``gym.spaces.Box`` and ``gym.spaces.Dict`` (``gym.GoalEnv``) 1D observation spaces + are supported for now. + + :param env: Gym env to wrap. + :param max_steps: Max number of steps of an episode + if it is not wrapped in a ``TimeLimit`` object. + :param test_mode: In test mode, the time feature is constant, + equal to zero. This allow to check that the agent did not overfit this feature, + learning a deterministic pre-defined sequence of actions. + """ + + def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False): + assert isinstance( + env.observation_space, (gym.spaces.Box, gym.spaces.Dict) + ), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `gym.spaces.Dict` (`gym.GoalEnv`) observation spaces." + + # Add a time feature to the observation + if isinstance(env.observation_space, gym.spaces.Dict): + assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space" + obs_space = env.observation_space.spaces["observation"] + assert isinstance( + obs_space, gym.spaces.Box + ), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space." + obs_space = env.observation_space.spaces["observation"] + else: + obs_space = env.observation_space + + assert len(obs_space.shape) == 1, "Only 1D observation spaces are supported" + + low, high = obs_space.low, obs_space.high + low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0])) + + if isinstance(env.observation_space, gym.spaces.Dict): + env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=np.float32) + else: + env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32) + + super(TimeFeatureWrapper, self).__init__(env) + + # Try to infer the max number of steps per episode + try: + self._max_steps = env.spec.max_episode_steps + except AttributeError: + self._max_steps = None + + # Fallback to provided value + if self._max_steps is None: + self._max_steps = max_steps + + self._current_step = 0 + self._test_mode = test_mode + + def reset(self) -> GymObs: + self._current_step = 0 + return self._get_obs(self.env.reset()) + + def step(self, action: Union[int, np.ndarray]) -> GymStepReturn: + self._current_step += 1 + obs, reward, done, info = self.env.step(action) + return self._get_obs(obs), reward, done, info + + def _get_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """ + Concatenate the time feature to the current observation. + + :param obs: + :return: + """ + # Remaining time is more general + time_feature = 1 - (self._current_step / self._max_steps) + if self._test_mode: + time_feature = 1.0 + + if isinstance(obs, dict): + obs["observation"] = np.append(obs["observation"], time_feature) + return obs + return np.append(obs, time_feature) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 78bc1ab..d22e31d 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -0.10.0 +0.11.0a0 diff --git a/setup.cfg b/setup.cfg index 951fc85..08320b3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators per-file-ignores = ./sb3_contrib/__init__.py:F401 ./sb3_contrib/tqc/__init__.py:F401 + ./sb3_contrib/common/wrappers/__init__.py:F401 exclude = # No need to traverse our git directory .git, diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py new file mode 100644 index 0000000..dfd1e24 --- /dev/null +++ b/tests/test_wrappers.py @@ -0,0 +1,69 @@ +import gym +import numpy as np +import pytest +from gym import spaces +from stable_baselines3.common.bit_flipping_env import BitFlippingEnv +from stable_baselines3.common.env_checker import check_env + +from sb3_contrib.common.wrappers import TimeFeatureWrapper + + +class CustomGoalEnv(gym.GoalEnv): + """docstring for CustomGoalEnv.""" + + def __init__(self): + super(CustomGoalEnv, self).__init__() + self.observation_space = spaces.Dict( + { + "observation": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)), + "achieved_goal": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)), + "desired_goal": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)), + } + ) + self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32) + + def reset(self): + return self.observation_space.sample() + + def step(self, action): + return self.observation_space.sample(), 0.0, False, {} + + +def check_time_feature(obs, timestep, max_timesteps): + assert np.allclose(obs[-1], 1.0 - timestep / max_timesteps) + + +def test_time_feature(): + env = gym.make("Pendulum-v0") + env = TimeFeatureWrapper(env) + check_env(env, warn=False) + # Check for four episodes + max_timesteps = 200 + obs = env.reset() + for _ in range(4): + check_time_feature(obs, timestep=0, max_timesteps=max_timesteps) + for step in range(1, max_timesteps + 1): + obs, _, done, _ = env.step(env.action_space.sample()) + check_time_feature(obs, timestep=step, max_timesteps=max_timesteps) + if done: + obs = env.reset() + + env = BitFlippingEnv() + with pytest.raises(AssertionError): + env = TimeFeatureWrapper(env) + + env = CustomGoalEnv() + env = TimeFeatureWrapper(env, max_steps=500) + obs = env.reset() + check_time_feature(obs["observation"], timestep=0, max_timesteps=500) + obs, _, _, _ = env.step(env.action_space.sample()) + check_time_feature(obs["observation"], timestep=1, max_timesteps=500) + + # In test mode, the time feature must be constant + env = gym.make("Pendulum-v0") + env = TimeFeatureWrapper(env, test_mode=True) + obs = env.reset() + check_time_feature(obs, timestep=0, max_timesteps=200) + obs, _, _, _ = env.step(env.action_space.sample()) + # Should be the same + check_time_feature(obs, timestep=0, max_timesteps=200)