File size: 7,691 Bytes
330bd18
 
 
925a881
 
a5ee5dc
 
f7f39bd
330bd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f39bd
129c500
b80895e
330bd18
8071afe
330bd18
8071afe
 
 
 
 
 
 
 
330bd18
 
8071afe
330bd18
 
 
 
 
 
 
8ed40d7
 
 
 
 
 
8b0046a
8ed40d7
 
 
 
 
 
 
 
966d40f
 
 
 
 
 
 
 
 
 
 
 
330bd18
 
 
 
8071afe
16c6824
 
 
 
 
ef3e243
8071afe
a3aa488
8b0046a
8071afe
8b0046a
8071afe
8ed40d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330bd18
966d40f
 
 
 
 
 
 
 
4ba2ab9
f7f39bd
 
330bd18
 
 
 
 
966d40f
 
330bd18
 
 
 
f92ca2e
330bd18
966d40f
330bd18
966d40f
f5a084e
966d40f
b9d404b
 
 
 
 
 
 
 
 
 
4ba2ab9
b9d404b
e39cf31
4ba2ab9
8071afe
330bd18
 
03648b3
8071afe
4ba2ab9
c66a07c
 
 
330bd18
1cf9f4d
4ba2ab9
fd27c98
 
 
8071afe
 
 
 
 
 
 
4ba2ab9
330bd18
 
 
 
 
f7f39bd
 
 
f5a084e
330bd18
f7f39bd
 
330bd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925a881
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# import base64
# import pathlib
# import tempfile
import os
os.system("python -m unidic download")
import nltk
nltk.download('averaged_perceptron_tagger_eng')
import gradio as gr

# recorder_js = pathlib.Path('recorder.js').read_text()
# main_js = pathlib.Path('main.js').read_text()
# record_button_js = pathlib.Path('record_button.js').read_text().replace('let recorder_js = null;', recorder_js).replace(
#     'let main_js = null;', main_js)


# def save_base64_video(base64_string):
#     base64_video = base64_string
#     video_data = base64.b64decode(base64_video)
#     with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
#         temp_filename = temp_file.name
#         temp_file.write(video_data)
#     print(f"Temporary MP4 file saved as: {temp_filename}")
#     return temp_filename
# import os

# os.system('python -m unidic download')
import numpy as np
from VAD.vad_iterator import VADIterator
import torch
import librosa
# from mlx_lm import load, stream_generate, generate
from LLM.chat import Chat
# from lightning_whisper_mlx import LightningWhisperMLX
from transformers import (
    AutoModelForSpeechSeq2Seq,
    AutoProcessor,
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
)
from melo.api import TTS

# LM_model, LM_tokenizer = load("mlx-community/SmolLM-360M-Instruct")
chat = Chat(2)
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words."})
user_role = "user"

tts_model = TTS(language="EN_NEWEST", device="auto")
speaker_id = tts_model.hps.data.spk2id["EN-Newest"]
blocksize = 512
tts_model.tts_to_file("text", speaker_id, quiet=True)
dummy_input = torch.randn(
        (1, 80, 3000),
        dtype="float16",
        device="cuda",
)
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device="cuda")
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
transcriber({"sampling_rate": sr, "raw": dummy_input})["text"]
end_event.record()
torch.cuda.synchronize()

def int2float(sound):
    """
    Taken from https://github.com/snakers4/silero-vad
    """

    abs_max = np.abs(sound).max()
    sound = sound.astype("float32")
    if abs_max > 0:
        sound *= 1 / 32768
    sound = sound.squeeze()  # depends on the use case
    return sound

