From d467d7a844f060dae9d3826f3e7f793fece6e982 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 20 May 2023 10:57:31 +0200 Subject: [PATCH] Update AsyncEval seeding (#185) * Update AsyncEval seeding * Update random seed in tests --- docs/misc/changelog.rst | 2 +- sb3_contrib/common/vec_env/async_eval.py | 12 ++++++------ sb3_contrib/version.txt | 2 +- setup.py | 2 +- tests/test_identity.py | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b54eb65..5fe10cc 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.0.0a8 (WIP) +Release 2.0.0a9 (WIP) -------------------------- **Gymnasium support** diff --git a/sb3_contrib/common/vec_env/async_eval.py b/sb3_contrib/common/vec_env/async_eval.py index 5fd88ae..f2500ad 100644 --- a/sb3_contrib/common/vec_env/async_eval.py +++ b/sb3_contrib/common/vec_env/async_eval.py @@ -8,7 +8,6 @@ import torch as th from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.running_mean_std import RunningMeanStd -from stable_baselines3.common.utils import compat_gym_seed from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper @@ -32,9 +31,9 @@ def _worker( :param n_eval_episodes: Number of evaluation episodes per candidate. """ parent_remote.close() - env = worker_env_wrapper.var() + vec_env: VecEnv = worker_env_wrapper.var() train_policy = train_policy_wrapper.var - vec_normalize = unwrap_vec_normalize(env) + vec_normalize = unwrap_vec_normalize(vec_env) if vec_normalize is not None: obs_rms = vec_normalize.obs_rms else: @@ -49,7 +48,7 @@ def _worker( train_policy.load_from_vector(candidate_weights.cpu()) episode_rewards, episode_lengths = evaluate_policy( train_policy, - env, + vec_env, n_eval_episodes=n_eval_episodes, return_episode_rewards=True, warn=False, @@ -57,14 +56,15 @@ def _worker( results.append((weights_idx, (episode_rewards, episode_lengths))) remote.send(results) elif cmd == "seed": - remote.send(compat_gym_seed(env, seed=data)) + # Note: the seed will only be effective at the next reset + remote.send(vec_env.seed(seed=data)) elif cmd == "get_obs_rms": remote.send(obs_rms) elif cmd == "sync_obs_rms": vec_normalize.obs_rms = data obs_rms = data elif cmd == "close": - env.close() + vec_env.close() remote.close() break else: diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 55a349a..caf8413 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.0.0a8 +2.0.0a9 diff --git a/setup.py b/setup.py index 5b7e392..77e9c17 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.0.0a4", + "stable_baselines3>=2.0.0a9", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_identity.py b/tests/test_identity.py index 3725db8..0e00cd5 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -30,7 +30,7 @@ def test_discrete(model_class, env): elif n_steps == TRPO: kwargs = dict(n_steps=256, cg_max_steps=5) - model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=0, **kwargs).learn(n_steps) evaluate_policy(model, vec_env, n_eval_episodes=20, reward_threshold=90, warn=False) obs = vec_env.reset()