Spaces:
Sleeping
Sleeping
import asyncio | |
class Job: | |
def __init__(self, data): | |
self._id = None | |
self.data = data | |
class Node: | |
# def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None): | |
def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ): | |
self.worker_id = worker_id | |
self.input_queue = input_queue | |
self.output_queue = output_queue | |
self.buffer = {} | |
self.job_sync = job_sync | |
self.sequential_node = sequential_node | |
self.next_i = 0 | |
self._jobs_dequeued = 0 | |
self._jobs_processed = 0 | |
# throw an error if job_sync is not None and sequential_node is False | |
if self.job_sync is not None and self.sequential_node == False: | |
raise ValueError('job_sync is not None and sequential_node is False') | |
async def run(self): | |
while True: | |
job: Job = await self.input_queue.get() | |
self._jobs_dequeued += 1 | |
if self.sequential_node == False: | |
async for job in self.process_job(job): | |
if self.output_queue is not None: | |
await self.output_queue.put(job) | |
if self.job_sync is not None: | |
self.job_sync.append(job) | |
self._jobs_processed += 1 | |
else: | |
# ensure that jobs are processed in order | |
self.buffer[job.id] = job | |
while self.next_i in self.buffer: | |
job = self.buffer.pop(self.next_i) | |
async for job in self.process_job(job): | |
if self.output_queue is not None: | |
await self.output_queue.put(job) | |
if self.job_sync is not None: | |
self.job_sync.append(job) | |
self._jobs_processed += 1 | |
self.next_i += 1 | |
async def process_job(self, job: Job): | |
raise NotImplementedError() | |
class Pipeline: | |
def __init__(self): | |
self.input_queues = [] | |
self.root_queue = None | |
# self.output_queues = [] | |
# self.job_sysncs = [] | |
self.nodes= [] | |
self.node_workers = {} | |
self.tasks = [] | |
self._job_id = 0 | |
async def add_node(self, node: Node, num_workers=1, input_queue=None, output_queue=None, job_sync=None, sequential_node=False ): | |
# input_queue must not be None | |
if input_queue is None: | |
raise ValueError('input_queue is None') | |
# job_sync nodes must be sequential_nodes | |
if job_sync is not None and sequential_node == False: | |
raise ValueError('job_sync is not None and sequential_node is False') | |
# sequential_nodes should one have 1 worker | |
if sequential_node == True and num_workers != 1: | |
raise ValueError('sequentaial nodes can only have one node (sequential_node is True and num_workers is not 1)') | |
# output queue must not equal input_queue | |
if output_queue == input_queue: | |
raise ValueError('output_queue must not be the same as input_queue') | |
node_name = node.__class__.__name__ | |
if node_name not in self.nodes: | |
self.nodes.append(node_name) | |
# if input_queue is None then this is the root node | |
if len(self.input_queues) is 0: | |
self.root_queue = input_queue | |
self.input_queues.append(input_queue) | |
for i in range(num_workers): | |
worker_id = i | |
node_worker = node(worker_id, input_queue, output_queue, job_sync, sequential_node) | |
self.node_workers[node_name] = node_worker | |
task = asyncio.create_task(node_worker.run()) | |
self.tasks.append(task) | |
async def enqueue_job(self, job: Job): | |
job.id = self._job_id | |
self._job_id += 1 | |
await self.root_queue.put(job) | |
async def close(self): | |
for task in self.tasks: | |
task.cancel() | |
await asyncio.gather(*self.tasks, return_exceptions=True) | |