text_str=""
audio_output = None
min_speech_ms=500
max_speech_ms=float("inf")
# ASR_model = LightningWhisperMLX(model="distil-large-v3", batch_size=6, quant=None)
# ASR_processor = AutoProcessor.from_pretrained("distil-whisper/distil-large-v3")
# ASR_model = AutoModelForSpeechSeq2Seq.from_pretrained(
#     "distil-whisper/distil-large-v3",
#     torch_dtype="float16",
# ).to("cpu")
LM_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
LM_model = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceTB/SmolLM-360M-Instruct", torch_dtype="float16", trust_remote_code=True
).to("cuda")
LM_pipe = pipeline(
    "text-generation", model=LM_model, tokenizer=LM_tokenizer, device="cuda"
)
dummy_input_text = "Write me a poem about Machine Learning."
dummy_chat = [{"role": "user", "content": dummy_input_text}]
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
LM_pipe(
    dummy_chat,
    max_new_tokens=32,
    min_new_tokens=0,
    temperature=0.0,
    do_sample=False,
)
end_event.record()
torch.cuda.synchronize()
vad_model, _ = torch.hub.load("snakers4/silero-vad:v4.0", "silero_vad")
vad_iterator = VADIterator(
    vad_model,
    threshold=0.3,
    sampling_rate=16000,
    min_silence_duration_ms=250,
    speech_pad_ms=500,
)

import time
def transcribe(stream, new_chunk):
    sr, y = new_chunk
    global text_str
    global chat
    global user_role
    global audio_output
    
    audio_int16 = np.frombuffer(y, dtype=np.int16)
    audio_float32 = int2float(audio_int16)
    audio_float32=librosa.resample(audio_float32, orig_sr=sr, target_sr=16000)
    sr=16000
    print(sr)
    print(audio_float32.shape)
    vad_output = vad_iterator(torch.from_numpy(audio_float32))
    
    if vad_output is not None and len(vad_output) != 0:
        print("VAD: end of speech detected")
        array = torch.cat(vad_output).cpu().numpy()
        duration_ms = len(array) / sr * 1000
        if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)):
            # input_features = ASR_processor(
            #     array, sampling_rate=16000, return_tensors="pt"
            # ).input_features
            # print(input_features)
            # input_features = input_features.to("cpu", dtype=getattr(torch, "float16"))
            # pred_ids = ASR_model.generate(input_features, max_new_tokens=128, min_new_tokens=0, num_beams=1, return_timestamps=False,task="transcribe",language="en")
            # print(pred_ids)
            # prompt = ASR_processor.batch_decode(
            #     pred_ids, skip_special_tokens=True, decode_with_timestamps=False
            # )[0]
            start_time = time.time()
            prompt=transcriber({"sampling_rate": sr, "raw": array})["text"]
            print(prompt)
            print("--- %s seconds ---" % (time.time() - start_time))
            # prompt=ASR_model.transcribe(array)["text"].strip()
            chat.append({"role": user_role, "content": prompt})
            chat_messages = chat.to_list()
            output=LM_pipe(
                chat_messages,
                max_new_tokens=32,
                min_new_tokens=0,
                temperature=0.0,
                do_sample=False,
            )
            print(output)
            print("--- %s seconds ---" % (time.time() - start_time))
            generated_text = output[0]['generated_text'][-1]["content"]
            print(generated_text)
            # torch.mps.empty_cache()
    
            chat.append({"role": "assistant", "content": generated_text})
            text_str=generated_text
            # import pdb;pdb.set_trace()
            audio_chunk = tts_model.tts_to_file(text_str, speaker_id, quiet=True)
            audio_chunk = (audio_chunk * 32768).astype(np.int16)
            audio_output=(44100, audio_chunk)
            print("--- %s seconds ---" % (time.time() - start_time))
    # else:
    #     audio_output=None
    text_str1=text_str
    
    return stream, text_str1, audio_output

demo = gr.Interface(
    transcribe,
    ["state", gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))],
    ["state", "text", gr.Audio(label="Output", autoplay=True)],
    live=True,
)
# with demo:
#     start_button = gr.Button("Record Screen 🔴")
#     video_component = gr.Video(interactive=True, show_share_button=True, include_audio=True)


#     def toggle_button_label(returned_string):
#         if returned_string.startswith("Record"):
#             return gr.Button(value="Stop Recording ⚪"), None
#         else:
#             try:
#                 temp_filename = save_base64_video(returned_string)
#             except Exception as e:
#                 return gr.Button(value="Record Screen 🔴"), gr.Warning(f'Failed to convert video to mp4:\n{e}')
#             return gr.Button(value="Record Screen 🔴"), gr.Video(value=temp_filename, interactive=True,
#                                                                 show_share_button=True)
#     start_button.click(toggle_button_label, start_button, [start_button, video_component], js=record_button_js)
demo.launch("share=True")