File size: 4,040 Bytes
ed232fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import asyncio

class Job:
    def __init__(self, data):
        self._id = None
        self.data = data


class Node:
    # def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None):
    def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ):
        self.worker_id = worker_id
        self.input_queue = input_queue
        self.output_queue = output_queue
        self.buffer = {}
        self.job_sync = job_sync
        self.sequential_node = sequential_node
        self.next_i = 0
        self._jobs_dequeued = 0
        self._jobs_processed = 0
        # throw an error if job_sync is not None and sequential_node is False
        if self.job_sync is not None and self.sequential_node == False:
            raise ValueError('job_sync is not None and sequential_node is False')

    async def run(self):
        while True:
            job: Job = await self.input_queue.get()
            self._jobs_dequeued += 1
            if self.sequential_node == False:
                await self.process_job(job)
                if self.output_queue is not None:
                    await self.output_queue.put(job)
                if self.job_sync is not None:
                    self.job_sync.append(job)
                self._jobs_processed += 1
            else:
                # ensure that jobs are processed in order
                self.buffer[job.id] = job
                while self.next_i in self.buffer:
                    job = self.buffer.pop(self.next_i)
                    await self.process_job(job)
                    self.next_i += 1
                    if self.output_queue is not None:
                        await self.output_queue.put(job)
                    if self.job_sync is not None:
                        self.job_sync.append(job)
                    self._jobs_processed += 1

    async def process_job(self, job: Job):
        raise NotImplementedError()
    
class Pipeline:
    def __init__(self):
        self.input_queues = []
        self.root_queue = None
        # self.output_queues = []
        # self.job_sysncs = []
        self.nodes= []
        self.node_workers = {}
        self.tasks = []
        self._job_id = 0

    async def add_node(self, node: Node, num_workers=1, input_queue=None, output_queue=None, job_sync=None, sequential_node=False ):
        # input_queue must not be None
        if input_queue is None:
            raise ValueError('input_queue is None')
        # job_sync nodes must be sequential_nodes
        if job_sync is not None and sequential_node == False:
            raise ValueError('job_sync is not None and sequential_node is False')
        # sequential_nodes should one have 1 worker
        if sequential_node == True and num_workers != 1:
            raise ValueError('sequentaial nodes can only have one node (sequential_node is True and num_workers is not 1)')
        # output queue must not equal input_queue
        if output_queue == input_queue:
            raise ValueError('output_queue must not be the same as input_queue')        

        node_name = node.__class__.__name__
        if node_name not in self.nodes:
            self.nodes.append(node_name)

        # if input_queue is None then this is the root node
        if len(self.input_queues) is 0:
            self.root_queue = input_queue

        self.input_queues.append(input_queue)

        for i in range(num_workers):
            worker_id = i
            node_worker = node(worker_id, input_queue, output_queue, job_sync, sequential_node)
            self.node_workers[node_name] = node_worker
            task = asyncio.create_task(node_worker.run())
            self.tasks.append(task)

    async def enqueue_job(self, job: Job):
        job.id = self._job_id
        self._job_id += 1
        await self.root_queue.put(job)

    async def close(self):
        for task in self.tasks:
            task.cancel()
        await asyncio.gather(*self.tasks, return_exceptions=True)