File size: 3,626 Bytes
d1b91e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import traceback
from functools import partial
from tqdm import tqdm


def chunked_worker(worker_id, args_queue=None, results_queue=None, init_ctx_func=None):
    ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None
    while True:
        args = args_queue.get()
        if args == '<KILL>':
            return
        job_idx, map_func, arg = args
        try:
            map_func_ = partial(map_func, ctx=ctx) if ctx is not None else map_func
            if isinstance(arg, dict):
                res = map_func_(**arg)
            elif isinstance(arg, (list, tuple)):
                res = map_func_(*arg)
            else:
                res = map_func_(arg)
            results_queue.put((job_idx, res))
        except:
            traceback.print_exc()
            results_queue.put((job_idx, None))


class MultiprocessManager:
    def __init__(self, num_workers=None, init_ctx_func=None, multithread=False):
        if multithread:
            from multiprocessing.dummy import Queue, Process
        else:
            from multiprocessing import Queue, Process
        if num_workers is None:
            num_workers = int(os.getenv('N_PROC', os.cpu_count()))
        self.num_workers = num_workers
        self.results_queue = Queue(maxsize=-1)
        self.args_queue = Queue(maxsize=-1)
        self.workers = []
        self.total_jobs = 0
        for i in range(num_workers):
            p = Process(target=chunked_worker,
                        args=(i, self.args_queue, self.results_queue, init_ctx_func),
                        daemon=True)
            self.workers.append(p)
            p.start()

    def add_job(self, func, args):
        self.args_queue.put((self.total_jobs, func, args))
        self.total_jobs += 1

    def get_results(self):
        for w in range(self.num_workers):
            self.args_queue.put("<KILL>")
        self.n_finished = 0
        while self.n_finished < self.total_jobs:
            job_id, res = self.results_queue.get()
            yield job_id, res
            self.n_finished += 1
        for w in self.workers:
            w.join()

    def __len__(self):
        return self.total_jobs


def multiprocess_run_tqdm(map_func, args, num_workers=None, ordered=True, init_ctx_func=None,
                          multithread=False, desc=None):
    for i, res in tqdm(enumerate(
            multiprocess_run(map_func, args, num_workers, ordered, init_ctx_func, multithread)),
            total=len(args), desc=desc):
        yield i, res


def multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, multithread=False):
    """
    Multiprocessing running chunked jobs.

    Examples:
    >>> for res in tqdm(multiprocess_run(job_func, args):
    >>>     print(res)

    :param map_func:
    :param args:
    :param num_workers:
    :param ordered:
    :param init_ctx_func:
    :param q_max_size:
    :param multithread:
    :return:
    """
    if num_workers is None:
        num_workers = int(os.getenv('N_PROC', os.cpu_count()))
    manager = MultiprocessManager(num_workers, init_ctx_func, multithread)
    for arg in args:
        manager.add_job(map_func, arg)
    if ordered:
        n_jobs = len(args)
        results = ['<WAIT>' for _ in range(n_jobs)]
        i_now = 0
        for job_i, res in manager.get_results():
            results[job_i] = res
            while i_now < n_jobs and (not isinstance(results[i_now], str) or results[i_now] != '<WAIT>'):
                yield results[i_now]
                i_now += 1
    else:
        for res in manager.get_results():
            yield res