Update Maskable PPO to match SB3 PPO + improve coverage (#56)

This commit is contained in:
Antonin RAFFIN 2021-12-10 12:48:19 +01:00 committed by GitHub
parent 20b5351086
commit b44689b0ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 99 additions and 36 deletions

View File

@ -4,13 +4,14 @@ Changelog
==========
Release 1.3.1a5 (WIP)
Release 1.3.1a6 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Dropped python 3.6 support
- Upgraded to Stable-Baselines3 >= 1.3.1a4
- ``MaskablePPO`` was updated to match latest SB3 ``PPO`` version (timeout handling and new method for the policy object)
New Features:
^^^^^^^^^^^^^
@ -24,6 +25,7 @@ Deprecations:
Others:
^^^^^^^
- Improve test coverage for ``MaskablePPO``
Documentation:

View File

@ -86,7 +86,6 @@ class MaskableActorCriticPolicy(BasePolicy):
self.features_dim = self.features_extractor.features_dim
self.normalize_images = normalize_images
# Action distribution
self.action_dist = make_masked_proba_distribution(action_space)
@ -106,7 +105,9 @@ class MaskableActorCriticPolicy(BasePolicy):
:param action_masks: Action masks to apply to the action distribution
:return: action, value and log probability of the action
"""
latent_pi, latent_vf = self._get_latent(obs)
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
@ -179,21 +180,6 @@ class MaskableActorCriticPolicy(BasePolicy):
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
Get the latent code (i.e., activations of the last layer of each network)
for the different networks.
:param obs: Observation
:return: Latent codes
for the actor, the value function and for gSDE function
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
return latent_pi, latent_vf
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
"""
Retrieve action distribution given the latent codes.
@ -218,11 +204,7 @@ class MaskableActorCriticPolicy(BasePolicy):
:param action_masks: Action masks to apply to the action distribution
:return: Taken action according to the policy
"""
latent_pi, _ = self._get_latent(observation)
distribution = self._get_action_dist_from_latent(latent_pi)
if action_masks is not None:
distribution.apply_masking(action_masks)
return distribution.get_actions(deterministic=deterministic)
return self.get_distribution(observation, action_masks).get_actions(deterministic=deterministic)
def predict(
self,
@ -291,7 +273,8 @@ class MaskableActorCriticPolicy(BasePolicy):
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
latent_pi, latent_vf = self._get_latent(obs)
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
distribution = self._get_action_dist_from_latent(latent_pi)
if action_masks is not None:
distribution.apply_masking(action_masks)
@ -299,6 +282,32 @@ class MaskableActorCriticPolicy(BasePolicy):
values = self.value_net(latent_vf)
return values, log_prob, distribution.entropy()
def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
"""
Get the current policy distribution given the observations.
:param obs:
:param action_masks:
:return: the action distribution.
"""
features = self.extract_features(obs)
latent_pi = self.mlp_extractor.forward_actor(features)
distribution = self._get_action_dist_from_latent(latent_pi)
if action_masks is not None:
distribution.apply_masking(action_masks)
return distribution
def predict_values(self, obs: th.Tensor) -> th.Tensor:
"""
Get the estimated values according to the current policy given the observations.
:param obs:
:return: the estimated values.
"""
features = self.extract_features(obs)
latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf)
class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
"""

View File

@ -336,6 +336,19 @@ class MaskablePPO(OnPolicyAlgorithm):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
# Handle timeout by bootstraping with value function
# see GitHub issue #633
for idx, done in enumerate(dones):
if (
done
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0]
rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(
self._last_obs,
actions,
@ -350,11 +363,9 @@ class MaskablePPO(OnPolicyAlgorithm):
with th.no_grad():
# Compute value for the last timestep
obs_tensor = obs_as_tensor(new_obs, self.device)
# Masking is not needed here, the choice of action doesn't matter.
# We only want the value of the current observation.
_, values, _ = self.policy.forward(obs_tensor)
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
@ -541,7 +552,7 @@ class MaskablePPO(OnPolicyAlgorithm):
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
fps = int(self.num_timesteps / (time.time() - self.start_time))
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))

