File size: 4,532 Bytes
ed232fa
730fe87
ed232fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730fe87
 
 
 
 
8c9e2db
 
 
 
 
730fe87
 
 
 
 
 
 
 
 
 
 
8c9e2db
730fe87
 
 
 
 
ed232fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162d5c8
ed232fa
 
 
 
730fe87
ed232fa
 
 
 
 
 
 
162d5c8
 
 
ed232fa
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import asyncio
import traceback

class Job:
    def __init__(self, data):
        self._id = None
        self.data = data


class Node:
    # def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None):
    def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ):
        self.worker_id = worker_id
        self.input_queue = input_queue
        self.output_queue = output_queue
        self.buffer = {}
        self.job_sync = job_sync
        self.sequential_node = sequential_node
        self.next_i = 0
        self._jobs_dequeued = 0
        self._jobs_processed = 0
        # throw an error if job_sync is not None and sequential_node is False
        if self.job_sync is not None and self.sequential_node == False:
            raise ValueError('job_sync is not None and sequential_node is False')

    async def run(self):
        try:
            while True:
                job: Job = await self.input_queue.get()
                self._jobs_dequeued += 1
                if self.sequential_node == False:
                    async for job in self.process_job(job):
                        if self.output_queue is not None:
                            await self.output_queue.put(job)
                        if self.job_sync is not None:
                            self.job_sync.append(job)
                    self._jobs_processed += 1
                else:
                    # ensure that jobs are processed in order
                    self.buffer[job.id] = job
                    while self.next_i in self.buffer:
                        job = self.buffer.pop(self.next_i)
                        async for job in self.process_job(job):
                            if self.output_queue is not None:
                                await self.output_queue.put(job)
                            if self.job_sync is not None:
                                self.job_sync.append(job)
                        self._jobs_processed += 1
                        self.next_i += 1
        except Exception as e:
            print(f"An error occurred in node: {self.__class__.__name__} worker: {self.worker_id}: {e}")
            traceback.print_exc()
            raise  # Re-raises the last exception.

    async def process_job(self, job: Job):
        raise NotImplementedError()
    
class Pipeline:
    def __init__(self):
        self.input_queues = []
        self.root_queue = None
        # self.output_queues = []
        # self.job_sysncs = []
        self.nodes= []
        self.node_workers = {}
        self.tasks = []
        self._job_id = 0

    async def add_node(self, node: Node, num_workers=1, input_queue=None, output_queue=None, job_sync=None, sequential_node=False ):
        # input_queue must not be None
        if input_queue is None:
            raise ValueError('input_queue is None')
        # job_sync nodes must be sequential_nodes
        if job_sync is not None and sequential_node == False:
            raise ValueError('job_sync is not None and sequential_node is False')
        # sequential_nodes should one have 1 worker
        if sequential_node == True and num_workers != 1:
            raise ValueError('sequentaial nodes can only have one node (sequential_node is True and num_workers is not 1)')
        # output queue must not equal input_queue
        if output_queue == input_queue:
            raise ValueError('output_queue must not be the same as input_queue')        

        node_name = node.__name__
        if node_name not in self.nodes:
            self.nodes.append(node_name)

        # if input_queue is None then this is the root node
        if len(self.input_queues) == 0:
            self.root_queue = input_queue

        self.input_queues.append(input_queue)

        for i in range(num_workers):
            worker_id = i
            node_worker = node(worker_id, input_queue, output_queue, job_sync, sequential_node)
            if node_name not in self.node_workers:
                self.node_workers[node_name] = []
            self.node_workers[node_name].append(node_worker)
            task = asyncio.create_task(node_worker.run())
            self.tasks.append(task)

    async def enqueue_job(self, job: Job):
        job.id = self._job_id
        self._job_id += 1
        await self.root_queue.put(job)

    async def close(self):
        for task in self.tasks:
            task.cancel()
        await asyncio.gather(*self.tasks, return_exceptions=True)