import asyncio import random import time class Job: def __init__(self, id, data): self.id = id self.data = data class Node: # def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None): def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ): self.worker_id = worker_id self.input_queue = input_queue self.output_queue = output_queue self.buffer = {} self.job_sync = job_sync self.sequential_node = sequential_node self.next_i = 0 self._jobs_dequeued = 0 self._jobs_processed = 0 # throw an error if job_sync is not None and sequential_node is False if self.job_sync is not None and self.sequential_node == False: raise ValueError('job_sync is not None and sequential_node is False') async def run(self): while True: job: Job = await self.input_queue.get() self._jobs_dequeued += 1 if self.sequential_node == False: await self.process_job(job) else: # ensure that jobs are processed in order self.buffer[job.id] = job while self.next_i in self.buffer: job = self.buffer.pop(self.next_i) await self.process_job(job) self.next_i += 1 if self.output_queue is not None: await self.output_queue.put(job) if self.job_sync is not None: self.job_sync.append(job) self._jobs_processed += 1 async def process_job(self, job: Job): raise NotImplementedError class Node1(Node): async def process_job(self, job: Job): job.data += f' (processed by node 1, worker {self.worker_id})' class Node2(Node): async def process_job(self, job: Job): sleep_duration = 0.8 + 0.4 * random.random() await asyncio.sleep(sleep_duration) job.data += f' (processed by node 2, worker {self.worker_id})' class Node3(Node): async def process_job(self, job: Job): job.data += f' (processed by node 3, worker {self.worker_id})' print(f'{job.id} - {job.data}') async def main(): node1_queue = asyncio.Queue() node2_queue = asyncio.Queue() node3_queue = asyncio.Queue() num_jobs = 100 joe_source = [Job(i, "") for i in range(num_jobs)] job_sync = [] # create the workers num_workers = 5 node1_workers = [Node1(i + 1, node1_queue, node2_queue) for i in range(1)] node2_workers = [Node2(i + 1, node2_queue, node3_queue) for i in range(num_workers)] node3_workers = [Node3(i + 1, node3_queue, job_sync=job_sync, sequential_node=True) for i in range(1)] # create tasks for the workers tasks1 = [asyncio.create_task(worker.run()) for worker in node1_workers] tasks2 = [asyncio.create_task(worker.run()) for worker in node2_workers] tasks3 = [asyncio.create_task(worker.run()) for worker in node3_workers] for job in joe_source: await node1_queue.put(job) # await input_queue.put(joe_source[0]) try: while len(job_sync) < num_jobs: # print(f"Waiting for jobs to finish... Job sync size: {len(job_sync)}, node1_queue size: {node1_queue.qsize()}, node2_queue size: {node2_queue.qsize()}, node3_queue size: {node3_queue.qsize()}") await asyncio.sleep(0.1) except asyncio.CancelledError: print("Pipeline cancelled") for task in tasks1: task.cancel() for task in tasks2: task.cancel() for task in tasks3: task.cancel() await asyncio.gather(*tasks1, *tasks2, *tasks3, return_exceptions=True) start_time = time.time() try: asyncio.run(main()) except KeyboardInterrupt: print("Pipeline interrupted by user") end_time = time.time() print(f"Pipeline processed in {end_time - start_time} seconds.")