project_charles / pipeline_test.py
sohojoe's picture
create a node class to help things be a bit more robust
de8cac1
raw
history blame
3.97 kB
import asyncio
import random
import time
class Job:
def __init__(self, id, data):
self.id = id
self.data = data
class Node:
# def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None):
def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ):
self.worker_id = worker_id
self.input_queue = input_queue
self.output_queue = output_queue
self.buffer = {}
self.job_sync = job_sync
self.sequential_node = sequential_node
self.next_i = 0
self._jobs_dequeued = 0
self._jobs_processed = 0
# throw an error if job_sync is not None and sequential_node is False
if self.job_sync is not None and self.sequential_node == False:
raise ValueError('job_sync is not None and sequential_node is False')
async def run(self):
while True:
job: Job = await self.input_queue.get()
self._jobs_dequeued += 1
if self.sequential_node == False:
await self.process_job(job)
else:
# ensure that jobs are processed in order
self.buffer[job.id] = job
while self.next_i in self.buffer:
job = self.buffer.pop(self.next_i)
await self.process_job(job)
self.next_i += 1
if self.output_queue is not None:
await self.output_queue.put(job)
if self.job_sync is not None:
self.job_sync.append(job)
self._jobs_processed += 1
async def process_job(self, job: Job):
raise NotImplementedError
class Node1(Node):
async def process_job(self, job: Job):
job.data += f' (processed by node 1, worker {self.worker_id})'
class Node2(Node):
async def process_job(self, job: Job):
sleep_duration = 0.8 + 0.4 * random.random()
await asyncio.sleep(sleep_duration)
job.data += f' (processed by node 2, worker {self.worker_id})'
class Node3(Node):
async def process_job(self, job: Job):
job.data += f' (processed by node 3, worker {self.worker_id})'
print(f'{job.id} - {job.data}')
async def main():
node1_queue = asyncio.Queue()
node2_queue = asyncio.Queue()
node3_queue = asyncio.Queue()
num_jobs = 100
joe_source = [Job(i, "") for i in range(num_jobs)]
job_sync = []
# create the workers
num_workers = 5
node1_workers = [Node1(i + 1, node1_queue, node2_queue) for i in range(1)]
node2_workers = [Node2(i + 1, node2_queue, node3_queue) for i in range(num_workers)]
node3_workers = [Node3(i + 1, node3_queue, job_sync=job_sync, sequential_node=True) for i in range(1)]
# create tasks for the workers
tasks1 = [asyncio.create_task(worker.run()) for worker in node1_workers]
tasks2 = [asyncio.create_task(worker.run()) for worker in node2_workers]
tasks3 = [asyncio.create_task(worker.run()) for worker in node3_workers]
for job in joe_source:
await node1_queue.put(job)
# await input_queue.put(joe_source[0])
try:
while len(job_sync) < num_jobs:
# print(f"Waiting for jobs to finish... Job sync size: {len(job_sync)}, node1_queue size: {node1_queue.qsize()}, node2_queue size: {node2_queue.qsize()}, node3_queue size: {node3_queue.qsize()}")
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print("Pipeline cancelled")
for task in tasks1:
task.cancel()
for task in tasks2:
task.cancel()
for task in tasks3:
task.cancel()
await asyncio.gather(*tasks1, *tasks2, *tasks3, return_exceptions=True)
start_time = time.time()
try:
asyncio.run(main())
except KeyboardInterrupt:
print("Pipeline interrupted by user")
end_time = time.time()
print(f"Pipeline processed in {end_time - start_time} seconds.")