Deprecate ``create_eval_env``, ``eval_env`` and ``eval_freq`` parameter (#105)
* Deprecate ``eval_env``, ``eval_freq```and ``create_eval_env`` * Update changelog * Typo * Raise deprecation warining in _setup_learn * Upgrade to latest SB3 version and update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
2490468b11
commit
dec7b5303a
|
|
@ -3,6 +3,29 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
|
Release 1.6.2 (2022-10-10)
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
**Progress bar and upgrade to latest SB3 version**
|
||||||
|
|
||||||
|
Breaking Changes:
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
- Upgraded to Stable-Baselines3 >= 1.6.2
|
||||||
|
|
||||||
|
New Features:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
Deprecations:
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
- Deprecate parameters ``eval_env``, ``eval_freq`` and ``create_eval_env``
|
||||||
|
|
||||||
|
Others:
|
||||||
|
^^^^^^^
|
||||||
|
|
||||||
Release 1.6.1 (2022-09-29)
|
Release 1.6.1 (2022-09-29)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
|
|
@ -13,7 +36,6 @@ Breaking Changes:
|
||||||
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
|
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
|
||||||
- Upgraded to Stable-Baselines3 >= 1.6.1
|
- Upgraded to Stable-Baselines3 >= 1.6.1
|
||||||
|
|
||||||
|
|
||||||
New Features:
|
New Features:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
@ -25,7 +47,6 @@ Bug Fixes:
|
||||||
- Fixed missing verbose parameter passing in the ``MaskableEvalCallback`` constructor (@burakdmb)
|
- Fixed missing verbose parameter passing in the ``MaskableEvalCallback`` constructor (@burakdmb)
|
||||||
- Fixed the issue that when updating the target network in QRDQN, TQC, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875)
|
- Fixed the issue that when updating the target network in QRDQN, TQC, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875)
|
||||||
|
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
@ -35,7 +56,7 @@ Others:
|
||||||
|
|
||||||
|
|
||||||
Release 1.6.0 (2022-07-11)
|
Release 1.6.0 (2022-07-11)
|
||||||
-------------------------------
|
--------------------------
|
||||||
|
|
||||||
**Add RecurrentPPO (aka PPO LSTM)**
|
**Add RecurrentPPO (aka PPO LSTM)**
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -321,8 +321,12 @@ class ARS(BaseAlgorithm):
|
||||||
:param callback: callback(s) called at every step with state of the algorithm.
|
:param callback: callback(s) called at every step with state of the algorithm.
|
||||||
:param log_interval: The number of timesteps before logging.
|
:param log_interval: The number of timesteps before logging.
|
||||||
:param tb_log_name: the name of the run for TensorBoard logging
|
:param tb_log_name: the name of the run for TensorBoard logging
|
||||||
:param eval_env: Environment that will be used to evaluate the agent
|
:param eval_env: Environment to use for evaluation.
|
||||||
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
|
Please use `EvalCallback` or a custom Callback instead.
|
||||||
|
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
|
Please use `EvalCallback` or a custom Callback instead.
|
||||||
:param n_eval_episodes: Number of episode to evaluate the agent
|
:param n_eval_episodes: Number of episode to evaluate the agent
|
||||||
:param eval_log_path: Path to a folder where the evaluations will be saved
|
:param eval_log_path: Path to a folder where the evaluations will be saved
|
||||||
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
|
@ -59,7 +60,8 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
By default, there is no limit on the kl div.
|
By default, there is no limit on the kl div.
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
:param create_eval_env: Whether to create a second environment that will be
|
:param create_eval_env: Whether to create a second environment that will be
|
||||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
:param seed: Seed for the pseudo random generators
|
:param seed: Seed for the pseudo random generators
|
||||||
|
|
@ -183,9 +185,13 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
) -> BaseCallback:
|
) -> BaseCallback:
|
||||||
"""
|
"""
|
||||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||||
:param eval_freq: How many steps between evaluations; if None, do not evaluate.
|
:param eval_env: Environment to use for evaluation.
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
|
Please use `MaskableEvalCallback` or a custom Callback instead.
|
||||||
|
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
|
Please use `MaskableEvalCallback` or a custom Callback instead.
|
||||||
:param n_eval_episodes: How many episodes to play per evaluation
|
:param n_eval_episodes: How many episodes to play per evaluation
|
||||||
:param n_eval_episodes: Number of episodes to rollout during evaluation.
|
|
||||||
:param log_path: Path to a folder where the evaluations will be saved
|
:param log_path: Path to a folder where the evaluations will be saved
|
||||||
:param use_masking: Whether or not to use invalid action masks during evaluation
|
:param use_masking: Whether or not to use invalid action masks during evaluation
|
||||||
:return: A hybrid callback calling `callback` and performing evaluation.
|
:return: A hybrid callback calling `callback` and performing evaluation.
|
||||||
|
|
@ -234,8 +240,12 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
:param total_timesteps: The total number of samples (env steps) to train on
|
:param total_timesteps: The total number of samples (env steps) to train on
|
||||||
:param eval_env: Environment to use for evaluation.
|
:param eval_env: Environment to use for evaluation.
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
|
Please use `MaskableEvalCallback` or a custom Callback instead.
|
||||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||||
:param eval_freq: How many steps between evaluations
|
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
|
Please use `MaskableEvalCallback` or a custom Callback instead.
|
||||||
:param n_eval_episodes: How many episodes to play per evaluation
|
:param n_eval_episodes: How many episodes to play per evaluation
|
||||||
:param log_path: Path to a folder where the evaluations will be saved
|
:param log_path: Path to a folder where the evaluations will be saved
|
||||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
||||||
|
|
@ -244,6 +254,17 @@ class MaskablePPO(OnPolicyAlgorithm):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if eval_env is not None or eval_freq != -1:
|
||||||
|
warnings.warn(
|
||||||
|
"Parameters `eval_env` and `eval_freq` are deprecated and will be removed in the future. "
|
||||||
|
"Please use `MaskableEvalCallback` or a custom Callback instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
# By setting the `stacklevel` we refer to the initial caller of the deprecated feature.
|
||||||
|
# This causes the the `DepricationWarning` to not be ignored and to be shown to the user. See
|
||||||
|
# https://github.com/DLR-RM/stable-baselines3/pull/1082#discussion_r989842855 for more details.
|
||||||
|
stacklevel=4,
|
||||||
|
)
|
||||||
|
|
||||||
self.start_time = time.time_ns()
|
self.start_time = time.time_ns()
|
||||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
if self.ep_info_buffer is None or reset_num_timesteps:
|
||||||
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,8 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
By default, there is no limit on the kl div.
|
By default, there is no limit on the kl div.
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
:param create_eval_env: Whether to create a second environment that will be
|
:param create_eval_env: Whether to create a second environment that will be
|
||||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
:param seed: Seed for the pseudo random generators
|
:param seed: Seed for the pseudo random generators
|
||||||
|
|
@ -211,8 +212,12 @@ class RecurrentPPO(OnPolicyAlgorithm):
|
||||||
|
|
||||||
:param total_timesteps: The total number of samples (env steps) to train on
|
:param total_timesteps: The total number of samples (env steps) to train on
|
||||||
:param eval_env: Environment to use for evaluation.
|
:param eval_env: Environment to use for evaluation.
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
|
Please use `EvalCallback` or a custom Callback instead.
|
||||||
:param callback: Callback(s) called at every step with state of the algorithm.
|
:param callback: Callback(s) called at every step with state of the algorithm.
|
||||||
:param eval_freq: How many steps between evaluations
|
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
|
Please use `EvalCallback` or a custom Callback instead.
|
||||||
:param n_eval_episodes: How many episodes to play per evaluation
|
:param n_eval_episodes: How many episodes to play per evaluation
|
||||||
:param log_path: Path to a folder where the evaluations will be saved
|
:param log_path: Path to a folder where the evaluations will be saved
|
||||||
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,8 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping)
|
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping)
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
:param create_eval_env: Whether to create a second environment that will be
|
:param create_eval_env: Whether to create a second environment that will be
|
||||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
:param seed: Seed for the pseudo random generators
|
:param seed: Seed for the pseudo random generators
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,8 @@ class TQC(OffPolicyAlgorithm):
|
||||||
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
|
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
|
||||||
during the warm up phase (before learning starts)
|
during the warm up phase (before learning starts)
|
||||||
:param create_eval_env: Whether to create a second environment that will be
|
:param create_eval_env: Whether to create a second environment that will be
|
||||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
:param seed: Seed for the pseudo random generators
|
:param seed: Seed for the pseudo random generators
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,8 @@ class TRPO(OnPolicyAlgorithm):
|
||||||
see p40-42 of John Schulman thesis http://joschu.net/docs/thesis.pdf
|
see p40-42 of John Schulman thesis http://joschu.net/docs/thesis.pdf
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
:param create_eval_env: Whether to create a second environment that will be
|
:param create_eval_env: Whether to create a second environment that will be
|
||||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
used for evaluating the agent periodically (Only available when passing string for the environment).
|
||||||
|
Caution, this parameter is deprecated and will be removed in the future.
|
||||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
:param seed: Seed for the pseudo random generators
|
:param seed: Seed for the pseudo random generators
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
1.6.1
|
1.6.2
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[metadata]
|
[metadata]
|
||||||
# This includes the license file in the wheel.
|
# This includes the license file in the wheel.
|
||||||
license_file = LICENSE
|
license_files = LICENSE
|
||||||
|
|
||||||
[tool:pytest]
|
[tool:pytest]
|
||||||
# Deterministic ordering for tests; useful for pytest-xdist.
|
# Deterministic ordering for tests; useful for pytest-xdist.
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=1.6.1",
|
"stable_baselines3>=1.6.2",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,8 @@ def test_eval_env():
|
||||||
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||||
eval_env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
eval_env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
|
||||||
model = MaskablePPO("MlpPolicy", env, clip_range_vf=0.2, n_steps=32, seed=8)
|
model = MaskablePPO("MlpPolicy", env, clip_range_vf=0.2, n_steps=32, seed=8)
|
||||||
model.learn(32, eval_env=eval_env, eval_freq=16)
|
with pytest.warns(DeprecationWarning): # `eval_env` is deprecated
|
||||||
|
model.learn(32, eval_env=eval_env, eval_freq=16)
|
||||||
model.learn(32, reset_num_timesteps=False)
|
model.learn(32, reset_num_timesteps=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -129,30 +129,32 @@ def test_check():
|
||||||
|
|
||||||
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
||||||
def test_run(env):
|
def test_run(env):
|
||||||
model = RecurrentPPO(
|
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||||
"MlpLstmPolicy",
|
model = RecurrentPPO(
|
||||||
env,
|
"MlpLstmPolicy",
|
||||||
n_steps=16,
|
env,
|
||||||
seed=0,
|
n_steps=16,
|
||||||
create_eval_env=True,
|
seed=0,
|
||||||
)
|
create_eval_env=True,
|
||||||
|
)
|
||||||
|
|
||||||
model.learn(total_timesteps=32, eval_freq=16)
|
model.learn(total_timesteps=32, eval_freq=16)
|
||||||
|
|
||||||
|
|
||||||
def test_run_sde():
|
def test_run_sde():
|
||||||
model = RecurrentPPO(
|
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||||
"MlpLstmPolicy",
|
model = RecurrentPPO(
|
||||||
"Pendulum-v1",
|
"MlpLstmPolicy",
|
||||||
n_steps=16,
|
"Pendulum-v1",
|
||||||
seed=0,
|
n_steps=16,
|
||||||
create_eval_env=True,
|
seed=0,
|
||||||
sde_sample_freq=4,
|
create_eval_env=True,
|
||||||
use_sde=True,
|
sde_sample_freq=4,
|
||||||
clip_range_vf=0.1,
|
use_sde=True,
|
||||||
)
|
clip_range_vf=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
model.learn(total_timesteps=200, eval_freq=150)
|
model.learn(total_timesteps=200, eval_freq=150)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
|
|
@ -10,16 +10,17 @@ from sb3_contrib.common.vec_env import AsyncEval
|
||||||
|
|
||||||
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
|
||||||
def test_tqc(ent_coef):
|
def test_tqc(ent_coef):
|
||||||
model = TQC(
|
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||||
"MlpPolicy",
|
model = TQC(
|
||||||
"Pendulum-v1",
|
"MlpPolicy",
|
||||||
policy_kwargs=dict(net_arch=[64, 64]),
|
"Pendulum-v1",
|
||||||
learning_starts=100,
|
policy_kwargs=dict(net_arch=[64, 64]),
|
||||||
verbose=1,
|
learning_starts=100,
|
||||||
create_eval_env=True,
|
verbose=1,
|
||||||
ent_coef=ent_coef,
|
create_eval_env=True,
|
||||||
)
|
ent_coef=ent_coef,
|
||||||
model.learn(total_timesteps=300, eval_freq=250)
|
)
|
||||||
|
model.learn(total_timesteps=300, eval_freq=250)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_critics", [1, 3])
|
@pytest.mark.parametrize("n_critics", [1, 3])
|
||||||
|
|
@ -50,17 +51,18 @@ def test_sde():
|
||||||
|
|
||||||
|
|
||||||
def test_qrdqn():
|
def test_qrdqn():
|
||||||
model = QRDQN(
|
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||||
"MlpPolicy",
|
model = QRDQN(
|
||||||
"CartPole-v1",
|
"MlpPolicy",
|
||||||
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
|
"CartPole-v1",
|
||||||
learning_starts=100,
|
policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]),
|
||||||
buffer_size=500,
|
learning_starts=100,
|
||||||
learning_rate=3e-4,
|
buffer_size=500,
|
||||||
verbose=1,
|
learning_rate=3e-4,
|
||||||
create_eval_env=True,
|
verbose=1,
|
||||||
)
|
create_eval_env=True,
|
||||||
model.learn(total_timesteps=500, eval_freq=250)
|
)
|
||||||
|
model.learn(total_timesteps=500, eval_freq=250)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
||||||
|
|
@ -89,7 +91,8 @@ def test_trpo_params():
|
||||||
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
|
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
|
||||||
def test_ars(policy_str, env_id):
|
def test_ars(policy_str, env_id):
|
||||||
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
|
model = ARS(policy_str, env_id, n_delta=1, verbose=1, seed=0)
|
||||||
model.learn(total_timesteps=500, log_interval=1, eval_freq=250)
|
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
|
||||||
|
model.learn(total_timesteps=500, log_interval=1, eval_freq=250)
|
||||||
|
|
||||||
|
|
||||||
def test_ars_multi_env():
|
def test_ars_multi_env():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue