Spaces:
Sleeping
Sleeping
File size: 4,435 Bytes
ed232fa 730fe87 ed232fa 730fe87 8c9e2db 730fe87 8c9e2db 730fe87 ed232fa 730fe87 ed232fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import asyncio
import traceback
class Job:
def __init__(self, data):
self._id = None
self.data = data
class Node:
# def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None):
def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ):
self.worker_id = worker_id
self.input_queue = input_queue
self.output_queue = output_queue
self.buffer = {}
self.job_sync = job_sync
self.sequential_node = sequential_node
self.next_i = 0
self._jobs_dequeued = 0
self._jobs_processed = 0
# throw an error if job_sync is not None and sequential_node is False
if self.job_sync is not None and self.sequential_node == False:
raise ValueError('job_sync is not None and sequential_node is False')
async def run(self):
try:
while True:
job: Job = await self.input_queue.get()
self._jobs_dequeued += 1
if self.sequential_node == False:
async for job in self.process_job(job):
if self.output_queue is not None:
await self.output_queue.put(job)
if self.job_sync is not None:
self.job_sync.append(job)
self._jobs_processed += 1
else:
# ensure that jobs are processed in order
self.buffer[job.id] = job
while self.next_i in self.buffer:
job = self.buffer.pop(self.next_i)
async for job in self.process_job(job):
if self.output_queue is not None:
await self.output_queue.put(job)
if self.job_sync is not None:
self.job_sync.append(job)
self._jobs_processed += 1
self.next_i += 1
except Exception as e:
print(f"An error occurred in node: {self.__class__.__name__} worker: {self.worker_id}: {e}")
traceback.print_exc()
raise # Re-raises the last exception.
async def process_job(self, job: Job):
raise NotImplementedError()
class Pipeline:
def __init__(self):
self.input_queues = []
self.root_queue = None
# self.output_queues = []
# self.job_sysncs = []
self.nodes= []
self.node_workers = {}
self.tasks = []
self._job_id = 0
async def add_node(self, node: Node, num_workers=1, input_queue=None, output_queue=None, job_sync=None, sequential_node=False ):
# input_queue must not be None
if input_queue is None:
raise ValueError('input_queue is None')
# job_sync nodes must be sequential_nodes
if job_sync is not None and sequential_node == False:
raise ValueError('job_sync is not None and sequential_node is False')
# sequential_nodes should one have 1 worker
if sequential_node == True and num_workers != 1:
raise ValueError('sequentaial nodes can only have one node (sequential_node is True and num_workers is not 1)')
# output queue must not equal input_queue
if output_queue == input_queue:
raise ValueError('output_queue must not be the same as input_queue')
node_name = node.__class__.__name__
if node_name not in self.nodes:
self.nodes.append(node_name)
# if input_queue is None then this is the root node
if len(self.input_queues) == 0:
self.root_queue = input_queue
self.input_queues.append(input_queue)
for i in range(num_workers):
worker_id = i
node_worker = node(worker_id, input_queue, output_queue, job_sync, sequential_node)
self.node_workers[node_name] = node_worker
task = asyncio.create_task(node_worker.run())
self.tasks.append(task)
async def enqueue_job(self, job: Job):
job.id = self._job_id
self._job_id += 1
await self.root_queue.put(job)
async def close(self):
for task in self.tasks:
task.cancel()
await asyncio.gather(*self.tasks, return_exceptions=True)
|