diff --git a/README.md b/README.md index 36cdb24..7280244 100644 --- a/README.md +++ b/README.md @@ -18,3 +18,19 @@ Implemented: ``` pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib ``` + + +## Citing the Project + +To cite this repository in publications (please cite SB3 directly): + +``` +@misc{stable-baselines3, + author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah}, + title = {Stable Baselines3}, + year = {2019}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}}, +} +``` diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index e58cb11..19a4037 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -17,25 +17,25 @@ class Actor(BasePolicy): """ Actor network (policy) for TQC. - :param observation_space: (gym.spaces.Space) Obervation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: ([int]) Network architecture - :param features_extractor: (nn.Module) Network to extract features + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: (int) Number of features - :param activation_fn: (Type[nn.Module]) Activation function - :param use_sde: (bool) Whether to use State Dependent Exploration or not - :param log_std_init: (float) Initial value for the log standard deviation - :param full_std: (bool) Whether to use (n_features x n_actions) parameters + :param features_dim: Number of features + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE. - :param sde_net_arch: ([int]) Network architecture for extracting features + :param sde_net_arch: Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. - :param normalize_images: (bool) Whether to normalize images or not, + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) """ @@ -131,7 +131,7 @@ class Actor(BasePolicy): but is slightly different when using ``expln`` function (cf StateDependentNoiseDistribution doc). - :return: (th.Tensor) + :return: """ msg = "get_std() is only available when using gSDE" assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg @@ -141,7 +141,7 @@ class Actor(BasePolicy): """ Sample new weights for the exploration matrix, when using gSDE. - :param batch_size: (int) + :param batch_size: """ msg = "reset_noise() is only available when using gSDE" assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg @@ -151,8 +151,8 @@ class Actor(BasePolicy): """ Get the parameters for the action distribution. - :param obs: (th.Tensor) - :return: (Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]) + :param obs: + :return: Mean, standard deviation and optional keyword arguments. """ features = self.extract_features(obs) @@ -187,9 +187,9 @@ class Actor(BasePolicy): """ Evaluate actions according to the current policy, given the observations. Only useful when using SDE. - :param obs: (th.Tensor) - :param actions: (th.Tensor) - :return: (th.Tensor) log likelihood of taking those actions + :param obs: + :param actions: + :return: log likelihood of taking those actions """ mean_actions, log_std, kwargs = self.get_action_dist_params(obs) self.action_dist.proba_distribution(mean_actions, log_std, **kwargs) @@ -200,14 +200,14 @@ class Critic(BaseModel): """ Critic network (q-value function) for TQC. - :param observation_space: (gym.spaces.Space) Obervation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: ([int]) Network architecture - :param features_extractor: (nn.Module) Network to extract features + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: (int) Number of features - :param activation_fn: (Type[nn.Module]) Activation function - :param normalize_images: (bool) Whether to normalize images or not, + :param features_dim: Number of features + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) """ @@ -257,28 +257,28 @@ class TQCPolicy(BasePolicy): """ Policy class (with both actor and critic) for TQC. - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (callable) Learning rate schedule (could be constant) - :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param use_sde: (bool) Whether to use State Dependent Exploration or not - :param log_std_init: (float) Initial value for the log standard deviation - :param sde_net_arch: ([int]) Network architecture for extracting features + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param sde_net_arch: Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments to pass to the feature extractor. - :param normalize_images: (bool) Whether to normalize images or not, + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ @@ -388,7 +388,7 @@ class TQCPolicy(BasePolicy): """ Sample new weights for the exploration matrix, when using gSDE. - :param batch_size: (int) + :param batch_size: """ self.actor.reset_noise(batch_size=batch_size) @@ -412,26 +412,26 @@ class CnnPolicy(TQCPolicy): """ Policy class (with both actor and critic) for TQC. - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (callable) Learning rate schedule (could be constant) - :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param use_sde: (bool) Whether to use State Dependent Exploration or not - :param log_std_init: (float) Initial value for the log standard deviation - :param sde_net_arch: ([int]) Network architecture for extracting features + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param sde_net_arch: Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param normalize_images: (bool) Whether to normalize images or not, + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 18fc902..ea6818e 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -7,7 +7,6 @@ from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.utils import polyak_update -from tqdm import tqdm from sb3_contrib.tqc.policies import TQCPolicy @@ -15,48 +14,48 @@ from sb3_contrib.tqc.policies import TQCPolicy class TQC(OffPolicyAlgorithm): """ - Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics + Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics. Paper: https://arxiv.org/abs/2005.04269 - :param policy: (TQCPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) - :param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str) - :param learning_rate: (float or callable) learning rate for adam optimizer, + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0) - :param buffer_size: (int) size of the replay buffer - :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts - :param batch_size: (int) Minibatch size for each gradient update - :param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) - :param gamma: (float) the discount factor - :param train_freq: (int) Update the model every ``train_freq`` steps. - :param gradient_steps: (int) How many gradient update after each step - :param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. + :param gradient_steps: How many gradient update after each step + :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes. Note that this cannot be used at the same time as ``train_freq`` - :param action_noise: (ActionNoise) the action noise type (None by default), this can help + :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. - :param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 - :param ent_coef: (str or float) Entropy regularization coefficient. (Equivalent to + :param ent_coef: Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value) - :param target_update_interval: (int) update the target network every ``target_network_update_freq`` + :param target_update_interval: update the target network every ``target_network_update_freq`` gradient steps. - :param target_entropy: (str or float) target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) - :param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE) + :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) - :param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling + :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts) - :param create_eval_env: (bool) Whether to create a second environment that will be + :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug - :param seed: (int) Seed for the pseudo random generators - :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. - :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param _init_setup_model: Whether or not to build the network at the creation of the instance """ def __init__( @@ -274,171 +273,6 @@ class TQC(OffPolicyAlgorithm): if len(ent_coef_losses) > 0: logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) - def pretrain( - self, - gradient_steps: int, - batch_size: int = 64, - n_action_samples: int = -1, - target_update_interval: int = 1, - tau: float = 0.005, - strategy: str = "exp", - reduce: str = "mean", - exp_temperature: float = 1.0, - off_policy_update_freq: int = -1, - ) -> None: - """ - Pretrain with Critic Regularized Regression (CRR) - Paper: https://arxiv.org/abs/2006.15134 - """ - # Update optimizers learning rate - optimizers = [self.actor.optimizer, self.critic.optimizer] - if self.ent_coef_optimizer is not None: - optimizers += [self.ent_coef_optimizer] - - # Update learning rate according to lr schedule - self._update_learning_rate(optimizers) - - actor_losses, critic_losses = [], [] - - for gradient_step in tqdm(range(gradient_steps)): - - if off_policy_update_freq > 0 and gradient_step % off_policy_update_freq == 0: - self.train(gradient_steps=1, batch_size=batch_size) - continue - - # Sample replay buffer - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) - - # We need to sample because `log_std` may have changed between two gradient steps - if self.use_sde: - self.actor.reset_noise() - - # Action by the current actor for the sampled state - _, log_prob = self.actor.action_log_prob(replay_data.observations) - log_prob = log_prob.reshape(-1, 1) - - ent_coef_loss = None - if self.ent_coef_optimizer is not None: - # Important: detach the variable from the graph - # so we don't change it with other losses - # see https://github.com/rail-berkeley/softlearning/issues/60 - ent_coef = th.exp(self.log_ent_coef.detach()) - ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() - else: - ent_coef = self.ent_coef_tensor - - self.replay_buffer.ent_coef = ent_coef.item() - - # Optimize entropy coefficient, also called - # entropy temperature or alpha in the paper - if ent_coef_loss is not None: - self.ent_coef_optimizer.zero_grad() - ent_coef_loss.backward() - self.ent_coef_optimizer.step() - - with th.no_grad(): - top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics - # Select action according to policy - next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) - # Compute and cut quantiles at the next state - # batch x nets x quantiles - next_z = self.critic_target(replay_data.next_observations, next_actions) - sorted_z, _ = th.sort(next_z.reshape(batch_size, -1)) - sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop] - - target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1) - # td error + entropy term - q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q - - # Get current Q estimates - # using action from the replay buffer - current_z = self.critic(replay_data.observations, replay_data.actions) - # Compute critic loss - critic_loss = self.quantile_huber_loss(current_z, q_backup) - critic_losses.append(critic_loss.item()) - - # Optimize the critic - self.critic.optimizer.zero_grad() - critic_loss.backward() - self.critic.optimizer.step() - - if strategy == "bc": - # Behavior cloning - weight = 1 - else: - # Tensor version: TODO: check that the reshape works as expected - # cleaner but not faster on cpu for large batch size - # with th.no_grad(): - # # Q-value for the action in the buffer - # qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True) - # # Create tensor to avoid loop - # # Note: For SDE, we need to sample several matrices - # obs_ = replay_data.observations.repeat(n_action_samples, 1) - # if self.use_sde: - # self.actor.reset_noise(batch_size * n_action_samples) - # actions_pi, _ = self.actor.action_log_prob(obs_) - # qf_pi = self.critic(obs_, actions_pi.detach()).mean(2).mean(1, keepdim=True) - # # Agregate: reduce mean or reduce max - # if reduce == "max": - # _, qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).max(axis=0) - # else: - # qf_agg = qf_pi.reshape(n_action_samples, batch_size, 1).mean(axis=0) - with th.no_grad(): - qf_buffer = self.critic(replay_data.observations, replay_data.actions).mean(2).mean(1, keepdim=True) - - # Use the mean (as done in AWAC, cf rlkit) - if n_action_samples == -1: - actions_pi = self.actor.forward(replay_data.observations, deterministic=True) - qf_agg = self.critic(replay_data.observations, actions_pi).mean(2).mean(1, keepdim=True) - else: - qf_agg = None - for _ in range(n_action_samples): - if self.use_sde: - self.actor.reset_noise() - actions_pi, _ = self.actor.action_log_prob(replay_data.observations) - - qf_pi = self.critic(replay_data.observations, actions_pi.detach()).mean(2).mean(1, keepdim=True) - if qf_agg is None: - if reduce == "max": - qf_agg = qf_pi - else: - qf_agg = qf_pi / n_action_samples - else: - if reduce == "max": - qf_agg = th.max(qf_pi, qf_agg) - else: - qf_agg += qf_pi / n_action_samples - - advantage = qf_buffer - qf_agg - if strategy == "binary": - # binary advantage - weight = advantage > 0 - else: - # exp advantage - exp_clip = 20.0 - weight = th.clamp(th.exp(advantage / exp_temperature), 0.0, exp_clip) - - # Log prob by the current actor for the sampled state and action - log_prob = self.actor.evaluate_actions(replay_data.observations, replay_data.actions) - log_prob = log_prob.reshape(-1, 1) - - # weigthed regression loss (close to policy gradient loss) - actor_loss = (-log_prob * weight).mean() - # actor_loss = ((actions_pi - replay_data.actions * weight) ** 2).mean() - actor_losses.append(actor_loss.item()) - - # Optimize the actor - self.actor.optimizer.zero_grad() - actor_loss.backward() - self.actor.optimizer.step() - - # Update target networks - if gradient_step % target_update_interval == 0: - polyak_update(self.critic.parameters(), self.critic_target.parameters(), tau) - - if self.use_sde: - print(f"std={(self.actor.get_std()).mean().item()}") - def learn( self, total_timesteps: int, @@ -469,7 +303,7 @@ class TQC(OffPolicyAlgorithm): Returns the names of the parameters that should be excluded by default when saving the model. - :return: (List[str]) List of parameters that should be excluded from save + :return: List of parameters that should be excluded from save """ # Exclude aliases return super(TQC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] diff --git a/setup.py b/setup.py index 41055cd..f5adfad 100644 --- a/setup.py +++ b/setup.py @@ -23,11 +23,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3[tests,docs]>=0.9.0a0", - # For progress bar when using CRR - "tqdm" - # Enable CMA - # "cma", + "stable_baselines3[tests,docs]>=0.9.0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_run.py b/tests/test_run.py index 2cbcf41..49cd2bf 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -14,41 +14,19 @@ def test_tqc(ent_coef): create_eval_env=True, ent_coef=ent_coef, ) - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=500, eval_freq=250) @pytest.mark.parametrize("n_critics", [1, 3]) def test_n_critics(n_critics): # Test TQC with different number of critics model = TQC( - "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1 + "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1 ) - model.learn(total_timesteps=1000) + model.learn(total_timesteps=500) - -# "CartPole-v1" -# @pytest.mark.parametrize("env_id", ["MountainCarContinuous-v0"]) -# def test_cmaes(env_id): -# if CMAES is None: -# return -# model = CMAES("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[64]), verbose=1, create_eval_env=True) -# model.learn(total_timesteps=50000, eval_freq=10000) - - -@pytest.mark.parametrize("strategy", ["exp", "bc", "binary"]) -@pytest.mark.parametrize("reduce", ["mean", "max"]) -def test_crr(tmp_path, strategy, reduce): +def test_sde(): model = TQC( - "MlpPolicy", - "Pendulum-v0", - policy_kwargs=dict(net_arch=[64]), - learning_starts=1000, - verbose=1, - create_eval_env=True, - action_noise=None, - use_sde=False, + "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64]), use_sde=True, learning_starts=100, verbose=1 ) - - model.learn(total_timesteps=1000, eval_freq=0) - for n_action_samples in [1, 2, -1]: - model.pretrain(gradient_steps=32, batch_size=32, n_action_samples=n_action_samples, strategy=strategy, reduce=reduce) + model.learn(total_timesteps=500) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 4415267..875efb9 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -45,7 +45,7 @@ def test_save_load(tmp_path, model_class): # create model model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1) - model.learn(total_timesteps=500, eval_freq=250) + model.learn(total_timesteps=500) env.reset() observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) @@ -154,7 +154,7 @@ def test_save_load(tmp_path, model_class): assert np.allclose(selected_actions, new_selected_actions, 1e-4) # check if learn still works - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=500) del model @@ -177,17 +177,17 @@ def test_set_env(model_class): # create model model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16])) # learn - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=500) # change env model.set_env(env2) # learn again - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=500) # change env test wrapping model.set_env(env3) # learn again - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=500) @pytest.mark.parametrize("model_class", MODEL_LIST) @@ -247,45 +247,6 @@ def test_save_load_replay_buffer(tmp_path, model_class): ) -@pytest.mark.parametrize("model_class", [TQC]) -@pytest.mark.parametrize("optimize_memory_usage", [False, True]) -def test_warn_buffer(recwarn, model_class, optimize_memory_usage): - """ - When using memory efficient replay buffer, - a warning must be emitted when calling `.learn()` - multiple times. - See https://github.com/DLR-RM/stable-baselines3/issues/46 - """ - # remove gym warnings - warnings.filterwarnings(action="ignore", category=DeprecationWarning) - warnings.filterwarnings(action="ignore", category=UserWarning, module="gym") - - model = model_class( - "MlpPolicy", - select_env(model_class), - buffer_size=100, - optimize_memory_usage=optimize_memory_usage, - policy_kwargs=dict(net_arch=[64]), - learning_starts=10, - ) - - model.learn(150) - - model.learn(150, reset_num_timesteps=False) - - # Check that there is no warning - assert len(recwarn) == 0 - - model.learn(150) - - if optimize_memory_usage: - assert len(recwarn) == 1 - warning = recwarn.pop(UserWarning) - assert "The last trajectory in the replay buffer will be truncated" in str(warning.message) - else: - assert len(recwarn) == 0 - - @pytest.mark.parametrize("model_class", MODEL_LIST) @pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) def test_save_load_policy(tmp_path, model_class, policy_str): @@ -309,7 +270,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str): # create model model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs) - model.learn(total_timesteps=500, eval_freq=250) + model.learn(total_timesteps=500) env.reset() observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) @@ -375,52 +336,3 @@ def test_save_load_policy(tmp_path, model_class, policy_str): os.remove(tmp_path / "policy.pkl") if actor_class is not None: os.remove(tmp_path / "actor.pkl") - - -@pytest.mark.parametrize("pathtype", [str, pathlib.Path]) -def test_open_file_str_pathlib(tmp_path, pathtype): - # check that suffix isn't added because we used open_path first - with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1: - save_to_pkl(fp1, "foo") - assert fp1.closed - with pytest.warns(None) as record: - assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo" - assert not record - - # test custom suffix - with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1: - save_to_pkl(fp1, "foo") - assert fp1.closed - with pytest.warns(None) as record: - assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo" - assert not record - - # test without suffix - with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1: - save_to_pkl(fp1, "foo") - assert fp1.closed - with pytest.warns(None) as record: - assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo" - assert not record - - # test that a warning is raised when the path doesn't exist - with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1: - save_to_pkl(fp1, "foo") - assert fp1.closed - with pytest.warns(None) as record: - assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo" - assert len(record) == 0 - - with pytest.warns(None) as record: - assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo" - assert len(record) == 1 - - fp = pathlib.Path(f"{tmp_path}/t2").open("w") - fp.write("rubbish") - fp.close() - # test that a warning is only raised when verbose = 0 - with pytest.warns(None) as record: - open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close() - open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close() - open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close() - assert len(record) == 1