View File

@ -1 +1 @@
1.3.1a5
1.3.1a6

View File

@ -98,11 +98,16 @@ class TestMaskableCategoricalDistribution:
dist.apply_masking(None)
# But now we can
dist.proba_distribution(th.randn(1, DIMS))
dist.log_prob(th.randint(DIMS - 1, (3, 1)))
action_logits = th.randn(1, DIMS)
dist.proba_distribution(action_logits)
actions = th.randint(DIMS - 1, (3, 1))
dist.log_prob(actions)
dist.entropy()
dist.sample()
dist.mode()
# Test api
dist.actions_from_params(action_logits)
dist.log_prob_from_params(action_logits)
dist.apply_masking(None)
def test_logits_must_align_with_dims(self):
@ -174,11 +179,16 @@ class TestMaskableMultiCategoricalDistribution:
dist.apply_masking(None)
# But now we can
dist.proba_distribution(th.randn(1, DIMS_PER_CAT * NUM_CATS))
dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS)))
action_logits = th.randn(1, DIMS_PER_CAT * NUM_CATS)
dist.proba_distribution(action_logits)
actions = th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS))
dist.log_prob(actions)
dist.entropy()
dist.sample()
dist.mode()
# Test api
dist.actions_from_params(action_logits)
dist.log_prob_from_params(action_logits)
dist.apply_masking(None)
def test_logits_must_align_with_dims(self):
@ -258,11 +268,16 @@ class TestMaskableBernoulliDistribution:
dist.apply_masking(None)
# But now we can
dist.proba_distribution(th.randn(1, 2 * DIMS))
dist.log_prob(th.randint(1, (2, DIMS)))
action_logits = th.randn(1, 2 * DIMS)
dist.proba_distribution(action_logits)
actions = th.randint(1, (2, DIMS))
dist.log_prob(actions)
dist.entropy()
dist.sample()
dist.mode()
# Test api
dist.actions_from_params(action_logits)
dist.log_prob_from_params(action_logits)
dist.apply_masking(None)
def test_logits_must_align_with_dims(self):

View File

@ -55,6 +55,22 @@ def test_identity():
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
def test_bootstraping():
# Max ep length = 100 by default
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
env = gym.wrappers.TimeLimit(env, 30)
model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8)
model.learn(128)
def test_eval_env():
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
eval_env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
model = MaskablePPO("MlpPolicy", env, clip_range_vf=0.2, n_steps=32, seed=8)
model.learn(32, eval_env=eval_env, eval_freq=16)
model.learn(32, reset_num_timesteps=False)
def test_supports_discrete_action_space():
"""
No errors using algorithm with an env that has a discrete action space
@ -221,7 +237,7 @@ def test_dict_obs():
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
env = ToDictWrapper(env)
model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8)
model.learn(100)
model.learn(64)
evaluate_policy(model, env, warn=False)
# Mask all actions except the good one, a random model should succeed
@ -229,3 +245,13 @@ def test_dict_obs():
env = ToDictWrapper(env)
model = MaskablePPO("MultiInputPolicy", env, seed=8)
evaluate_policy(model, env, reward_threshold=99, warn=False)
# MultiDiscrete
env = InvalidActionEnvMultiDiscrete(dims=[2, 3], n_invalid_actions=1)
env = ToDictWrapper(env)
model = MaskablePPO("MultiInputPolicy", env, n_steps=32, seed=8)
model.learn(32)
# MultiBinary
env = InvalidActionEnvMultiBinary(dims=3, n_invalid_actions=1)
env = ToDictWrapper(env)
model = MaskablePPO("MultiInputPolicy", env, n_steps=32, seed=8)
model.learn(32)