sohojoe commited on
Commit
8c9e2db
1 Parent(s): ed232fa

add: pipelines nodes can now spawn one to many jobs via yield

Browse files
Files changed (2) hide show
  1. pipeline.py +12 -12
  2. tests/test_pipeline.py +14 -11
pipeline.py CHANGED
@@ -27,24 +27,24 @@ class Node:
27
  job: Job = await self.input_queue.get()
28
  self._jobs_dequeued += 1
29
  if self.sequential_node == False:
30
- await self.process_job(job)
31
- if self.output_queue is not None:
32
- await self.output_queue.put(job)
33
- if self.job_sync is not None:
34
- self.job_sync.append(job)
35
- self._jobs_processed += 1
36
  else:
37
  # ensure that jobs are processed in order
38
  self.buffer[job.id] = job
39
  while self.next_i in self.buffer:
40
  job = self.buffer.pop(self.next_i)
41
- await self.process_job(job)
 
 
 
 
 
42
  self.next_i += 1
43
- if self.output_queue is not None:
44
- await self.output_queue.put(job)
45
- if self.job_sync is not None:
46
- self.job_sync.append(job)
47
- self._jobs_processed += 1
48
 
49
  async def process_job(self, job: Job):
50
  raise NotImplementedError()
 
27
  job: Job = await self.input_queue.get()
28
  self._jobs_dequeued += 1
29
  if self.sequential_node == False:
30
+ async for job in self.process_job(job):
31
+ if self.output_queue is not None:
32
+ await self.output_queue.put(job)
33
+ if self.job_sync is not None:
34
+ self.job_sync.append(job)
35
+ self._jobs_processed += 1
36
  else:
37
  # ensure that jobs are processed in order
38
  self.buffer[job.id] = job
39
  while self.next_i in self.buffer:
40
  job = self.buffer.pop(self.next_i)
41
+ async for job in self.process_job(job):
42
+ if self.output_queue is not None:
43
+ await self.output_queue.put(job)
44
+ if self.job_sync is not None:
45
+ self.job_sync.append(job)
46
+ self._jobs_processed += 1
47
  self.next_i += 1
 
 
 
 
 
48
 
49
  async def process_job(self, job: Job):
50
  raise NotImplementedError()
tests/test_pipeline.py CHANGED
@@ -12,6 +12,7 @@ from pipeline import Pipeline, Node, Job
12
  class Node1(Node):
13
  async def process_job(self, job: Job):
14
  job.data += f' (processed by node 1, worker {self.worker_id})'
 
15
 
16
 
17
  class Node2(Node):
@@ -19,12 +20,14 @@ class Node2(Node):
19
  sleep_duration = 0.08 + 0.04 * random.random()
20
  await asyncio.sleep(sleep_duration)
21
  job.data += f' (processed by node 2, worker {self.worker_id})'
 
22
 
23
 
24
  class Node3(Node):
25
  async def process_job(self, job: Job):
26
  job.data += f' (processed by node 3, worker {self.worker_id})'
27
  print(f'{job.id} - {job.data}')
 
28
 
29
 
30
  class TestPipeline(unittest.TestCase):
@@ -63,17 +66,17 @@ class TestPipeline(unittest.TestCase):
63
  asyncio.run(self._test_pipeline_edge_cases())
64
 
65
 
66
- # def test_pipeline_keeps_order(self):
67
- # self.pipeline = Pipeline()
68
- # self.job_sync = []
69
- # num_jobs = 100
70
- # start_time = time.time()
71
- # asyncio.run(self._test_pipeline(num_jobs))
72
- # end_time = time.time()
73
- # print(f"Pipeline processed in {end_time - start_time} seconds.")
74
- # self.assertEqual(len(self.job_sync), num_jobs)
75
- # for i, job in enumerate(self.job_sync):
76
- # self.assertEqual(i, job.id)
77
 
78
 
79
  if __name__ == '__main__':
 
12
  class Node1(Node):
13
  async def process_job(self, job: Job):
14
  job.data += f' (processed by node 1, worker {self.worker_id})'
15
+ yield job
16
 
17
 
18
  class Node2(Node):
 
20
  sleep_duration = 0.08 + 0.04 * random.random()
21
  await asyncio.sleep(sleep_duration)
22
  job.data += f' (processed by node 2, worker {self.worker_id})'
23
+ yield job
24
 
25
 
26
  class Node3(Node):
27
  async def process_job(self, job: Job):
28
  job.data += f' (processed by node 3, worker {self.worker_id})'
29
  print(f'{job.id} - {job.data}')
30
+ yield job
31
 
32
 
33
  class TestPipeline(unittest.TestCase):
 
66
  asyncio.run(self._test_pipeline_edge_cases())
67
 
68
 
69
+ def test_pipeline_keeps_order(self):
70
+ self.pipeline = Pipeline()
71
+ self.job_sync = []
72
+ num_jobs = 100
73
+ start_time = time.time()
74
+ asyncio.run(self._test_pipeline(num_jobs))
75
+ end_time = time.time()
76
+ print(f"Pipeline processed in {end_time - start_time} seconds.")
77
+ self.assertEqual(len(self.job_sync), num_jobs)
78
+ for i, job in enumerate(self.job_sync):
79
+ self.assertEqual(i, job.id)
80
 
81
 
82
  if __name__ == '__main__':