diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a1b5a68..14493ac 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.8.0a2 (WIP) +Release 1.8.0a4 (WIP) -------------------------- Breaking Changes: @@ -23,6 +23,10 @@ Deprecations: Others: ^^^^^^^ +Documentation: +^^^^^^^^^^^^^^ +- Added warning about potential crashes caused by ``check_env`` in the ``MaskablePPO`` docs (@AlexPasqua) + Release 1.7.0 (2023-01-10) -------------------------- diff --git a/docs/modules/ppo_mask.rst b/docs/modules/ppo_mask.rst index a43f596..87309e1 100644 --- a/docs/modules/ppo_mask.rst +++ b/docs/modules/ppo_mask.rst @@ -83,6 +83,11 @@ returns the invalid action mask (``True`` if the action is valid, ``False`` othe If the environment implements the invalid action mask but using a different name, you can use the ``ActionMasker`` to specify the name (see `PR #25 `_): +.. note:: + If you are using a custom environment and you want to debug it with ``check_env``, + it will execute the method ``step`` passing a random action to it (using ``action_space.sample()``), + without taking into account the invalid actions mask (see `issue #145 `_). + .. code-block:: python diff --git a/sb3_contrib/ars/ars.py b/sb3_contrib/ars/ars.py index 4b05412..e5ea9ce 100644 --- a/sb3_contrib/ars/ars.py +++ b/sb3_contrib/ars/ars.py @@ -71,7 +71,6 @@ class ARS(BaseAlgorithm): device: Union[th.device, str] = "cpu", _init_setup_model: bool = True, ): - super().__init__( policy, env, @@ -186,7 +185,6 @@ class ARS(BaseAlgorithm): results = async_eval.get_results() for weights_idx, (episode_rewards, episode_lengths) in results: - # Update reward to cancel out alive bonus if needed candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths) batch_steps += np.sum(episode_lengths) @@ -214,7 +212,6 @@ class ARS(BaseAlgorithm): else: # Single process, synchronous version for weights_idx in range(self.pop_size): - # Load current candidate weights train_policy.load_from_vector(candidate_weights[weights_idx].cpu()) # Evaluate the candidate diff --git a/sb3_contrib/ars/policies.py b/sb3_contrib/ars/policies.py index 782d441..f484cf4 100644 --- a/sb3_contrib/ars/policies.py +++ b/sb3_contrib/ars/policies.py @@ -30,7 +30,6 @@ class ARSPolicy(BasePolicy): with_bias: bool = True, squash_output: bool = True, ): - super().__init__( observation_space, action_space, @@ -103,7 +102,6 @@ class ARSLinearPolicy(ARSPolicy): with_bias: bool = False, squash_output: bool = False, ): - super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output) diff --git a/sb3_contrib/common/maskable/buffers.py b/sb3_contrib/common/maskable/buffers.py index 1a43539..3bd1327 100644 --- a/sb3_contrib/common/maskable/buffers.py +++ b/sb3_contrib/common/maskable/buffers.py @@ -185,7 +185,6 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer): indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data if not self.generator_ready: - for key, obs in self.observations.items(): self.observations[key] = self.swap_and_flatten(obs) @@ -205,7 +204,6 @@ class MaskableDictRolloutBuffer(DictRolloutBuffer): start_idx += batch_size def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples: - return MaskableDictRolloutBufferSamples( observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, actions=self.to_torch(self.actions[batch_inds]), diff --git a/sb3_contrib/common/maskable/callbacks.py b/sb3_contrib/common/maskable/callbacks.py index 46927c6..f0839e8 100644 --- a/sb3_contrib/common/maskable/callbacks.py +++ b/sb3_contrib/common/maskable/callbacks.py @@ -36,11 +36,9 @@ class MaskableEvalCallback(EvalCallback): self.use_masking = use_masking def _on_step(self) -> bool: - continue_training = True if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: - # Sync training and eval env if there is VecNormalize if self.model.get_vec_normalize_env() is not None: try: diff --git a/sb3_contrib/common/maskable/evaluation.py b/sb3_contrib/common/maskable/evaluation.py index 63ff0b4..4790d5c 100644 --- a/sb3_contrib/common/maskable/evaluation.py +++ b/sb3_contrib/common/maskable/evaluation.py @@ -108,7 +108,6 @@ def evaluate_policy( # noqa: C901 current_lengths += 1 for i in range(n_envs): if episode_counts[i] < episode_count_targets[i]: - # unpack values so that the callback can access the local variables reward = rewards[i] done = dones[i] diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index b0b434b..920449b 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -57,7 +57,6 @@ class MaskableActorCriticPolicy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - if optimizer_kwargs is None: optimizer_kwargs = {} # Small values to avoid NaN in Adam optimizer diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 9cff998..78d67d9 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -462,7 +462,6 @@ class RecurrentPPO(OnPolicyAlgorithm): callback.on_training_start(locals(), globals()) while self.num_timesteps < total_timesteps: - continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) if continue_training is False: diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index cc78dfb..6090097 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -123,7 +123,6 @@ class QRDQNPolicy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super().__init__( observation_space, action_space, diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 66ea3c8..f9f0a40 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -91,7 +91,6 @@ class QRDQN(OffPolicyAlgorithm): device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): - super().__init__( policy, env, @@ -255,7 +254,6 @@ class QRDQN(OffPolicyAlgorithm): reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfQRDQN: - return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 9d452d4..8e19fa2 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -101,7 +101,6 @@ class TQC(OffPolicyAlgorithm): device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): - super().__init__( policy, env, @@ -292,7 +291,6 @@ class TQC(OffPolicyAlgorithm): reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfTQC: - return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index eee1f98..2a7986d 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -99,7 +99,6 @@ class TRPO(OnPolicyAlgorithm): device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): - super().__init__( policy, env, @@ -235,7 +234,6 @@ class TRPO(OnPolicyAlgorithm): # This will only loop once (get all data in one go) for rollout_data in self.rollout_buffer.get(batch_size=None): - # Optional: sub-sample data for faster computation if self.sub_sampling_factor > 1: rollout_data = RolloutBufferSamples( @@ -308,7 +306,6 @@ class TRPO(OnPolicyAlgorithm): with th.no_grad(): # Line-search (backtracking) for _ in range(self.line_search_max_iter): - start_idx = 0 # Applying the scaled step direction for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape): @@ -408,7 +405,6 @@ class TRPO(OnPolicyAlgorithm): reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfTRPO: - return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index c3d22c0..e8175d3 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.8.0a2 +1.8.0a4 diff --git a/tests/test_deterministic.py b/tests/test_deterministic.py index 458d3f0..87c374f 100644 --- a/tests/test_deterministic.py +++ b/tests/test_deterministic.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.noise import NormalActionNoise @@ -25,7 +26,13 @@ def test_deterministic_training_common(algo): ars_multi = False if algo in [TQC]: - kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4}) + kwargs.update( + { + "action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), + "learning_starts": 100, + "train_freq": 4, + } + ) else: if algo == QRDQN: env_id = "CartPole-v1"