Update Maskable PPO to match SB3 PPO + improve coverage (#56)
This commit is contained in:
parent
20b5351086
commit
b44689b0ea
|
|
@ -4,13 +4,14 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.3.1a5 (WIP)
|
||||
Release 1.3.1a6 (WIP)
|
||||
-------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Dropped python 3.6 support
|
||||
- Upgraded to Stable-Baselines3 >= 1.3.1a4
|
||||
- ``MaskablePPO`` was updated to match latest SB3 ``PPO`` version (timeout handling and new method for the policy object)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -24,6 +25,7 @@ Deprecations:
|
|||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Improve test coverage for ``MaskablePPO``
|
||||
|
||||
|
||||
Documentation:
|
||||
|
|
|
|||
|
|
@ -86,7 +86,6 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.normalize_images = normalize_images
|
||||
|
||||
# Action distribution
|
||||
self.action_dist = make_masked_proba_distribution(action_space)
|
||||
|
||||
|
|
@ -106,7 +105,9 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
:param action_masks: Action masks to apply to the action distribution
|
||||
:return: action, value and log probability of the action
|
||||
"""
|
||||
latent_pi, latent_vf = self._get_latent(obs)
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
# Evaluate the values for the given observations
|
||||
values = self.value_net(latent_vf)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
|
|
@ -179,21 +180,6 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Get the latent code (i.e., activations of the last layer of each network)
|
||||
for the different networks.
|
||||
|
||||
:param obs: Observation
|
||||
:return: Latent codes
|
||||
for the actor, the value function and for gSDE function
|
||||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
|
||||
return latent_pi, latent_vf
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
|
||||
"""
|
||||
Retrieve action distribution given the latent codes.
|
||||
|
|
@ -218,11 +204,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
:param action_masks: Action masks to apply to the action distribution
|
||||
:return: Taken action according to the policy
|
||||
"""
|
||||
latent_pi, _ = self._get_latent(observation)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
if action_masks is not None:
|
||||
distribution.apply_masking(action_masks)
|
||||
return distribution.get_actions(deterministic=deterministic)
|
||||
return self.get_distribution(observation, action_masks).get_actions(deterministic=deterministic)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
|
|
@ -291,7 +273,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
:return: estimated value, log likelihood of taking those actions
|
||||
and entropy of the action distribution.
|
||||
"""
|
||||
latent_pi, latent_vf = self._get_latent(obs)
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
if action_masks is not None:
|
||||
distribution.apply_masking(action_masks)
|
||||
|
|
@ -299,6 +282,32 @@ class MaskableActorCriticPolicy(BasePolicy):
|
|||
values = self.value_net(latent_vf)
|
||||
return values, log_prob, distribution.entropy()
|
||||
|
||||
def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
|
||||
"""
|
||||
Get the current policy distribution given the observations.
|
||||
|
||||
:param obs:
|
||||
:param action_masks:
|
||||
:return: the action distribution.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
latent_pi = self.mlp_extractor.forward_actor(features)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
if action_masks is not None:
|
||||
distribution.apply_masking(action_masks)
|
||||
return distribution
|
||||
|
||||
def predict_values(self, obs: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Get the estimated values according to the current policy given the observations.
|
||||
|
||||
:param obs:
|
||||
:return: the estimated values.
|
||||
"""
|
||||
features = self.extract_features(obs)
|
||||
latent_vf = self.mlp_extractor.forward_critic(features)
|
||||
return self.value_net(latent_vf)
|
||||
|
||||
|
||||
class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -336,6 +336,19 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
# Reshape in case of discrete action
|
||||
actions = actions.reshape(-1, 1)
|
||||
|
||||
# Handle timeout by bootstraping with value function
|
||||
# see GitHub issue #633
|
||||
for idx, done in enumerate(dones):
|
||||
if (
|
||||
done
|
||||
and infos[idx].get("terminal_observation") is not None
|
||||
and infos[idx].get("TimeLimit.truncated", False)
|
||||
):
|
||||
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
|
||||
with th.no_grad():
|
||||
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
rollout_buffer.add(
|
||||
self._last_obs,
|
||||
actions,
|
||||
|
|
@ -350,11 +363,9 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
|
||||
with th.no_grad():
|
||||
# Compute value for the last timestep
|
||||
obs_tensor = obs_as_tensor(new_obs, self.device)
|
||||
|
||||
# Masking is not needed here, the choice of action doesn't matter.
|
||||
# We only want the value of the current observation.
|
||||
_, values, _ = self.policy.forward(obs_tensor)
|
||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||
|
||||
|
|
@ -541,7 +552,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
|||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.3.1a5
|
||||
1.3.1a6
|
||||
|
|
|
|||
|
|
@ -98,11 +98,16 @@ class TestMaskableCategoricalDistribution:
|
|||
dist.apply_masking(None)
|
||||
|
||||
# But now we can
|
||||
dist.proba_distribution(th.randn(1, DIMS))
|
||||
dist.log_prob(th.randint(DIMS - 1, (3, 1)))
|
||||
action_logits = th.randn(1, DIMS)
|
||||
dist.proba_distribution(action_logits)
|
||||
actions = th.randint(DIMS - 1, (3, 1))
|
||||
dist.log_prob(actions)
|
||||
dist.entropy()
|
||||
dist.sample()
|
||||
dist.mode()
|
||||
# Test api
|
||||
dist.actions_from_params(action_logits)
|
||||
dist.log_prob_from_params(action_logits)
|
||||
dist.apply_masking(None)
|
||||
|
||||
def test_logits_must_align_with_dims(self):
|
||||
|
|
@ -174,11 +179,16 @@ class TestMaskableMultiCategoricalDistribution:
|
|||
dist.apply_masking(None)
|
||||
|
||||
# But now we can
|
||||
dist.proba_distribution(th.randn(1, DIMS_PER_CAT * NUM_CATS))
|
||||
dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS)))
|
||||
action_logits = th.randn(1, DIMS_PER_CAT * NUM_CATS)
|
||||
dist.proba_distribution(action_logits)
|
||||
actions = th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS))
|
||||
dist.log_prob(actions)
|
||||
dist.entropy()
|
||||
dist.sample()
|
||||
dist.mode()
|
||||
# Test api
|
||||
dist.actions_from_params(action_logits)
|
||||
dist.log_prob_from_params(action_logits)
|
||||
dist.apply_masking(None)
|
||||
|
||||
def test_logits_must_align_with_dims(self):
|
||||
|
|
@ -258,11 +268,16 @@ class TestMaskableBernoulliDistribution:
|
|||
dist.apply_masking(None)
|
||||
|
||||
# But now we can
|
||||
dist.proba_distribution(th.randn(1, 2 * DIMS))
|
||||
dist.log_prob(th.randint(1, (2, DIMS)))
|
||||
action_logits = th.randn(1, 2 * DIMS)
|
||||
dist.proba_distribution(action_logits)
|
||||
actions = th.randint(1, (2, DIMS))
|
||||
dist.log_prob(actions)
|
||||
dist.entropy()
|
||||
dist.sample()
|
||||
dist.mode()
|
||||
# Test api
|
||||
dist.actions_from_params(action_logits)
|
||||
dist.log_prob_from_params(action_logits)
|
||||
dist.apply_masking(None)
|
||||
|
||||
def test_logits_must_align_with_dims(self):
|
||||
|
|
|
|||
|
|
@ -55,6 +55,22 @@ def test_identity():
|
|||
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||
|
||||
|
||||
def test_bootstraping():
|
||||
# Max ep length = 100 by default
|
||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
env = gym.wrappers.TimeLimit(env, 30)
|
||||
model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8)
|
||||
model.learn(128)
|
||||
|
||||
|
||||
def test_eval_env():
|
||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
eval_env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
model = MaskablePPO("MlpPolicy", env, clip_range_vf=0.2, n_steps=32, seed=8)
|
||||
model.learn(32, eval_env=eval_env, eval_freq=16)
|
||||
model.learn(32, reset_num_timesteps=False)
|
||||
|
||||
|
||||
def test_supports_discrete_action_space():
|
||||
"""
|
||||
No errors using algorithm with an env that has a discrete action space
|
||||
|
|
@ -221,7 +237,7 @@ def test_dict_obs():
|
|||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||
env = ToDictWrapper(env)
|
||||
model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8)
|
||||
model.learn(100)
|
||||
model.learn(64)
|
||||
evaluate_policy(model, env, warn=False)
|
||||
|
||||
# Mask all actions except the good one, a random model should succeed
|
||||
|
|
@ -229,3 +245,13 @@ def test_dict_obs():
|
|||
env = ToDictWrapper(env)
|
||||
model = MaskablePPO("MultiInputPolicy", env, seed=8)
|
||||
evaluate_policy(model, env, reward_threshold=99, warn=False)
|
||||
# MultiDiscrete
|
||||
env = InvalidActionEnvMultiDiscrete(dims=[2, 3], n_invalid_actions=1)
|
||||
env = ToDictWrapper(env)
|
||||
model = MaskablePPO("MultiInputPolicy", env, n_steps=32, seed=8)
|
||||
model.learn(32)
|
||||
# MultiBinary
|
||||
env = InvalidActionEnvMultiBinary(dims=3, n_invalid_actions=1)
|
||||
env = ToDictWrapper(env)
|
||||
model = MaskablePPO("MultiInputPolicy", env, n_steps=32, seed=8)
|
||||
model.learn(32)
|
||||
|
|
|
|||
Loading…
Reference in New Issue