Update Maskable PPO to match SB3 PPO + improve coverage (#56)
This commit is contained in:
parent
20b5351086
commit
b44689b0ea
|
|
@ -4,13 +4,14 @@ Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
Release 1.3.1a5 (WIP)
|
Release 1.3.1a6 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
- Dropped python 3.6 support
|
- Dropped python 3.6 support
|
||||||
- Upgraded to Stable-Baselines3 >= 1.3.1a4
|
- Upgraded to Stable-Baselines3 >= 1.3.1a4
|
||||||
|
- ``MaskablePPO`` was updated to match latest SB3 ``PPO`` version (timeout handling and new method for the policy object)
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
@ -24,6 +25,7 @@ Deprecations:
|
||||||
|
|
||||||
Others:
|
Others:
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
- Improve test coverage for ``MaskablePPO``
|
||||||
|
|
||||||
|
|
||||||
Documentation:
|
Documentation:
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,6 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
self.features_dim = self.features_extractor.features_dim
|
self.features_dim = self.features_extractor.features_dim
|
||||||
|
|
||||||
self.normalize_images = normalize_images
|
self.normalize_images = normalize_images
|
||||||
|
|
||||||
# Action distribution
|
# Action distribution
|
||||||
self.action_dist = make_masked_proba_distribution(action_space)
|
self.action_dist = make_masked_proba_distribution(action_space)
|
||||||
|
|
||||||
|
|
@ -106,7 +105,9 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
:param action_masks: Action masks to apply to the action distribution
|
:param action_masks: Action masks to apply to the action distribution
|
||||||
:return: action, value and log probability of the action
|
:return: action, value and log probability of the action
|
||||||
"""
|
"""
|
||||||
latent_pi, latent_vf = self._get_latent(obs)
|
# Preprocess the observation if needed
|
||||||
|
features = self.extract_features(obs)
|
||||||
|
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||||
# Evaluate the values for the given observations
|
# Evaluate the values for the given observations
|
||||||
values = self.value_net(latent_vf)
|
values = self.value_net(latent_vf)
|
||||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
|
|
@ -179,21 +180,6 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
# Setup optimizer with initial learning rate
|
# Setup optimizer with initial learning rate
|
||||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||||
|
|
||||||
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
|
||||||
"""
|
|
||||||
Get the latent code (i.e., activations of the last layer of each network)
|
|
||||||
for the different networks.
|
|
||||||
|
|
||||||
:param obs: Observation
|
|
||||||
:return: Latent codes
|
|
||||||
for the actor, the value function and for gSDE function
|
|
||||||
"""
|
|
||||||
# Preprocess the observation if needed
|
|
||||||
features = self.extract_features(obs)
|
|
||||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
|
||||||
|
|
||||||
return latent_pi, latent_vf
|
|
||||||
|
|
||||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
|
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution:
|
||||||
"""
|
"""
|
||||||
Retrieve action distribution given the latent codes.
|
Retrieve action distribution given the latent codes.
|
||||||
|
|
@ -218,11 +204,7 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
:param action_masks: Action masks to apply to the action distribution
|
:param action_masks: Action masks to apply to the action distribution
|
||||||
:return: Taken action according to the policy
|
:return: Taken action according to the policy
|
||||||
"""
|
"""
|
||||||
latent_pi, _ = self._get_latent(observation)
|
return self.get_distribution(observation, action_masks).get_actions(deterministic=deterministic)
|
||||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
|
||||||
if action_masks is not None:
|
|
||||||
distribution.apply_masking(action_masks)
|
|
||||||
return distribution.get_actions(deterministic=deterministic)
|
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
|
|
@ -291,7 +273,8 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
:return: estimated value, log likelihood of taking those actions
|
:return: estimated value, log likelihood of taking those actions
|
||||||
and entropy of the action distribution.
|
and entropy of the action distribution.
|
||||||
"""
|
"""
|
||||||
latent_pi, latent_vf = self._get_latent(obs)
|
features = self.extract_features(obs)
|
||||||
|
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
if action_masks is not None:
|
if action_masks is not None:
|
||||||
distribution.apply_masking(action_masks)
|
distribution.apply_masking(action_masks)
|
||||||
|
|
@ -299,6 +282,32 @@ class MaskableActorCriticPolicy(BasePolicy):
|
||||||
values = self.value_net(latent_vf)
|
values = self.value_net(latent_vf)
|
||||||
return values, log_prob, distribution.entropy()
|
return values, log_prob, distribution.entropy()
|
||||||
|
|
||||||
|
def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution:
|
||||||
|
"""
|
||||||
|
Get the current policy distribution given the observations.
|
||||||
|
|
||||||
|
:param obs:
|
||||||
|
:param action_masks:
|
||||||
|
:return: the action distribution.
|
||||||
|
"""
|
||||||
|
features = self.extract_features(obs)
|
||||||
|
latent_pi = self.mlp_extractor.forward_actor(features)
|
||||||
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
|
if action_masks is not None:
|
||||||
|
distribution.apply_masking(action_masks)
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
def predict_values(self, obs: th.Tensor) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
Get the estimated values according to the current policy given the observations.
|
||||||
|
|
||||||
|
:param obs:
|
||||||
|
:return: the estimated values.
|
||||||
|
"""
|
||||||
|
features = self.extract_features(obs)
|
||||||
|
latent_vf = self.mlp_extractor.forward_critic(features)
|
||||||
|
return self.value_net(latent_vf)
|
||||||
|
|
||||||
|
|
||||||
class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -336,6 +336,19 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
# Reshape in case of discrete action
|
# Reshape in case of discrete action
|
||||||
actions = actions.reshape(-1, 1)
|
actions = actions.reshape(-1, 1)
|
||||||
|
|
||||||
|
# Handle timeout by bootstraping with value function
|
||||||
|
# see GitHub issue #633
|
||||||
|
for idx, done in enumerate(dones):
|
||||||
|
if (
|
||||||
|
done
|
||||||
|
and infos[idx].get("terminal_observation") is not None
|
||||||
|
and infos[idx].get("TimeLimit.truncated", False)
|
||||||
|
):
|
||||||
|
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
|
||||||
|
with th.no_grad():
|
||||||
|
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
||||||
|
rewards[idx] += self.gamma * terminal_value
|
||||||
|
|
||||||
rollout_buffer.add(
|
rollout_buffer.add(
|
||||||
self._last_obs,
|
self._last_obs,
|
||||||
actions,
|
actions,
|
||||||
|
|
@ -350,11 +363,9 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Compute value for the last timestep
|
# Compute value for the last timestep
|
||||||
obs_tensor = obs_as_tensor(new_obs, self.device)
|
|
||||||
|
|
||||||
# Masking is not needed here, the choice of action doesn't matter.
|
# Masking is not needed here, the choice of action doesn't matter.
|
||||||
# We only want the value of the current observation.
|
# We only want the value of the current observation.
|
||||||
_, values, _ = self.policy.forward(obs_tensor)
|
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
|
||||||
|
|
||||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||||
|
|
||||||
|
|
@ -541,7 +552,7 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
# Display training infos
|
# Display training infos
|
||||||
if log_interval is not None and iteration % log_interval == 0:
|
if log_interval is not None and iteration % log_interval == 0:
|
||||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
|
||||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.3.1a5
|
1.3.1a6
|
||||||
|
|
|
||||||
|
|
@ -98,11 +98,16 @@ class TestMaskableCategoricalDistribution:
|
||||||
dist.apply_masking(None)
|
dist.apply_masking(None)
|
||||||
|
|
||||||
# But now we can
|
# But now we can
|
||||||
dist.proba_distribution(th.randn(1, DIMS))
|
action_logits = th.randn(1, DIMS)
|
||||||
dist.log_prob(th.randint(DIMS - 1, (3, 1)))
|
dist.proba_distribution(action_logits)
|
||||||
|
actions = th.randint(DIMS - 1, (3, 1))
|
||||||
|
dist.log_prob(actions)
|
||||||
dist.entropy()
|
dist.entropy()
|
||||||
dist.sample()
|
dist.sample()
|
||||||
dist.mode()
|
dist.mode()
|
||||||
|
# Test api
|
||||||
|
dist.actions_from_params(action_logits)
|
||||||
|
dist.log_prob_from_params(action_logits)
|
||||||
dist.apply_masking(None)
|
dist.apply_masking(None)
|
||||||
|
|
||||||
def test_logits_must_align_with_dims(self):
|
def test_logits_must_align_with_dims(self):
|
||||||
|
|
@ -174,11 +179,16 @@ class TestMaskableMultiCategoricalDistribution:
|
||||||
dist.apply_masking(None)
|
dist.apply_masking(None)
|
||||||
|
|
||||||
# But now we can
|
# But now we can
|
||||||
dist.proba_distribution(th.randn(1, DIMS_PER_CAT * NUM_CATS))
|
action_logits = th.randn(1, DIMS_PER_CAT * NUM_CATS)
|
||||||
dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS)))
|
dist.proba_distribution(action_logits)
|
||||||
|
actions = th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS))
|
||||||
|
dist.log_prob(actions)
|
||||||
dist.entropy()
|
dist.entropy()
|
||||||
dist.sample()
|
dist.sample()
|
||||||
dist.mode()
|
dist.mode()
|
||||||
|
# Test api
|
||||||
|
dist.actions_from_params(action_logits)
|
||||||
|
dist.log_prob_from_params(action_logits)
|
||||||
dist.apply_masking(None)
|
dist.apply_masking(None)
|
||||||
|
|
||||||
def test_logits_must_align_with_dims(self):
|
def test_logits_must_align_with_dims(self):
|
||||||
|
|
@ -258,11 +268,16 @@ class TestMaskableBernoulliDistribution:
|
||||||
dist.apply_masking(None)
|
dist.apply_masking(None)
|
||||||
|
|
||||||
# But now we can
|
# But now we can
|
||||||
dist.proba_distribution(th.randn(1, 2 * DIMS))
|
action_logits = th.randn(1, 2 * DIMS)
|
||||||
dist.log_prob(th.randint(1, (2, DIMS)))
|
dist.proba_distribution(action_logits)
|
||||||
|
actions = th.randint(1, (2, DIMS))
|
||||||
|
dist.log_prob(actions)
|
||||||
dist.entropy()
|
dist.entropy()
|
||||||
dist.sample()
|
dist.sample()
|
||||||
dist.mode()
|
dist.mode()
|
||||||
|
# Test api
|
||||||
|
dist.actions_from_params(action_logits)
|
||||||
|
dist.log_prob_from_params(action_logits)
|
||||||
dist.apply_masking(None)
|
dist.apply_masking(None)
|
||||||
|
|
||||||
def test_logits_must_align_with_dims(self):
|
def test_logits_must_align_with_dims(self):
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,22 @@ def test_identity():
|
||||||
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bootstraping():
|
||||||
|
# Max ep length = 100 by default
|
||||||
|
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||||
|
env = gym.wrappers.TimeLimit(env, 30)
|
||||||
|
model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8)
|
||||||
|
model.learn(128)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_env():
|
||||||
|
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||||
|
eval_env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||||
|
model = MaskablePPO("MlpPolicy", env, clip_range_vf=0.2, n_steps=32, seed=8)
|
||||||
|
model.learn(32, eval_env=eval_env, eval_freq=16)
|
||||||
|
model.learn(32, reset_num_timesteps=False)
|
||||||
|
|
||||||
|
|
||||||
def test_supports_discrete_action_space():
|
def test_supports_discrete_action_space():
|
||||||
"""
|
"""
|
||||||
No errors using algorithm with an env that has a discrete action space
|
No errors using algorithm with an env that has a discrete action space
|
||||||
|
|
@ -221,7 +237,7 @@ def test_dict_obs():
|
||||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||||
env = ToDictWrapper(env)
|
env = ToDictWrapper(env)
|
||||||
model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8)
|
model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8)
|
||||||
model.learn(100)
|
model.learn(64)
|
||||||
evaluate_policy(model, env, warn=False)
|
evaluate_policy(model, env, warn=False)
|
||||||
|
|
||||||
# Mask all actions except the good one, a random model should succeed
|
# Mask all actions except the good one, a random model should succeed
|
||||||
|
|
@ -229,3 +245,13 @@ def test_dict_obs():
|
||||||
env = ToDictWrapper(env)
|
env = ToDictWrapper(env)
|
||||||
model = MaskablePPO("MultiInputPolicy", env, seed=8)
|
model = MaskablePPO("MultiInputPolicy", env, seed=8)
|
||||||
evaluate_policy(model, env, reward_threshold=99, warn=False)
|
evaluate_policy(model, env, reward_threshold=99, warn=False)
|
||||||
|
# MultiDiscrete
|
||||||
|
env = InvalidActionEnvMultiDiscrete(dims=[2, 3], n_invalid_actions=1)
|
||||||
|
env = ToDictWrapper(env)
|
||||||
|
model = MaskablePPO("MultiInputPolicy", env, n_steps=32, seed=8)
|
||||||
|
model.learn(32)
|
||||||
|
# MultiBinary
|
||||||
|
env = InvalidActionEnvMultiBinary(dims=3, n_invalid_actions=1)
|
||||||
|
env = ToDictWrapper(env)
|
||||||
|
model = MaskablePPO("MultiInputPolicy", env, n_steps=32, seed=8)
|
||||||
|
model.learn(32)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue