sohojoe commited on
Commit
ed232fa
β€’
1 Parent(s): de8cac1

create a pipeline class with basic test coverage

Browse files
pipeline_test.py β†’ pipeline.py RENAMED
@@ -1,11 +1,8 @@
1
  import asyncio
2
- import random
3
- import time
4
-
5
 
6
  class Job:
7
- def __init__(self, id, data):
8
- self.id = id
9
  self.data = data
10
 
11
 
@@ -31,6 +28,11 @@ class Node:
31
  self._jobs_dequeued += 1
32
  if self.sequential_node == False:
33
  await self.process_job(job)
 
 
 
 
 
34
  else:
35
  # ensure that jobs are processed in order
36
  self.buffer[job.id] = job
@@ -38,79 +40,65 @@ class Node:
38
  job = self.buffer.pop(self.next_i)
39
  await self.process_job(job)
40
  self.next_i += 1
41
- if self.output_queue is not None:
42
- await self.output_queue.put(job)
43
- if self.job_sync is not None:
44
- self.job_sync.append(job)
45
- self._jobs_processed += 1
46
-
47
- async def process_job(self, job: Job):
48
- raise NotImplementedError
49
 
50
-
51
- class Node1(Node):
52
  async def process_job(self, job: Job):
53
- job.data += f' (processed by node 1, worker {self.worker_id})'
54
-
55
-
56
- class Node2(Node):
57
- async def process_job(self, job: Job):
58
- sleep_duration = 0.8 + 0.4 * random.random()
59
- await asyncio.sleep(sleep_duration)
60
- job.data += f' (processed by node 2, worker {self.worker_id})'
61
-
62
-
63
- class Node3(Node):
64
- async def process_job(self, job: Job):
65
- job.data += f' (processed by node 3, worker {self.worker_id})'
66
- print(f'{job.id} - {job.data}')
67
-
68
-
69
- async def main():
70
- node1_queue = asyncio.Queue()
71
- node2_queue = asyncio.Queue()
72
- node3_queue = asyncio.Queue()
73
-
74
- num_jobs = 100
75
- joe_source = [Job(i, "") for i in range(num_jobs)]
76
- job_sync = []
77
-
78
- # create the workers
79
- num_workers = 5
80
- node1_workers = [Node1(i + 1, node1_queue, node2_queue) for i in range(1)]
81
- node2_workers = [Node2(i + 1, node2_queue, node3_queue) for i in range(num_workers)]
82
- node3_workers = [Node3(i + 1, node3_queue, job_sync=job_sync, sequential_node=True) for i in range(1)]
83
-
84
- # create tasks for the workers
85
- tasks1 = [asyncio.create_task(worker.run()) for worker in node1_workers]
86
- tasks2 = [asyncio.create_task(worker.run()) for worker in node2_workers]
87
- tasks3 = [asyncio.create_task(worker.run()) for worker in node3_workers]
88
-
89
- for job in joe_source:
90
- await node1_queue.put(job)
91
- # await input_queue.put(joe_source[0])
92
-
93
- try:
94
- while len(job_sync) < num_jobs:
95
- # print(f"Waiting for jobs to finish... Job sync size: {len(job_sync)}, node1_queue size: {node1_queue.qsize()}, node2_queue size: {node2_queue.qsize()}, node3_queue size: {node3_queue.qsize()}")
96
- await asyncio.sleep(0.1)
97
- except asyncio.CancelledError:
98
- print("Pipeline cancelled")
99
- for task in tasks1:
100
- task.cancel()
101
- for task in tasks2:
102
- task.cancel()
103
- for task in tasks3:
104
- task.cancel()
105
- await asyncio.gather(*tasks1, *tasks2, *tasks3, return_exceptions=True)
106
-
107
-
108
- start_time = time.time()
109
-
110
- try:
111
- asyncio.run(main())
112
- except KeyboardInterrupt:
113
- print("Pipeline interrupted by user")
114
-
115
- end_time = time.time()
116
- print(f"Pipeline processed in {end_time - start_time} seconds.")
 
1
  import asyncio
 
 
 
2
 
3
  class Job:
4
+ def __init__(self, data):
5
+ self._id = None
6
  self.data = data
7
 
8
 
 
28
  self._jobs_dequeued += 1
29
  if self.sequential_node == False:
30
  await self.process_job(job)
31
+ if self.output_queue is not None:
32
+ await self.output_queue.put(job)
33
+ if self.job_sync is not None:
34
+ self.job_sync.append(job)
35
+ self._jobs_processed += 1
36
  else:
37
  # ensure that jobs are processed in order
38
  self.buffer[job.id] = job
 
40
  job = self.buffer.pop(self.next_i)
41
  await self.process_job(job)
42
  self.next_i += 1
43
+ if self.output_queue is not None:
44
+ await self.output_queue.put(job)
45
+ if self.job_sync is not None:
46
+ self.job_sync.append(job)
47
+ self._jobs_processed += 1
 
 
 
48
 
 
 
49
  async def process_job(self, job: Job):
50
+ raise NotImplementedError()
51
+
52
+ class Pipeline:
53
+ def __init__(self):
54
+ self.input_queues = []
55
+ self.root_queue = None
56
+ # self.output_queues = []
57
+ # self.job_sysncs = []
58
+ self.nodes= []
59
+ self.node_workers = {}
60
+ self.tasks = []
61
+ self._job_id = 0
62
+
63
+ async def add_node(self, node: Node, num_workers=1, input_queue=None, output_queue=None, job_sync=None, sequential_node=False ):
64
+ # input_queue must not be None
65
+ if input_queue is None:
66
+ raise ValueError('input_queue is None')
67
+ # job_sync nodes must be sequential_nodes
68
+ if job_sync is not None and sequential_node == False:
69
+ raise ValueError('job_sync is not None and sequential_node is False')
70
+ # sequential_nodes should one have 1 worker
71
+ if sequential_node == True and num_workers != 1:
72
+ raise ValueError('sequentaial nodes can only have one node (sequential_node is True and num_workers is not 1)')
73
+ # output queue must not equal input_queue
74
+ if output_queue == input_queue:
75
+ raise ValueError('output_queue must not be the same as input_queue')
76
+
77
+ node_name = node.__class__.__name__
78
+ if node_name not in self.nodes:
79
+ self.nodes.append(node_name)
80
+
81
+ # if input_queue is None then this is the root node
82
+ if len(self.input_queues) is 0:
83
+ self.root_queue = input_queue
84
+
85
+ self.input_queues.append(input_queue)
86
+
87
+ for i in range(num_workers):
88
+ worker_id = i
89
+ node_worker = node(worker_id, input_queue, output_queue, job_sync, sequential_node)
90
+ self.node_workers[node_name] = node_worker
91
+ task = asyncio.create_task(node_worker.run())
92
+ self.tasks.append(task)
93
+
94
+ async def enqueue_job(self, job: Job):
95
+ job.id = self._job_id
96
+ self._job_id += 1
97
+ await self.root_queue.put(job)
98
+
99
+ async def close(self):
100
+ for task in self.tasks:
101
+ task.cancel()
102
+ await asyncio.gather(*self.tasks, return_exceptions=True)
103
+
104
+
 
 
 
 
 
 
 
 
 
tests/test_pipeline.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import random
3
+ import time
4
+ import unittest
5
+ import sys
6
+ import os
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+
9
+ from pipeline import Pipeline, Node, Job
10
+
11
+
12
+ class Node1(Node):
13
+ async def process_job(self, job: Job):
14
+ job.data += f' (processed by node 1, worker {self.worker_id})'
15
+
16
+
17
+ class Node2(Node):
18
+ async def process_job(self, job: Job):
19
+ sleep_duration = 0.08 + 0.04 * random.random()
20
+ await asyncio.sleep(sleep_duration)
21
+ job.data += f' (processed by node 2, worker {self.worker_id})'
22
+
23
+
24
+ class Node3(Node):
25
+ async def process_job(self, job: Job):
26
+ job.data += f' (processed by node 3, worker {self.worker_id})'
27
+ print(f'{job.id} - {job.data}')
28
+
29
+
30
+ class TestPipeline(unittest.TestCase):
31
+ def setUp(self):
32
+ pass
33
+
34
+ async def _test_pipeline_edge_cases(self):
35
+ # must have a input queue
36
+ with self.assertRaises(ValueError):
37
+ await self.pipeline.add_node(Node1, 1, None, None)
38
+ # too output queue must not equal from input queue
39
+ node1_queue = asyncio.Queue()
40
+ with self.assertRaises(ValueError):
41
+ await self.pipeline.add_node(Node1, 1, node1_queue, node1_queue)
42
+
43
+
44
+ async def _test_pipeline(self, num_jobs):
45
+ node1_queue = asyncio.Queue()
46
+ node2_queue = asyncio.Queue()
47
+ node3_queue = asyncio.Queue()
48
+ await self.pipeline.add_node(Node1, 1, node1_queue, node2_queue)
49
+ await self.pipeline.add_node(Node2, 5, node2_queue, node3_queue)
50
+ await self.pipeline.add_node(Node3, 1, node3_queue, job_sync=self.job_sync, sequential_node=True)
51
+ for i in range(num_jobs):
52
+ job = Job("")
53
+ await self.pipeline.enqueue_job(job)
54
+ while True:
55
+ if len(self.job_sync) == num_jobs:
56
+ break
57
+ await asyncio.sleep(0.1)
58
+ await self.pipeline.close()
59
+
60
+ def test_pipeline_edge_cases(self):
61
+ self.pipeline = Pipeline()
62
+ self.job_sync = []
63
+ asyncio.run(self._test_pipeline_edge_cases())
64
+
65
+
66
+ # def test_pipeline_keeps_order(self):
67
+ # self.pipeline = Pipeline()
68
+ # self.job_sync = []
69
+ # num_jobs = 100
70
+ # start_time = time.time()
71
+ # asyncio.run(self._test_pipeline(num_jobs))
72
+ # end_time = time.time()
73
+ # print(f"Pipeline processed in {end_time - start_time} seconds.")
74
+ # self.assertEqual(len(self.job_sync), num_jobs)
75
+ # for i, job in enumerate(self.job_sync):
76
+ # self.assertEqual(i, job.id)
77
+
78
+
79
+ if __name__ == '__main__':
80
+ unittest.main()
81
+ # test = TestPipeline()
82
+ # test.setUp()
83
+ # test.test_pipeline()