sohojoe commited on
Commit
de8cac1
1 Parent(s): 80eea9e

create a node class to help things be a bit more robust

Browse files
Files changed (1) hide show
  1. pipeline_test.py +75 -41
pipeline_test.py CHANGED
@@ -9,66 +9,100 @@ class Job:
9
  self.data = data
10
 
11
 
12
- async def node1(worker_id: int, input_queue, output_queue):
13
- while True:
14
- job:Job = await input_queue.get()
15
- job.data += f' (processed by node 1, worker {worker_id})'
16
- await output_queue.put(job)
17
-
18
- async def node2(worker_id: int, input_queue, output_queue):
19
- while True:
20
- job:Job = await input_queue.get()
21
- sleep_duration = 0.8 + 0.4 * random.random() # Generate a random sleep duration between 0.8 and 1.2 seconds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  await asyncio.sleep(sleep_duration)
23
- job.data += f' (processed by node 2, worker {worker_id})'
24
- await output_queue.put(job)
25
-
26
- async def node3(worker_id: int, input_queue, job_sync):
27
- buffer = {}
28
- next_i = 0
29
- while True:
30
- job:Job = await input_queue.get()
31
- buffer[job.id] = job # Store the data in the buffer
32
- # While the next expected item is in the buffer, output it and increment the index
33
- while next_i in buffer:
34
- curr_job = buffer.pop(next_i)
35
- curr_job.data += f' (processed by node 3, worker {worker_id})'
36
- print(f'{curr_job.id} - {curr_job.data}')
37
- next_i += 1
38
- job_sync.append(curr_job)
39
 
40
  async def main():
41
- input_queue = asyncio.Queue()
42
- buffer_queue = asyncio.Queue()
43
- output_queue = asyncio.Queue()
44
 
45
  num_jobs = 100
46
  joe_source = [Job(i, "") for i in range(num_jobs)]
47
  job_sync = []
48
 
49
- task1 = asyncio.create_task(node1(None, input_queue, buffer_queue))
50
- task3 = asyncio.create_task(node3(None, output_queue, job_sync))
51
-
52
  num_workers = 5
53
- tasks2 = []
54
- for i in range(num_workers):
55
- task2 = asyncio.create_task(node2(i + 1, buffer_queue, output_queue))
56
- tasks2.append(task2)
 
 
 
 
57
 
58
  for job in joe_source:
59
- await input_queue.put(job)
 
60
 
61
  try:
62
- # await asyncio.gather(task1, *tasks2, task3)
63
  while len(job_sync) < num_jobs:
 
64
  await asyncio.sleep(0.1)
65
  except asyncio.CancelledError:
66
  print("Pipeline cancelled")
67
- task1.cancel()
 
68
  for task in tasks2:
69
  task.cancel()
70
- task3.cancel()
71
- await asyncio.gather(task1, *tasks2, task3, return_exceptions=True)
 
72
 
73
 
74
  start_time = time.time()
 
9
  self.data = data
10
 
11
 
12
+ class Node:
13
+ # def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None):
14
+ def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ):
15
+ self.worker_id = worker_id
16
+ self.input_queue = input_queue
17
+ self.output_queue = output_queue
18
+ self.buffer = {}
19
+ self.job_sync = job_sync
20
+ self.sequential_node = sequential_node
21
+ self.next_i = 0
22
+ self._jobs_dequeued = 0
23
+ self._jobs_processed = 0
24
+ # throw an error if job_sync is not None and sequential_node is False
25
+ if self.job_sync is not None and self.sequential_node == False:
26
+ raise ValueError('job_sync is not None and sequential_node is False')
27
+
28
+ async def run(self):
29
+ while True:
30
+ job: Job = await self.input_queue.get()
31
+ self._jobs_dequeued += 1
32
+ if self.sequential_node == False:
33
+ await self.process_job(job)
34
+ else:
35
+ # ensure that jobs are processed in order
36
+ self.buffer[job.id] = job
37
+ while self.next_i in self.buffer:
38
+ job = self.buffer.pop(self.next_i)
39
+ await self.process_job(job)
40
+ self.next_i += 1
41
+ if self.output_queue is not None:
42
+ await self.output_queue.put(job)
43
+ if self.job_sync is not None:
44
+ self.job_sync.append(job)
45
+ self._jobs_processed += 1
46
+
47
+ async def process_job(self, job: Job):
48
+ raise NotImplementedError
49
+
50
+
51
+ class Node1(Node):
52
+ async def process_job(self, job: Job):
53
+ job.data += f' (processed by node 1, worker {self.worker_id})'
54
+
55
+
56
+ class Node2(Node):
57
+ async def process_job(self, job: Job):
58
+ sleep_duration = 0.8 + 0.4 * random.random()
59
  await asyncio.sleep(sleep_duration)
60
+ job.data += f' (processed by node 2, worker {self.worker_id})'
61
+
62
+
63
+ class Node3(Node):
64
+ async def process_job(self, job: Job):
65
+ job.data += f' (processed by node 3, worker {self.worker_id})'
66
+ print(f'{job.id} - {job.data}')
67
+
 
 
 
 
 
 
 
 
68
 
69
  async def main():
70
+ node1_queue = asyncio.Queue()
71
+ node2_queue = asyncio.Queue()
72
+ node3_queue = asyncio.Queue()
73
 
74
  num_jobs = 100
75
  joe_source = [Job(i, "") for i in range(num_jobs)]
76
  job_sync = []
77
 
78
+ # create the workers
 
 
79
  num_workers = 5
80
+ node1_workers = [Node1(i + 1, node1_queue, node2_queue) for i in range(1)]
81
+ node2_workers = [Node2(i + 1, node2_queue, node3_queue) for i in range(num_workers)]
82
+ node3_workers = [Node3(i + 1, node3_queue, job_sync=job_sync, sequential_node=True) for i in range(1)]
83
+
84
+ # create tasks for the workers
85
+ tasks1 = [asyncio.create_task(worker.run()) for worker in node1_workers]
86
+ tasks2 = [asyncio.create_task(worker.run()) for worker in node2_workers]
87
+ tasks3 = [asyncio.create_task(worker.run()) for worker in node3_workers]
88
 
89
  for job in joe_source:
90
+ await node1_queue.put(job)
91
+ # await input_queue.put(joe_source[0])
92
 
93
  try:
 
94
  while len(job_sync) < num_jobs:
95
+ # print(f"Waiting for jobs to finish... Job sync size: {len(job_sync)}, node1_queue size: {node1_queue.qsize()}, node2_queue size: {node2_queue.qsize()}, node3_queue size: {node3_queue.qsize()}")
96
  await asyncio.sleep(0.1)
97
  except asyncio.CancelledError:
98
  print("Pipeline cancelled")
99
+ for task in tasks1:
100
+ task.cancel()
101
  for task in tasks2:
102
  task.cancel()
103
+ for task in tasks3:
104
+ task.cancel()
105
+ await asyncio.gather(*tasks1, *tasks2, *tasks3, return_exceptions=True)
106
 
107
 
108
  start_time = time.time()