diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3fc4a7f..0723975 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,13 +4,14 @@ Changelog ========== -Release 1.3.1a5 (WIP) +Release 1.3.1a6 (WIP) ------------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ - Dropped python 3.6 support - Upgraded to Stable-Baselines3 >= 1.3.1a4 +- ``MaskablePPO`` was updated to match latest SB3 ``PPO`` version (timeout handling and new method for the policy object) New Features: ^^^^^^^^^^^^^ @@ -24,6 +25,7 @@ Deprecations: Others: ^^^^^^^ +- Improve test coverage for ``MaskablePPO`` Documentation: diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index 449d362..1461eb6 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -86,7 +86,6 @@ class MaskableActorCriticPolicy(BasePolicy): self.features_dim = self.features_extractor.features_dim self.normalize_images = normalize_images - # Action distribution self.action_dist = make_masked_proba_distribution(action_space) @@ -106,7 +105,9 @@ class MaskableActorCriticPolicy(BasePolicy): :param action_masks: Action masks to apply to the action distribution :return: action, value and log probability of the action """ - latent_pi, latent_vf = self._get_latent(obs) + # Preprocess the observation if needed + features = self.extract_features(obs) + latent_pi, latent_vf = self.mlp_extractor(features) # Evaluate the values for the given observations values = self.value_net(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) @@ -179,21 +180,6 @@ class MaskableActorCriticPolicy(BasePolicy): # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) - def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: - """ - Get the latent code (i.e., activations of the last layer of each network) - for the different networks. - - :param obs: Observation - :return: Latent codes - for the actor, the value function and for gSDE function - """ - # Preprocess the observation if needed - features = self.extract_features(obs) - latent_pi, latent_vf = self.mlp_extractor(features) - - return latent_pi, latent_vf - def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> MaskableDistribution: """ Retrieve action distribution given the latent codes. @@ -218,11 +204,7 @@ class MaskableActorCriticPolicy(BasePolicy): :param action_masks: Action masks to apply to the action distribution :return: Taken action according to the policy """ - latent_pi, _ = self._get_latent(observation) - distribution = self._get_action_dist_from_latent(latent_pi) - if action_masks is not None: - distribution.apply_masking(action_masks) - return distribution.get_actions(deterministic=deterministic) + return self.get_distribution(observation, action_masks).get_actions(deterministic=deterministic) def predict( self, @@ -291,7 +273,8 @@ class MaskableActorCriticPolicy(BasePolicy): :return: estimated value, log likelihood of taking those actions and entropy of the action distribution. """ - latent_pi, latent_vf = self._get_latent(obs) + features = self.extract_features(obs) + latent_pi, latent_vf = self.mlp_extractor(features) distribution = self._get_action_dist_from_latent(latent_pi) if action_masks is not None: distribution.apply_masking(action_masks) @@ -299,6 +282,32 @@ class MaskableActorCriticPolicy(BasePolicy): values = self.value_net(latent_vf) return values, log_prob, distribution.entropy() + def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution: + """ + Get the current policy distribution given the observations. + + :param obs: + :param action_masks: + :return: the action distribution. + """ + features = self.extract_features(obs) + latent_pi = self.mlp_extractor.forward_actor(features) + distribution = self._get_action_dist_from_latent(latent_pi) + if action_masks is not None: + distribution.apply_masking(action_masks) + return distribution + + def predict_values(self, obs: th.Tensor) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: + :return: the estimated values. + """ + features = self.extract_features(obs) + latent_vf = self.mlp_extractor.forward_critic(features) + return self.value_net(latent_vf) + class MaskableActorCriticCnnPolicy(MaskableActorCriticPolicy): """ diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index ad08054..26c2d8e 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -336,6 +336,19 @@ class MaskablePPO(OnPolicyAlgorithm): # Reshape in case of discrete action actions = actions.reshape(-1, 1) + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done in enumerate(dones): + if ( + done + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_value = self.policy.predict_values(terminal_obs)[0] + rewards[idx] += self.gamma * terminal_value + rollout_buffer.add( self._last_obs, actions, @@ -350,11 +363,9 @@ class MaskablePPO(OnPolicyAlgorithm): with th.no_grad(): # Compute value for the last timestep - obs_tensor = obs_as_tensor(new_obs, self.device) - # Masking is not needed here, the choice of action doesn't matter. # We only want the value of the current observation. - _, values, _ = self.policy.forward(obs_tensor) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) @@ -541,7 +552,7 @@ class MaskablePPO(OnPolicyAlgorithm): # Display training infos if log_interval is not None and iteration % log_interval == 0: - fps = int(self.num_timesteps / (time.time() - self.start_time)) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time)) self.logger.record("time/iterations", iteration, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index dbc4395..e6eaed8 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.3.1a5 +1.3.1a6 diff --git a/tests/test_distributions.py b/tests/test_distributions.py index a54a9a5..183952d 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -98,11 +98,16 @@ class TestMaskableCategoricalDistribution: dist.apply_masking(None) # But now we can - dist.proba_distribution(th.randn(1, DIMS)) - dist.log_prob(th.randint(DIMS - 1, (3, 1))) + action_logits = th.randn(1, DIMS) + dist.proba_distribution(action_logits) + actions = th.randint(DIMS - 1, (3, 1)) + dist.log_prob(actions) dist.entropy() dist.sample() dist.mode() + # Test api + dist.actions_from_params(action_logits) + dist.log_prob_from_params(action_logits) dist.apply_masking(None) def test_logits_must_align_with_dims(self): @@ -174,11 +179,16 @@ class TestMaskableMultiCategoricalDistribution: dist.apply_masking(None) # But now we can - dist.proba_distribution(th.randn(1, DIMS_PER_CAT * NUM_CATS)) - dist.log_prob(th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS))) + action_logits = th.randn(1, DIMS_PER_CAT * NUM_CATS) + dist.proba_distribution(action_logits) + actions = th.randint(DIMS_PER_CAT - 1, (3, NUM_CATS)) + dist.log_prob(actions) dist.entropy() dist.sample() dist.mode() + # Test api + dist.actions_from_params(action_logits) + dist.log_prob_from_params(action_logits) dist.apply_masking(None) def test_logits_must_align_with_dims(self): @@ -258,11 +268,16 @@ class TestMaskableBernoulliDistribution: dist.apply_masking(None) # But now we can - dist.proba_distribution(th.randn(1, 2 * DIMS)) - dist.log_prob(th.randint(1, (2, DIMS))) + action_logits = th.randn(1, 2 * DIMS) + dist.proba_distribution(action_logits) + actions = th.randint(1, (2, DIMS)) + dist.log_prob(actions) dist.entropy() dist.sample() dist.mode() + # Test api + dist.actions_from_params(action_logits) + dist.log_prob_from_params(action_logits) dist.apply_masking(None) def test_logits_must_align_with_dims(self): diff --git a/tests/test_invalid_actions.py b/tests/test_invalid_actions.py index 222191f..d530cfc 100644 --- a/tests/test_invalid_actions.py +++ b/tests/test_invalid_actions.py @@ -55,6 +55,22 @@ def test_identity(): evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False) +def test_bootstraping(): + # Max ep length = 100 by default + env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10) + env = gym.wrappers.TimeLimit(env, 30) + model = MaskablePPO("MlpPolicy", env, n_steps=64, seed=8) + model.learn(128) + + +def test_eval_env(): + env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10) + eval_env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10) + model = MaskablePPO("MlpPolicy", env, clip_range_vf=0.2, n_steps=32, seed=8) + model.learn(32, eval_env=eval_env, eval_freq=16) + model.learn(32, reset_num_timesteps=False) + + def test_supports_discrete_action_space(): """ No errors using algorithm with an env that has a discrete action space @@ -221,7 +237,7 @@ def test_dict_obs(): env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10) env = ToDictWrapper(env) model = MaskablePPO("MultiInputPolicy", env, n_steps=64, seed=8) - model.learn(100) + model.learn(64) evaluate_policy(model, env, warn=False) # Mask all actions except the good one, a random model should succeed @@ -229,3 +245,13 @@ def test_dict_obs(): env = ToDictWrapper(env) model = MaskablePPO("MultiInputPolicy", env, seed=8) evaluate_policy(model, env, reward_threshold=99, warn=False) + # MultiDiscrete + env = InvalidActionEnvMultiDiscrete(dims=[2, 3], n_invalid_actions=1) + env = ToDictWrapper(env) + model = MaskablePPO("MultiInputPolicy", env, n_steps=32, seed=8) + model.learn(32) + # MultiBinary + env = InvalidActionEnvMultiBinary(dims=3, n_invalid_actions=1) + env = ToDictWrapper(env) + model = MaskablePPO("MultiInputPolicy", env, n_steps=32, seed=8) + model.learn(32)