Update AsyncEval seeding (#185)
* Update AsyncEval seeding * Update random seed in tests
This commit is contained in:
parent
86fb056fda
commit
d467d7a844
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.0.0a8 (WIP)
|
||||
Release 2.0.0a9 (WIP)
|
||||
--------------------------
|
||||
|
||||
**Gymnasium support**
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import torch as th
|
|||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
||||
from stable_baselines3.common.utils import compat_gym_seed
|
||||
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
|
||||
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
|
||||
|
||||
|
|
@ -32,9 +31,9 @@ def _worker(
|
|||
:param n_eval_episodes: Number of evaluation episodes per candidate.
|
||||
"""
|
||||
parent_remote.close()
|
||||
env = worker_env_wrapper.var()
|
||||
vec_env: VecEnv = worker_env_wrapper.var()
|
||||
train_policy = train_policy_wrapper.var
|
||||
vec_normalize = unwrap_vec_normalize(env)
|
||||
vec_normalize = unwrap_vec_normalize(vec_env)
|
||||
if vec_normalize is not None:
|
||||
obs_rms = vec_normalize.obs_rms
|
||||
else:
|
||||
|
|
@ -49,7 +48,7 @@ def _worker(
|
|||
train_policy.load_from_vector(candidate_weights.cpu())
|
||||
episode_rewards, episode_lengths = evaluate_policy(
|
||||
train_policy,
|
||||
env,
|
||||
vec_env,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
return_episode_rewards=True,
|
||||
warn=False,
|
||||
|
|
@ -57,14 +56,15 @@ def _worker(
|
|||
results.append((weights_idx, (episode_rewards, episode_lengths)))
|
||||
remote.send(results)
|
||||
elif cmd == "seed":
|
||||
remote.send(compat_gym_seed(env, seed=data))
|
||||
# Note: the seed will only be effective at the next reset
|
||||
remote.send(vec_env.seed(seed=data))
|
||||
elif cmd == "get_obs_rms":
|
||||
remote.send(obs_rms)
|
||||
elif cmd == "sync_obs_rms":
|
||||
vec_normalize.obs_rms = data
|
||||
obs_rms = data
|
||||
elif cmd == "close":
|
||||
env.close()
|
||||
vec_env.close()
|
||||
remote.close()
|
||||
break
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.0.0a8
|
||||
2.0.0a9
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
|||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"stable_baselines3>=2.0.0a4",
|
||||
"stable_baselines3>=2.0.0a9",
|
||||
],
|
||||
description="Contrib package of Stable Baselines3, experimental code.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def test_discrete(model_class, env):
|
|||
elif n_steps == TRPO:
|
||||
kwargs = dict(n_steps=256, cg_max_steps=5)
|
||||
|
||||
model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
|
||||
model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=0, **kwargs).learn(n_steps)
|
||||
|
||||
evaluate_policy(model, vec_env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||
obs = vec_env.reset()
|
||||
|
|
|
|||
Loading…
Reference in New Issue