Spaces:
Sleeping
Sleeping
create a pipeline class with basic test coverage
Browse files- pipeline_test.py β pipeline.py +67 -79
- tests/test_pipeline.py +83 -0
pipeline_test.py β pipeline.py
RENAMED
@@ -1,11 +1,8 @@
|
|
1 |
import asyncio
|
2 |
-
import random
|
3 |
-
import time
|
4 |
-
|
5 |
|
6 |
class Job:
|
7 |
-
def __init__(self,
|
8 |
-
self.
|
9 |
self.data = data
|
10 |
|
11 |
|
@@ -31,6 +28,11 @@ class Node:
|
|
31 |
self._jobs_dequeued += 1
|
32 |
if self.sequential_node == False:
|
33 |
await self.process_job(job)
|
|
|
|
|
|
|
|
|
|
|
34 |
else:
|
35 |
# ensure that jobs are processed in order
|
36 |
self.buffer[job.id] = job
|
@@ -38,79 +40,65 @@ class Node:
|
|
38 |
job = self.buffer.pop(self.next_i)
|
39 |
await self.process_job(job)
|
40 |
self.next_i += 1
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
async def process_job(self, job: Job):
|
48 |
-
raise NotImplementedError
|
49 |
|
50 |
-
|
51 |
-
class Node1(Node):
|
52 |
async def process_job(self, job: Job):
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
start_time = time.time()
|
109 |
-
|
110 |
-
try:
|
111 |
-
asyncio.run(main())
|
112 |
-
except KeyboardInterrupt:
|
113 |
-
print("Pipeline interrupted by user")
|
114 |
-
|
115 |
-
end_time = time.time()
|
116 |
-
print(f"Pipeline processed in {end_time - start_time} seconds.")
|
|
|
1 |
import asyncio
|
|
|
|
|
|
|
2 |
|
3 |
class Job:
|
4 |
+
def __init__(self, data):
|
5 |
+
self._id = None
|
6 |
self.data = data
|
7 |
|
8 |
|
|
|
28 |
self._jobs_dequeued += 1
|
29 |
if self.sequential_node == False:
|
30 |
await self.process_job(job)
|
31 |
+
if self.output_queue is not None:
|
32 |
+
await self.output_queue.put(job)
|
33 |
+
if self.job_sync is not None:
|
34 |
+
self.job_sync.append(job)
|
35 |
+
self._jobs_processed += 1
|
36 |
else:
|
37 |
# ensure that jobs are processed in order
|
38 |
self.buffer[job.id] = job
|
|
|
40 |
job = self.buffer.pop(self.next_i)
|
41 |
await self.process_job(job)
|
42 |
self.next_i += 1
|
43 |
+
if self.output_queue is not None:
|
44 |
+
await self.output_queue.put(job)
|
45 |
+
if self.job_sync is not None:
|
46 |
+
self.job_sync.append(job)
|
47 |
+
self._jobs_processed += 1
|
|
|
|
|
|
|
48 |
|
|
|
|
|
49 |
async def process_job(self, job: Job):
|
50 |
+
raise NotImplementedError()
|
51 |
+
|
52 |
+
class Pipeline:
|
53 |
+
def __init__(self):
|
54 |
+
self.input_queues = []
|
55 |
+
self.root_queue = None
|
56 |
+
# self.output_queues = []
|
57 |
+
# self.job_sysncs = []
|
58 |
+
self.nodes= []
|
59 |
+
self.node_workers = {}
|
60 |
+
self.tasks = []
|
61 |
+
self._job_id = 0
|
62 |
+
|
63 |
+
async def add_node(self, node: Node, num_workers=1, input_queue=None, output_queue=None, job_sync=None, sequential_node=False ):
|
64 |
+
# input_queue must not be None
|
65 |
+
if input_queue is None:
|
66 |
+
raise ValueError('input_queue is None')
|
67 |
+
# job_sync nodes must be sequential_nodes
|
68 |
+
if job_sync is not None and sequential_node == False:
|
69 |
+
raise ValueError('job_sync is not None and sequential_node is False')
|
70 |
+
# sequential_nodes should one have 1 worker
|
71 |
+
if sequential_node == True and num_workers != 1:
|
72 |
+
raise ValueError('sequentaial nodes can only have one node (sequential_node is True and num_workers is not 1)')
|
73 |
+
# output queue must not equal input_queue
|
74 |
+
if output_queue == input_queue:
|
75 |
+
raise ValueError('output_queue must not be the same as input_queue')
|
76 |
+
|
77 |
+
node_name = node.__class__.__name__
|
78 |
+
if node_name not in self.nodes:
|
79 |
+
self.nodes.append(node_name)
|
80 |
+
|
81 |
+
# if input_queue is None then this is the root node
|
82 |
+
if len(self.input_queues) is 0:
|
83 |
+
self.root_queue = input_queue
|
84 |
+
|
85 |
+
self.input_queues.append(input_queue)
|
86 |
+
|
87 |
+
for i in range(num_workers):
|
88 |
+
worker_id = i
|
89 |
+
node_worker = node(worker_id, input_queue, output_queue, job_sync, sequential_node)
|
90 |
+
self.node_workers[node_name] = node_worker
|
91 |
+
task = asyncio.create_task(node_worker.run())
|
92 |
+
self.tasks.append(task)
|
93 |
+
|
94 |
+
async def enqueue_job(self, job: Job):
|
95 |
+
job.id = self._job_id
|
96 |
+
self._job_id += 1
|
97 |
+
await self.root_queue.put(job)
|
98 |
+
|
99 |
+
async def close(self):
|
100 |
+
for task in self.tasks:
|
101 |
+
task.cancel()
|
102 |
+
await asyncio.gather(*self.tasks, return_exceptions=True)
|
103 |
+
|
104 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_pipeline.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import unittest
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
8 |
+
|
9 |
+
from pipeline import Pipeline, Node, Job
|
10 |
+
|
11 |
+
|
12 |
+
class Node1(Node):
|
13 |
+
async def process_job(self, job: Job):
|
14 |
+
job.data += f' (processed by node 1, worker {self.worker_id})'
|
15 |
+
|
16 |
+
|
17 |
+
class Node2(Node):
|
18 |
+
async def process_job(self, job: Job):
|
19 |
+
sleep_duration = 0.08 + 0.04 * random.random()
|
20 |
+
await asyncio.sleep(sleep_duration)
|
21 |
+
job.data += f' (processed by node 2, worker {self.worker_id})'
|
22 |
+
|
23 |
+
|
24 |
+
class Node3(Node):
|
25 |
+
async def process_job(self, job: Job):
|
26 |
+
job.data += f' (processed by node 3, worker {self.worker_id})'
|
27 |
+
print(f'{job.id} - {job.data}')
|
28 |
+
|
29 |
+
|
30 |
+
class TestPipeline(unittest.TestCase):
|
31 |
+
def setUp(self):
|
32 |
+
pass
|
33 |
+
|
34 |
+
async def _test_pipeline_edge_cases(self):
|
35 |
+
# must have a input queue
|
36 |
+
with self.assertRaises(ValueError):
|
37 |
+
await self.pipeline.add_node(Node1, 1, None, None)
|
38 |
+
# too output queue must not equal from input queue
|
39 |
+
node1_queue = asyncio.Queue()
|
40 |
+
with self.assertRaises(ValueError):
|
41 |
+
await self.pipeline.add_node(Node1, 1, node1_queue, node1_queue)
|
42 |
+
|
43 |
+
|
44 |
+
async def _test_pipeline(self, num_jobs):
|
45 |
+
node1_queue = asyncio.Queue()
|
46 |
+
node2_queue = asyncio.Queue()
|
47 |
+
node3_queue = asyncio.Queue()
|
48 |
+
await self.pipeline.add_node(Node1, 1, node1_queue, node2_queue)
|
49 |
+
await self.pipeline.add_node(Node2, 5, node2_queue, node3_queue)
|
50 |
+
await self.pipeline.add_node(Node3, 1, node3_queue, job_sync=self.job_sync, sequential_node=True)
|
51 |
+
for i in range(num_jobs):
|
52 |
+
job = Job("")
|
53 |
+
await self.pipeline.enqueue_job(job)
|
54 |
+
while True:
|
55 |
+
if len(self.job_sync) == num_jobs:
|
56 |
+
break
|
57 |
+
await asyncio.sleep(0.1)
|
58 |
+
await self.pipeline.close()
|
59 |
+
|
60 |
+
def test_pipeline_edge_cases(self):
|
61 |
+
self.pipeline = Pipeline()
|
62 |
+
self.job_sync = []
|
63 |
+
asyncio.run(self._test_pipeline_edge_cases())
|
64 |
+
|
65 |
+
|
66 |
+
# def test_pipeline_keeps_order(self):
|
67 |
+
# self.pipeline = Pipeline()
|
68 |
+
# self.job_sync = []
|
69 |
+
# num_jobs = 100
|
70 |
+
# start_time = time.time()
|
71 |
+
# asyncio.run(self._test_pipeline(num_jobs))
|
72 |
+
# end_time = time.time()
|
73 |
+
# print(f"Pipeline processed in {end_time - start_time} seconds.")
|
74 |
+
# self.assertEqual(len(self.job_sync), num_jobs)
|
75 |
+
# for i, job in enumerate(self.job_sync):
|
76 |
+
# self.assertEqual(i, job.id)
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == '__main__':
|
80 |
+
unittest.main()
|
81 |
+
# test = TestPipeline()
|
82 |
+
# test.setUp()
|
83 |
+
# test.test_pipeline()
|