sohojoe commited on
Commit
8ea370a
1 Parent(s): 0be6b60

remove old files

Browse files
Agent/character_properties.py DELETED
@@ -1,9 +0,0 @@
1
- class CharacterProperties:
2
- def __init__(self, happiness, energy, confidence, fear, excitement, wanderlust, restful):
3
- self.happiness = happiness
4
- self.energy = energy
5
- self.confidence = confidence
6
- self.fear = fear
7
- self.excitement = excitement
8
- self.wanderlust = wanderlust
9
- self.restful = restful
 
 
 
 
 
 
 
 
 
 
Agent/character_state.py DELETED
@@ -1,26 +0,0 @@
1
-
2
-
3
- from Agent.character_properties import CharacterProperties
4
-
5
-
6
- class CharacterState:
7
- def __init__(self, name, fixed_traits, properties):
8
- self.name = name
9
- self.fixed_traits = fixed_traits
10
- self.properties = properties
11
-
12
- charles_properties = CharacterProperties(
13
- happiness = 0.85,
14
- energy = 0.9,
15
- confidence = 0.8,
16
- fear = 0.2,
17
- excitement = 0.75,
18
- wanderlust = 0.95,
19
- restful = 0.8
20
- )
21
-
22
- charles_state = CharacterState(
23
- name = 'Charles Petrescu',
24
- fixed_traits = ['quirky', 'honest', 'inquisitive', 'adventurous', 'friendly', 'random', 'knowledgeable', 'humorous'],
25
- properties = charles_properties
26
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
debug_000.py DELETED
@@ -1,74 +0,0 @@
1
- import os
2
- import subprocess
3
- from elevenlabs import generate, play
4
- from elevenlabs import set_api_key
5
- from elevenlabs import generate, stream
6
- from dotenv import load_dotenv
7
- load_dotenv()
8
-
9
- account_sid = os.environ["ELEVENLABS_API_KEY"]
10
- voice_id="2OviOUQc1JsQRQgNkVBj"
11
- model_id="eleven_monolingual_v1"
12
- set_api_key(account_sid)
13
-
14
- def stream_tts(prompt):
15
- audio_stream = generate(
16
- text=prompt,
17
- voice=voice_id,
18
- model=model_id,
19
- stream_chunk_size=4096,
20
- stream=True,
21
- )
22
- return audio_stream
23
-
24
- prompts=[
25
- "erm",
26
- "Cabbages, my dear friend!",
27
- "Did you know that the world's largest cabbage weighed 62.71 kilograms?",
28
- "Simply remarkable!",
29
- "How are you today?",
30
- ]
31
-
32
- mpv_command = ["mpv", "--no-cache", "--no-terminal", "--", "fd://0"]
33
- mpv_process = subprocess.Popen(
34
- mpv_command,
35
- stdin=subprocess.PIPE,
36
- stdout=subprocess.DEVNULL,
37
- stderr=subprocess.DEVNULL,
38
- )
39
-
40
- load_chunks = False
41
- load_chunks = os.path.exists("chunks.pkl")
42
-
43
- # check if chunks.pkl exists
44
- if load_chunks:
45
- # try open chunks
46
- with open("chunks.pkl", "rb") as f:
47
- import pickle
48
- chunks = pickle.load(f)
49
- for chunk in chunks:
50
- mpv_process.stdin.write(chunk)
51
- mpv_process.stdin.flush()
52
-
53
- else:
54
- chunks = []
55
-
56
- for prompt in prompts:
57
- for chunk in stream_tts(prompt):
58
- if chunk is not None:
59
- chunks.append(chunk)
60
- mpv_process.stdin.write(chunk) # type: ignore
61
- mpv_process.stdin.flush() # type: ignore
62
-
63
- # save chunks to file as a pickled list of bytes
64
- with open("chunks.pkl", "wb") as f:
65
- import pickle
66
- pickle.dump(chunks, f)
67
- with open("chunks.mp3", "wb") as f:
68
- for chunk in chunks:
69
- f.write(chunk)
70
-
71
-
72
- if mpv_process.stdin:
73
- mpv_process.stdin.close()
74
- mpv_process.wait()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
debug_001.py DELETED
@@ -1,64 +0,0 @@
1
- import io
2
- import os
3
- import subprocess
4
- import av
5
- import numpy as np
6
-
7
-
8
- mpv_command = ["mpv", "--no-cache", "--no-terminal", "--", "fd://0"]
9
- mpv_process = subprocess.Popen(
10
- mpv_command,
11
- stdin=subprocess.PIPE,
12
- stdout=subprocess.DEVNULL,
13
- stderr=subprocess.DEVNULL,
14
- )
15
-
16
- load_chunks = False
17
- load_chunks = os.path.exists("chunks.pkl")
18
-
19
-
20
- audio_frames = []
21
-
22
- # try open chunks
23
- with open("chunks.pkl", "rb") as f:
24
- import pickle
25
- chunks = pickle.load(f)
26
- append = False
27
- for chunk in chunks:
28
- mpv_process.stdin.write(chunk)
29
- mpv_process.stdin.flush()
30
- # np_chunk = np.frombuffer(chunk, dtype=np.int16)
31
- # aa = av.AudioFrame.from_ndarray(chunk)
32
- try:
33
- if append:
34
- bytes_io.write(chunk)
35
- append = False
36
- bytes_io.seek(0)
37
- else:
38
- bytes_io = io.BytesIO(chunk)
39
- container = av.open(bytes_io, 'r', format='mp3')
40
- audio_stream = next(s for s in container.streams if s.type == 'audio')
41
- for frame in container.decode(audio_stream):
42
- # Convert the audio frame to a NumPy array
43
- array = frame.to_ndarray()
44
-
45
- # Now you can use av.AudioFrame.from_ndarray
46
- audio_frame = av.AudioFrame.from_ndarray(array, format='flt', layout='mono')
47
- audio_frame.sample_rate = 44100
48
-
49
- audio_frames.append(audio_frame)
50
-
51
- except Exception as e:
52
- print (e)
53
- append = True
54
- bytes_io.seek(0, io.SEEK_END)
55
- continue
56
-
57
- # with open("frames.pkl", "wb") as f:
58
- # import pickle
59
- # pickle.dump(audio_frames, f)
60
-
61
-
62
- if mpv_process.stdin:
63
- mpv_process.stdin.close()
64
- mpv_process.wait()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
debug_app.py DELETED
@@ -1,106 +0,0 @@
1
- import logging
2
- import traceback
3
- from typing import List
4
-
5
- import av
6
- import numpy as np
7
- import streamlit as st
8
- from streamlit_webrtc import WebRtcMode, webrtc_streamer
9
- import pydub
10
-
11
- from dotenv import load_dotenv
12
- from ffmpeg_converter_actor import FFMpegConverterActor
13
- load_dotenv()
14
- from sample_utils.turn import get_ice_servers
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- import os
20
-
21
- import ray
22
- from ray.util.queue import Queue
23
- if not ray.is_initialized():
24
- # Try to connect to a running Ray cluster
25
- ray_address = os.getenv('RAY_ADDRESS')
26
- if ray_address:
27
- ray.init(ray_address, namespace="project_charles")
28
- else:
29
- ray.init(namespace="project_charles")
30
-
31
-
32
-
33
-
34
- def video_frame_callback(
35
- frame: av.VideoFrame,
36
- ) -> av.VideoFrame:
37
- return frame
38
-
39
-
40
- with open("chunks.pkl", "rb") as f:
41
- import pickle
42
- debug_chunks = pickle.load(f)
43
-
44
- converter_queue = Queue(maxsize=100)
45
- converter_actor = FFMpegConverterActor.remote(converter_queue)
46
- ray.get(converter_actor.start_process.remote())
47
- converter_actor.run.remote()
48
- for chunk in debug_chunks:
49
- ray.get(converter_actor.push_chunk.remote(chunk))
50
-
51
-
52
- # emptry array of type int16
53
- sample_buffer = np.zeros((0), dtype=np.int16)
54
-
55
- def process_frame(old_frame):
56
-
57
- try:
58
- output_channels = 2
59
- output_sample_rate = 44100
60
- required_samples = old_frame.samples
61
-
62
- if not converter_queue.empty():
63
- frame_as_bytes = converter_queue.get()
64
- # print(f"frame_as_bytes: {len(frame_as_bytes)}")
65
- samples = np.frombuffer(frame_as_bytes, dtype=np.int16)
66
- else:
67
- # create a byte array of zeros
68
- samples = np.zeros((required_samples * 2 * 1), dtype=np.int16)
69
-
70
- # Duplicate mono channel for stereo
71
- if output_channels == 2:
72
- samples = np.vstack((samples, samples)).reshape((-1,), order='F')
73
-
74
- samples = samples.reshape(1, -1)
75
-
76
- layout = 'stereo' if output_channels == 2 else 'mono'
77
- new_frame = av.AudioFrame.from_ndarray(samples, format='s16', layout=layout)
78
- new_frame.sample_rate = old_frame.sample_rate
79
- new_frame.pts = old_frame.pts
80
- return new_frame
81
- except Exception as e:
82
- print (e)
83
- traceback.print_exc()
84
- raise(e)
85
-
86
-
87
-
88
- def audio_frame_callback(old_frame: av.AudioFrame) -> av.AudioFrame:
89
- new_frame = process_frame(old_frame)
90
-
91
- # print (f"frame: {old_frame}, pts: {old_frame.pts}")
92
- # print (f"new_frame: {new_frame}, pts: {new_frame.pts}")
93
-
94
- return new_frame
95
- # return old_frame
96
-
97
-
98
- webrtc_streamer(
99
- key="delay",
100
- mode=WebRtcMode.SENDRECV,
101
- rtc_configuration={"iceServers": get_ice_servers()},
102
-
103
- video_frame_callback=video_frame_callback,
104
- audio_frame_callback=audio_frame_callback,
105
-
106
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy_to_delete/chat_pipeline.py DELETED
@@ -1,96 +0,0 @@
1
- import asyncio
2
- import time
3
- from clip_transform import CLIPTransform
4
- from chat_service import ChatService
5
- from dotenv import load_dotenv
6
- from text_to_speech_service import TextToSpeechService
7
- from concurrent.futures import ThreadPoolExecutor
8
- from local_speaker_service import LocalSpeakerService
9
- from chat_service import ChatService
10
- from legacy_to_delete.pipeline import Pipeline, Node, Job
11
- from typing import List
12
-
13
- class ChatJob(Job):
14
- def __init__(self, data, chat_service: ChatService):
15
- super().__init__(data)
16
- self.chat_service = chat_service
17
-
18
- class Node1(Node):
19
- next_id = 0
20
-
21
- async def process_job(self, job: ChatJob):
22
- # input job.data is the input string
23
- # output job.data is the next sentance
24
- async for sentence in job.chat_service.get_responses_as_sentances_async(job.data):
25
- if job.chat_service.ignore_sentence(sentence):
26
- continue
27
- print(f"{sentence}")
28
- new_job = ChatJob(sentence, job.chat_service)
29
- new_job.id = self.next_id
30
- self.next_id += 1
31
- yield new_job
32
-
33
- class Node2(Node):
34
- next_id = 0
35
-
36
- async def process_job(self, job: ChatJob):
37
- # input job.data is the sentance
38
- # output job.data is the streamed speech bytes
39
- async for chunk in job.chat_service.get_speech_chunks_async(job.data):
40
- new_job = ChatJob(chunk, job.chat_service)
41
- new_job.id = self.next_id
42
- self.next_id += 1
43
- yield new_job
44
-
45
-
46
- class Node3(Node):
47
- # sync_size = 64
48
- # sync = []
49
-
50
- async def process_job(self, job: ChatJob):
51
- # input job.data is the streamed speech bytes
52
- # Node3.sync.append(job.data)
53
- job.chat_service.enqueue_speech_bytes_to_play([job.data])
54
- yield job
55
- # if len(Node3.sync) >= Node3.sync_size:
56
- # audio_chunks = Node3.sync[:Node3.sync_size]
57
- # Node3.sync = Node3.sync[Node3.sync_size:]
58
- # job.chat_service.enqueue_speech_bytes_to_play(audio_chunks)
59
- # yield job
60
-
61
- class ChatPipeline():
62
- def __init__(self):
63
- load_dotenv()
64
- self.pipeline = Pipeline()
65
- self.audio_processor = LocalSpeakerService()
66
- self.chat_service = ChatService(self.audio_processor, voice_id="2OviOUQc1JsQRQgNkVBj") # Chales003
67
-
68
- def __enter__(self):
69
- return self
70
-
71
- def __exit__(self, exc_type, exc_value, traceback):
72
- self.audio_processor.close()
73
- self.audio_processor = None
74
-
75
- def __del__(self):
76
- if self.audio_processor:
77
- self.audio_processor.close()
78
- self.audio_processor = None
79
-
80
- async def start(self):
81
- self.node1_queue = asyncio.Queue()
82
- self.node2_queue = asyncio.Queue()
83
- self.node3_queue = asyncio.Queue()
84
- self.sync = []
85
- await self.pipeline.add_node(Node1, 1, self.node1_queue, self.node2_queue, sequential_node=True)
86
- await self.pipeline.add_node(Node2, 1, self.node2_queue, self.node3_queue, sequential_node=True)
87
- await self.pipeline.add_node(Node3, 1, self.node3_queue, None, sequential_node=True)
88
-
89
- async def enqueue(self, prompt):
90
- job = ChatJob(prompt, self.chat_service)
91
- await self.pipeline.enqueue_job(job)
92
-
93
- async def wait_until_all_jobs_idle(self):
94
- # TODO - implement this
95
- while True:
96
- await asyncio.sleep(0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy_to_delete/debug.py DELETED
@@ -1,157 +0,0 @@
1
- import asyncio
2
- import time
3
- import traceback
4
- from chat_pipeline import ChatPipeline
5
- from clip_transform import CLIPTransform
6
- from chat_service import ChatService
7
- from dotenv import load_dotenv
8
- from text_to_speech_service import TextToSpeechService
9
- from concurrent.futures import ThreadPoolExecutor
10
- from local_speaker_service import LocalSpeakerService
11
- from chat_service import ChatService
12
-
13
- def time_sentance_lenghts():
14
- load_dotenv()
15
-
16
- print ("Initializing Chat")
17
- # audio_processor = AudioStreamProcessor()
18
- user_speech_service0 = TextToSpeechService(voice_id="Adam")
19
- prompts = [
20
- "hello, i am a long sentance, how are you today? Tell me about your shadow self?",
21
- "a shorter sentance",
22
- "Jung believed that the process of self-discovery and personal growth involves confronting and integrating the shadow self into the conscious mind.",
23
- "By doing so, we become more self-aware and more fully actualized individuals.",
24
- ]
25
-
26
- print ("Timing prompts\n")
27
- for prompt in prompts:
28
- start_time = time.time()
29
- start_stream_time = time.time()
30
- stream = user_speech_service0.stream(prompt)
31
- audio = b""
32
- for chunk in stream:
33
- if chunk is not None:
34
- audio += chunk
35
- end_stream_time = time.time()
36
- from elevenlabs import play
37
- start_speech_time = time.time()
38
- play(audio)
39
- end_speech_time = time.time()
40
- end_time = time.time()
41
- total_time = (end_time - start_time)
42
- stream_time = (end_stream_time - start_stream_time)
43
- speech_time = (end_speech_time - start_speech_time)
44
- stream_multiple = speech_time / stream_time
45
- print(f"Stream time: {stream_time:.4f}, Acutual audio time: {speech_time:.4f}, a multiple of {stream_multiple:.2f}. for prompt: {prompt}")
46
-
47
- print ("\nChat success")
48
-
49
- def test_sentance_lenghts():
50
- load_dotenv()
51
-
52
- print ("Initializing Chat")
53
- audio_processor = LocalSpeakerService()
54
- user_speech_service0 = TextToSpeechService(voice_id="Adam")
55
- user_speech_service1 = TextToSpeechService(voice_id="Adam")
56
- user_speech_service2 = TextToSpeechService(voice_id="Adam")
57
- user_speech_service3 = TextToSpeechService(voice_id="Adam")
58
-
59
- prompts = [
60
- "hello, i am a long sentance, how are you today? Tell me about your shadow self?",
61
- "a shorter sentance",
62
- "Jung believed that the process of self-discovery and personal growth involves confronting and integrating the shadow self into the conscious mind.",
63
- "By doing so, we become more self-aware and more fully actualized individuals.",
64
- ]
65
- first = True
66
- stream1 = user_speech_service1.stream(prompts[1])
67
- stream0 = user_speech_service0.stream(prompts[0])
68
- time.sleep(5)
69
- stream2 = user_speech_service2.stream(prompts[2])
70
- stream3 = user_speech_service3.stream(prompts[3])
71
- audio_processor.add_audio_stream(stream0)
72
- audio_processor.add_audio_stream(stream1)
73
- audio_processor.add_audio_stream(stream2)
74
- audio_processor.add_audio_stream(stream3)
75
- audio_processor.close()
76
- from elevenlabs import generate, play
77
- speech0 = generate(prompts[0], voice="Adam")
78
- speech1 = generate(prompts[1], voice="Adam")
79
- speech2 = generate(prompts[2], voice="Adam")
80
- speech3 = generate(prompts[3], voice="Adam")
81
- play(speech0)
82
- play(speech1)
83
- play(speech2)
84
- play(speech1)
85
- play(speech3)
86
- play(speech1)
87
- # for prompt in prompts:
88
- # stream = user_speech_service.stream(prompt)
89
- # if first:
90
- # first = False
91
- # time.sleep(5)
92
- # audio_processor.add_audio_stream(stream)
93
- audio_processor.close()
94
- print ("Chat success")
95
-
96
- def run_debug_code():
97
- load_dotenv()
98
-
99
- # print ("Initializing CLIP templates")
100
- # clip_transform = CLIPTransform()
101
- # print ("CLIP success")
102
-
103
- print ("Initializing Chat")
104
- # chat_service = ChatService()
105
- audio_processor = LocalSpeakerService()
106
- chat_service = ChatService(audio_processor, voice_id="2OviOUQc1JsQRQgNkVBj") # Chales003
107
-
108
- user_speech_service = TextToSpeechService(voice_id="Adam")
109
-
110
- # user_speech_service.print_voices() # if you want to see your custom voices
111
-
112
- prompts = [
113
- "hello, how are you today?",
114
- "tell me about your shadow self?",
115
- "hmm, interesting, tell me more about that.",
116
- "wait, that is so interesting, what else?",
117
- ]
118
- for prompt in prompts:
119
- print ("")
120
- print (f'prompt: "{prompt}"')
121
- stream = user_speech_service.stream(prompt)
122
- audio_processor.add_audio_stream(stream)
123
-
124
- print ("")
125
- print (f'response:')
126
- response = chat_service.respond_to(prompt)
127
-
128
- audio_processor.close()
129
- print ("Chat success")
130
-
131
- async def run_pipeline():
132
- load_dotenv()
133
-
134
- try:
135
- chat_pipeline = ChatPipeline()
136
- await chat_pipeline.start()
137
- prompts = [
138
- "hello, how are you today?",
139
- "tell me about your shadow self?",
140
- "hmm, interesting, tell me more about that.",
141
- "wait, that is so interesting, what else?",
142
- ]
143
- for prompt in prompts:
144
- await chat_pipeline.enqueue(prompt)
145
- await chat_pipeline.wait_until_all_jobs_idle()
146
- except KeyboardInterrupt:
147
- print("Pipeline interrupted by user")
148
- except Exception as e:
149
- traceback.print_exc()
150
- print(f"An error occurred: {e}")
151
-
152
- if __name__ == '__main__':
153
- # time_sentance_lenghts()
154
- # test_sentance_lenghts()
155
- # run_debug_code()
156
- asyncio.run(run_pipeline())
157
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy_to_delete/pipeline.py DELETED
@@ -1,112 +0,0 @@
1
- import asyncio
2
- import traceback
3
-
4
- class Job:
5
- def __init__(self, data):
6
- self._id = None
7
- self.data = data
8
-
9
-
10
- class Node:
11
- # def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None):
12
- def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ):
13
- self.worker_id = worker_id
14
- self.input_queue = input_queue
15
- self.output_queue = output_queue
16
- self.buffer = {}
17
- self.job_sync = job_sync
18
- self.sequential_node = sequential_node
19
- self.next_i = 0
20
- self._jobs_dequeued = 0
21
- self._jobs_processed = 0
22
- # throw an error if job_sync is not None and sequential_node is False
23
- if self.job_sync is not None and self.sequential_node == False:
24
- raise ValueError('job_sync is not None and sequential_node is False')
25
-
26
- async def run(self):
27
- try:
28
- while True:
29
- job: Job = await self.input_queue.get()
30
- self._jobs_dequeued += 1
31
- if self.sequential_node == False:
32
- async for job in self.process_job(job):
33
- if self.output_queue is not None:
34
- await self.output_queue.put(job)
35
- if self.job_sync is not None:
36
- self.job_sync.append(job)
37
- self._jobs_processed += 1
38
- else:
39
- # ensure that jobs are processed in order
40
- self.buffer[job.id] = job
41
- while self.next_i in self.buffer:
42
- job = self.buffer.pop(self.next_i)
43
- async for job in self.process_job(job):
44
- if self.output_queue is not None:
45
- await self.output_queue.put(job)
46
- if self.job_sync is not None:
47
- self.job_sync.append(job)
48
- self._jobs_processed += 1
49
- self.next_i += 1
50
- except Exception as e:
51
- print(f"An error occurred in node: {self.__class__.__name__} worker: {self.worker_id}: {e}")
52
- traceback.print_exc()
53
- raise # Re-raises the last exception.
54
-
55
- async def process_job(self, job: Job):
56
- raise NotImplementedError()
57
-
58
- class Pipeline:
59
- def __init__(self):
60
- self.input_queues = []
61
- self.root_queue = None
62
- # self.output_queues = []
63
- # self.job_sysncs = []
64
- self.nodes= []
65
- self.node_workers = {}
66
- self.tasks = []
67
- self._job_id = 0
68
-
69
- async def add_node(self, node: Node, num_workers=1, input_queue=None, output_queue=None, job_sync=None, sequential_node=False ):
70
- # input_queue must not be None
71
- if input_queue is None:
72
- raise ValueError('input_queue is None')
73
- # job_sync nodes must be sequential_nodes
74
- if job_sync is not None and sequential_node == False:
75
- raise ValueError('job_sync is not None and sequential_node is False')
76
- # sequential_nodes should one have 1 worker
77
- if sequential_node == True and num_workers != 1:
78
- raise ValueError('sequentaial nodes can only have one node (sequential_node is True and num_workers is not 1)')
79
- # output queue must not equal input_queue
80
- if output_queue == input_queue:
81
- raise ValueError('output_queue must not be the same as input_queue')
82
-
83
- node_name = node.__name__
84
- if node_name not in self.nodes:
85
- self.nodes.append(node_name)
86
-
87
- # if input_queue is None then this is the root node
88
- if len(self.input_queues) == 0:
89
- self.root_queue = input_queue
90
-
91
- self.input_queues.append(input_queue)
92
-
93
- for i in range(num_workers):
94
- worker_id = i
95
- node_worker = node(worker_id, input_queue, output_queue, job_sync, sequential_node)
96
- if node_name not in self.node_workers:
97
- self.node_workers[node_name] = []
98
- self.node_workers[node_name].append(node_worker)
99
- task = asyncio.create_task(node_worker.run())
100
- self.tasks.append(task)
101
-
102
- async def enqueue_job(self, job: Job):
103
- job.id = self._job_id
104
- self._job_id += 1
105
- await self.root_queue.put(job)
106
-
107
- async def close(self):
108
- for task in self.tasks:
109
- task.cancel()
110
- await asyncio.gather(*self.tasks, return_exceptions=True)
111
-
112
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_pipeline.py DELETED
@@ -1,86 +0,0 @@
1
- import asyncio
2
- import random
3
- import time
4
- import unittest
5
- import sys
6
- import os
7
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
-
9
- from legacy_to_delete.pipeline import Pipeline, Node, Job
10
-
11
-
12
- class Node1(Node):
13
- async def process_job(self, job: Job):
14
- job.data += f' (processed by node 1, worker {self.worker_id})'
15
- yield job
16
-
17
-
18
- class Node2(Node):
19
- async def process_job(self, job: Job):
20
- sleep_duration = 0.08 + 0.04 * random.random()
21
- await asyncio.sleep(sleep_duration)
22
- job.data += f' (processed by node 2, worker {self.worker_id})'
23
- yield job
24
-
25
-
26
- class Node3(Node):
27
- async def process_job(self, job: Job):
28
- job.data += f' (processed by node 3, worker {self.worker_id})'
29
- print(f'{job.id} - {job.data}')
30
- yield job
31
-
32
-
33
- class TestPipeline(unittest.TestCase):
34
- def setUp(self):
35
- pass
36
-
37
- async def _test_pipeline_edge_cases(self):
38
- # must have a input queue
39
- with self.assertRaises(ValueError):
40
- await self.pipeline.add_node(Node1, 1, None, None)
41
- # too output queue must not equal from input queue
42
- node1_queue = asyncio.Queue()
43
- with self.assertRaises(ValueError):
44
- await self.pipeline.add_node(Node1, 1, node1_queue, node1_queue)
45
-
46
-
47
- async def _test_pipeline(self, num_jobs):
48
- node1_queue = asyncio.Queue()
49
- node2_queue = asyncio.Queue()
50
- node3_queue = asyncio.Queue()
51
- await self.pipeline.add_node(Node1, 1, node1_queue, node2_queue)
52
- await self.pipeline.add_node(Node2, 5, node2_queue, node3_queue)
53
- await self.pipeline.add_node(Node3, 1, node3_queue, job_sync=self.job_sync, sequential_node=True)
54
- for i in range(num_jobs):
55
- job = Job("")
56
- await self.pipeline.enqueue_job(job)
57
- while True:
58
- if len(self.job_sync) == num_jobs:
59
- break
60
- await asyncio.sleep(0.1)
61
- await self.pipeline.close()
62
-
63
- def test_pipeline_edge_cases(self):
64
- self.pipeline = Pipeline()
65
- self.job_sync = []
66
- asyncio.run(self._test_pipeline_edge_cases())
67
-
68
-
69
- def test_pipeline_keeps_order(self):
70
- self.pipeline = Pipeline()
71
- self.job_sync = []
72
- num_jobs = 100
73
- start_time = time.time()
74
- asyncio.run(self._test_pipeline(num_jobs))
75
- end_time = time.time()
76
- print(f"Pipeline processed in {end_time - start_time} seconds.")
77
- self.assertEqual(len(self.job_sync), num_jobs)
78
- for i, job in enumerate(self.job_sync):
79
- self.assertEqual(i, job.id)
80
-
81
-
82
- if __name__ == '__main__':
83
- unittest.main()
84
- # test = TestPipeline()
85
- # test.setUp()
86
- # test.test_pipeline()