project_charles / tests /test_pipeline.py
sohojoe's picture
create a pipeline class with basic test coverage
ed232fa
raw
history blame
2.67 kB
import asyncio
import random
import time
import unittest
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pipeline import Pipeline, Node, Job
class Node1(Node):
async def process_job(self, job: Job):
job.data += f' (processed by node 1, worker {self.worker_id})'
class Node2(Node):
async def process_job(self, job: Job):
sleep_duration = 0.08 + 0.04 * random.random()
await asyncio.sleep(sleep_duration)
job.data += f' (processed by node 2, worker {self.worker_id})'
class Node3(Node):
async def process_job(self, job: Job):
job.data += f' (processed by node 3, worker {self.worker_id})'
print(f'{job.id} - {job.data}')
class TestPipeline(unittest.TestCase):
def setUp(self):
pass
async def _test_pipeline_edge_cases(self):
# must have a input queue
with self.assertRaises(ValueError):
await self.pipeline.add_node(Node1, 1, None, None)
# too output queue must not equal from input queue
node1_queue = asyncio.Queue()
with self.assertRaises(ValueError):
await self.pipeline.add_node(Node1, 1, node1_queue, node1_queue)
async def _test_pipeline(self, num_jobs):
node1_queue = asyncio.Queue()
node2_queue = asyncio.Queue()
node3_queue = asyncio.Queue()
await self.pipeline.add_node(Node1, 1, node1_queue, node2_queue)
await self.pipeline.add_node(Node2, 5, node2_queue, node3_queue)
await self.pipeline.add_node(Node3, 1, node3_queue, job_sync=self.job_sync, sequential_node=True)
for i in range(num_jobs):
job = Job("")
await self.pipeline.enqueue_job(job)
while True:
if len(self.job_sync) == num_jobs:
break
await asyncio.sleep(0.1)
await self.pipeline.close()
def test_pipeline_edge_cases(self):
self.pipeline = Pipeline()
self.job_sync = []
asyncio.run(self._test_pipeline_edge_cases())
# def test_pipeline_keeps_order(self):
# self.pipeline = Pipeline()
# self.job_sync = []
# num_jobs = 100
# start_time = time.time()
# asyncio.run(self._test_pipeline(num_jobs))
# end_time = time.time()
# print(f"Pipeline processed in {end_time - start_time} seconds.")
# self.assertEqual(len(self.job_sync), num_jobs)
# for i, job in enumerate(self.job_sync):
# self.assertEqual(i, job.id)
if __name__ == '__main__':
unittest.main()
# test = TestPipeline()
# test.setUp()
# test.test_pipeline()