210 lines
8.2 KiB
Python
210 lines
8.2 KiB
Python
import multiprocessing
|
|
import multiprocessing as mp
|
|
from collections import defaultdict
|
|
from typing import Callable, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from stable_baselines3.common.evaluation import evaluate_policy
|
|
from stable_baselines3.common.policies import BasePolicy
|
|
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
|
from stable_baselines3.common.utils import compat_gym_seed
|
|
from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
|
|
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
|
|
|
|
|
|
def _worker(
|
|
remote: mp.connection.Connection,
|
|
parent_remote: mp.connection.Connection,
|
|
worker_env_wrapper: CloudpickleWrapper,
|
|
train_policy_wrapper: CloudpickleWrapper,
|
|
n_eval_episodes: int = 1,
|
|
) -> None:
|
|
"""
|
|
Function that will be run in each process.
|
|
It is in charge of creating environments, evaluating candidates
|
|
and communicating with the main process.
|
|
|
|
:param remote: Pipe to communicate with the parent process.
|
|
:param parent_remote:
|
|
:param worker_env_wrapper: Callable used to create the environment inside the process.
|
|
:param train_policy_wrapper: Callable used to create the policy inside the process.
|
|
:param n_eval_episodes: Number of evaluation episodes per candidate.
|
|
"""
|
|
parent_remote.close()
|
|
env = worker_env_wrapper.var()
|
|
train_policy = train_policy_wrapper.var
|
|
vec_normalize = unwrap_vec_normalize(env)
|
|
if vec_normalize is not None:
|
|
obs_rms = vec_normalize.obs_rms
|
|
else:
|
|
obs_rms = None
|
|
while True:
|
|
try:
|
|
cmd, data = remote.recv()
|
|
if cmd == "eval":
|
|
results = []
|
|
# Evaluate each candidate and save results
|
|
for weights_idx, candidate_weights in data:
|
|
train_policy.load_from_vector(candidate_weights.cpu())
|
|
episode_rewards, episode_lengths = evaluate_policy(
|
|
train_policy,
|
|
env,
|
|
n_eval_episodes=n_eval_episodes,
|
|
return_episode_rewards=True,
|
|
warn=False,
|
|
)
|
|
results.append((weights_idx, (episode_rewards, episode_lengths)))
|
|
remote.send(results)
|
|
elif cmd == "seed":
|
|
remote.send(compat_gym_seed(env, seed=data))
|
|
elif cmd == "get_obs_rms":
|
|
remote.send(obs_rms)
|
|
elif cmd == "sync_obs_rms":
|
|
vec_normalize.obs_rms = data
|
|
obs_rms = data
|
|
elif cmd == "close":
|
|
env.close()
|
|
remote.close()
|
|
break
|
|
else:
|
|
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
|
|
except EOFError:
|
|
break
|
|
|
|
|
|
class AsyncEval:
|
|
"""
|
|
Helper class to do asynchronous evaluation of different policies with multiple processes.
|
|
It is useful when implementing population based methods like Evolution Strategies (ES),
|
|
Cross Entropy Method (CEM) or Augmented Random Search (ARS).
|
|
|
|
.. warning::
|
|
|
|
Only 'forkserver' and 'spawn' start methods are thread-safe,
|
|
which is important to avoid race conditions.
|
|
However, compared to
|
|
'fork' they incur a small start-up cost and have restrictions on
|
|
global variables. With those methods, users must wrap the code in an
|
|
``if __name__ == "__main__":`` block.
|
|
For more information, see the multiprocessing documentation.
|
|
|
|
:param envs_fn: Vectorized environments to run in subprocesses (callable)
|
|
:param train_policy: The policy object that will load the different candidate
|
|
weights.
|
|
:param start_method: method used to start the subprocesses.
|
|
Must be one of the methods returned by ``multiprocessing.get_all_start_methods()``.
|
|
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
|
|
:param n_eval_episodes: The number of episodes to test each agent
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
envs_fn: List[Callable[[], VecEnv]],
|
|
train_policy: BasePolicy,
|
|
start_method: Optional[str] = None,
|
|
n_eval_episodes: int = 1,
|
|
):
|
|
self.waiting = False
|
|
self.closed = False
|
|
n_envs = len(envs_fn)
|
|
|
|
if start_method is None:
|
|
# Fork is not a thread safe method (see issue #217)
|
|
# but is more user friendly (does not require to wrap the code in
|
|
# a `if __name__ == "__main__":`)
|
|
forkserver_available = "forkserver" in mp.get_all_start_methods()
|
|
start_method = "forkserver" if forkserver_available else "spawn"
|
|
ctx = mp.get_context(start_method)
|
|
|
|
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
|
|
self.processes = []
|
|
for work_remote, remote, worker_env in zip(self.work_remotes, self.remotes, envs_fn):
|
|
args = (
|
|
work_remote,
|
|
remote,
|
|
CloudpickleWrapper(worker_env),
|
|
CloudpickleWrapper(train_policy),
|
|
n_eval_episodes,
|
|
)
|
|
# daemon=True: if the main process crashes, we should not cause things to hang
|
|
process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
|
|
process.start()
|
|
self.processes.append(process)
|
|
work_remote.close()
|
|
|
|
def send_jobs(self, candidate_weights: th.Tensor, pop_size: int) -> None:
|
|
"""
|
|
Send jobs to the workers to evaluate new candidates.
|
|
|
|
:param candidate_weights: The weights to be evaluated.
|
|
:pop_size: The number of candidate (size of the population)
|
|
"""
|
|
jobs_per_worker = defaultdict(list)
|
|
for weights_idx in range(pop_size):
|
|
jobs_per_worker[weights_idx % len(self.remotes)].append((weights_idx, candidate_weights[weights_idx]))
|
|
|
|
for remote_idx, remote in enumerate(self.remotes):
|
|
remote.send(("eval", jobs_per_worker[remote_idx]))
|
|
self.waiting = True
|
|
|
|
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
|
"""
|
|
Seed the environments.
|
|
|
|
:param seed: The seed for the pseudo-random generators.
|
|
:return:
|
|
"""
|
|
for idx, remote in enumerate(self.remotes):
|
|
remote.send(("seed", seed + idx))
|
|
return [remote.recv() for remote in self.remotes]
|
|
|
|
def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]:
|
|
"""
|
|
Retreive episode rewards and lengths from each worker
|
|
for all candidates (there might be multiple candidates per worker)
|
|
|
|
:return: A list of tuples containing each candidate index and its
|
|
result (episodic reward and episode length)
|
|
"""
|
|
results = [remote.recv() for remote in self.remotes]
|
|
flat_results = [result for worker_results in results for result in worker_results]
|
|
self.waiting = False
|
|
return flat_results
|
|
|
|
def get_obs_rms(self) -> List[RunningMeanStd]:
|
|
"""
|
|
Retrieve the observation filters (observation running mean std)
|
|
of each process, they will be combined in the main process.
|
|
Synchronisation is done afterward using ``sync_obs_rms()``.
|
|
:return: A list of ``RunningMeanStd`` objects (one per process)
|
|
"""
|
|
for remote in self.remotes:
|
|
remote.send(("get_obs_rms", None))
|
|
return [remote.recv() for remote in self.remotes]
|
|
|
|
def sync_obs_rms(self, obs_rms: RunningMeanStd) -> None:
|
|
"""
|
|
Synchronise (and update) the observation filters
|
|
(observation running mean std)
|
|
:param obs_rms: The updated ``RunningMeanStd`` to be used
|
|
by workers for normalizing observations.
|
|
"""
|
|
for remote in self.remotes:
|
|
remote.send(("sync_obs_rms", obs_rms))
|
|
|
|
def close(self) -> None:
|
|
"""
|
|
Close the processes.
|
|
"""
|
|
if self.closed:
|
|
return
|
|
if self.waiting:
|
|
for remote in self.remotes:
|
|
remote.recv()
|
|
for remote in self.remotes:
|
|
remote.send(("close", None))
|
|
for process in self.processes:
|
|
process.join()
|
|
self.closed = True
|