Update AsyncEval seeding (#185)

* Update AsyncEval seeding

* Update random seed in tests
This commit is contained in:
Antonin RAFFIN 2023-05-20 10:57:31 +02:00 committed by GitHub
parent 86fb056fda
commit d467d7a844
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 10 additions and 10 deletions

View File

@ -3,7 +3,7 @@
Changelog Changelog
========== ==========
Release 2.0.0a8 (WIP) Release 2.0.0a9 (WIP)
-------------------------- --------------------------
**Gymnasium support** **Gymnasium support**

View File

@ -8,7 +8,6 @@ import torch as th
from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.running_mean_std import RunningMeanStd from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.utils import compat_gym_seed
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
@ -32,9 +31,9 @@ def _worker(
:param n_eval_episodes: Number of evaluation episodes per candidate. :param n_eval_episodes: Number of evaluation episodes per candidate.
""" """
parent_remote.close() parent_remote.close()
env = worker_env_wrapper.var() vec_env: VecEnv = worker_env_wrapper.var()
train_policy = train_policy_wrapper.var train_policy = train_policy_wrapper.var
vec_normalize = unwrap_vec_normalize(env) vec_normalize = unwrap_vec_normalize(vec_env)
if vec_normalize is not None: if vec_normalize is not None:
obs_rms = vec_normalize.obs_rms obs_rms = vec_normalize.obs_rms
else: else:
@ -49,7 +48,7 @@ def _worker(
train_policy.load_from_vector(candidate_weights.cpu()) train_policy.load_from_vector(candidate_weights.cpu())
episode_rewards, episode_lengths = evaluate_policy( episode_rewards, episode_lengths = evaluate_policy(
train_policy, train_policy,
env, vec_env,
n_eval_episodes=n_eval_episodes, n_eval_episodes=n_eval_episodes,
return_episode_rewards=True, return_episode_rewards=True,
warn=False, warn=False,
@ -57,14 +56,15 @@ def _worker(
results.append((weights_idx, (episode_rewards, episode_lengths))) results.append((weights_idx, (episode_rewards, episode_lengths)))
remote.send(results) remote.send(results)
elif cmd == "seed": elif cmd == "seed":
remote.send(compat_gym_seed(env, seed=data)) # Note: the seed will only be effective at the next reset
remote.send(vec_env.seed(seed=data))
elif cmd == "get_obs_rms": elif cmd == "get_obs_rms":
remote.send(obs_rms) remote.send(obs_rms)
elif cmd == "sync_obs_rms": elif cmd == "sync_obs_rms":
vec_normalize.obs_rms = data vec_normalize.obs_rms = data
obs_rms = data obs_rms = data
elif cmd == "close": elif cmd == "close":
env.close() vec_env.close()
remote.close() remote.close()
break break
else: else:

View File

@ -1 +1 @@
2.0.0a8 2.0.0a9

View File

@ -65,7 +65,7 @@ setup(
packages=[package for package in find_packages() if package.startswith("sb3_contrib")], packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]}, package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[ install_requires=[
"stable_baselines3>=2.0.0a4", "stable_baselines3>=2.0.0a9",
], ],
description="Contrib package of Stable Baselines3, experimental code.", description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin", author="Antonin Raffin",

View File

@ -30,7 +30,7 @@ def test_discrete(model_class, env):
elif n_steps == TRPO: elif n_steps == TRPO:
kwargs = dict(n_steps=256, cg_max_steps=5) kwargs = dict(n_steps=256, cg_max_steps=5)
model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps) model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=0, **kwargs).learn(n_steps)
evaluate_policy(model, vec_env, n_eval_episodes=20, reward_threshold=90, warn=False) evaluate_policy(model, vec_env, n_eval_episodes=20, reward_threshold=90, warn=False)
obs = vec_env.reset() obs = vec_env.reset()