Update AsyncEval seeding (#185)
* Update AsyncEval seeding * Update random seed in tests
This commit is contained in:
parent
86fb056fda
commit
d467d7a844
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Release 2.0.0a8 (WIP)
|
Release 2.0.0a9 (WIP)
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
**Gymnasium support**
|
**Gymnasium support**
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import torch as th
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
||||||
from stable_baselines3.common.utils import compat_gym_seed
|
|
||||||
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
|
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
|
||||||
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
|
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
|
||||||
|
|
||||||
|
|
@ -32,9 +31,9 @@ def _worker(
|
||||||
:param n_eval_episodes: Number of evaluation episodes per candidate.
|
:param n_eval_episodes: Number of evaluation episodes per candidate.
|
||||||
"""
|
"""
|
||||||
parent_remote.close()
|
parent_remote.close()
|
||||||
env = worker_env_wrapper.var()
|
vec_env: VecEnv = worker_env_wrapper.var()
|
||||||
train_policy = train_policy_wrapper.var
|
train_policy = train_policy_wrapper.var
|
||||||
vec_normalize = unwrap_vec_normalize(env)
|
vec_normalize = unwrap_vec_normalize(vec_env)
|
||||||
if vec_normalize is not None:
|
if vec_normalize is not None:
|
||||||
obs_rms = vec_normalize.obs_rms
|
obs_rms = vec_normalize.obs_rms
|
||||||
else:
|
else:
|
||||||
|
|
@ -49,7 +48,7 @@ def _worker(
|
||||||
train_policy.load_from_vector(candidate_weights.cpu())
|
train_policy.load_from_vector(candidate_weights.cpu())
|
||||||
episode_rewards, episode_lengths = evaluate_policy(
|
episode_rewards, episode_lengths = evaluate_policy(
|
||||||
train_policy,
|
train_policy,
|
||||||
env,
|
vec_env,
|
||||||
n_eval_episodes=n_eval_episodes,
|
n_eval_episodes=n_eval_episodes,
|
||||||
return_episode_rewards=True,
|
return_episode_rewards=True,
|
||||||
warn=False,
|
warn=False,
|
||||||
|
|
@ -57,14 +56,15 @@ def _worker(
|
||||||
results.append((weights_idx, (episode_rewards, episode_lengths)))
|
results.append((weights_idx, (episode_rewards, episode_lengths)))
|
||||||
remote.send(results)
|
remote.send(results)
|
||||||
elif cmd == "seed":
|
elif cmd == "seed":
|
||||||
remote.send(compat_gym_seed(env, seed=data))
|
# Note: the seed will only be effective at the next reset
|
||||||
|
remote.send(vec_env.seed(seed=data))
|
||||||
elif cmd == "get_obs_rms":
|
elif cmd == "get_obs_rms":
|
||||||
remote.send(obs_rms)
|
remote.send(obs_rms)
|
||||||
elif cmd == "sync_obs_rms":
|
elif cmd == "sync_obs_rms":
|
||||||
vec_normalize.obs_rms = data
|
vec_normalize.obs_rms = data
|
||||||
obs_rms = data
|
obs_rms = data
|
||||||
elif cmd == "close":
|
elif cmd == "close":
|
||||||
env.close()
|
vec_env.close()
|
||||||
remote.close()
|
remote.close()
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
2.0.0a8
|
2.0.0a9
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
|
||||||
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"stable_baselines3>=2.0.0a4",
|
"stable_baselines3>=2.0.0a9",
|
||||||
],
|
],
|
||||||
description="Contrib package of Stable Baselines3, experimental code.",
|
description="Contrib package of Stable Baselines3, experimental code.",
|
||||||
author="Antonin Raffin",
|
author="Antonin Raffin",
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ def test_discrete(model_class, env):
|
||||||
elif n_steps == TRPO:
|
elif n_steps == TRPO:
|
||||||
kwargs = dict(n_steps=256, cg_max_steps=5)
|
kwargs = dict(n_steps=256, cg_max_steps=5)
|
||||||
|
|
||||||
model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)
|
model = model_class("MlpPolicy", vec_env, learning_rate=1e-3, gamma=0.4, seed=0, **kwargs).learn(n_steps)
|
||||||
|
|
||||||
evaluate_policy(model, vec_env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
evaluate_policy(model, vec_env, n_eval_episodes=20, reward_threshold=90, warn=False)
|
||||||
obs = vec_env.reset()
|
obs = vec_env.reset()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue