Spaces:
Build error
Build error
Atin Sakkeer Hussain
commited on
Commit
•
c399026
1
Parent(s):
b5e6f78
Add app.py
Browse files
app.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.cuda
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import mdtex2html
|
5 |
+
import tempfile
|
6 |
+
from PIL import Image
|
7 |
+
import scipy
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
from llama.m2ugen import M2UGen
|
11 |
+
import llama
|
12 |
+
import numpy as np
|
13 |
+
import os
|
14 |
+
import torch
|
15 |
+
import torchaudio
|
16 |
+
import torchvision.transforms as transforms
|
17 |
+
import av
|
18 |
+
import subprocess
|
19 |
+
import librosa
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument(
|
23 |
+
"--model", default="./ckpts/checkpoint.pth", type=str,
|
24 |
+
help="Name of or path to M2UGen pretrained checkpoint",
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--llama_type", default="7B", type=str,
|
28 |
+
help="Type of llama original weight",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--llama_dir", default="/path/to/llama", type=str,
|
32 |
+
help="Path to LLaMA pretrained checkpoint",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--mert_path", default="m-a-p/MERT-v1-330M", type=str,
|
36 |
+
help="Path to MERT pretrained checkpoint",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--vit_path", default="m-a-p/MERT-v1-330M", type=str,
|
40 |
+
help="Path to ViT pretrained checkpoint",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--vivit_path", default="m-a-p/MERT-v1-330M", type=str,
|
44 |
+
help="Path to ViViT pretrained checkpoint",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--knn_dir", default="./ckpts", type=str,
|
48 |
+
help="Path to directory with KNN Index",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
'--music_decoder', default="musicgen", type=str,
|
52 |
+
help='Decoder to use musicgen/audioldm2')
|
53 |
+
|
54 |
+
parser.add_argument(
|
55 |
+
'--music_decoder_path', default="facebook/musicgen-medium", type=str,
|
56 |
+
help='Path to decoder to use musicgen/audioldm2')
|
57 |
+
|
58 |
+
args = parser.parse_args()
|
59 |
+
|
60 |
+
generated_audio_files = []
|
61 |
+
|
62 |
+
llama_type = args.llama_type
|
63 |
+
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
|
64 |
+
llama_tokenzier_path = args.llama_dir
|
65 |
+
model = M2UGen(llama_ckpt_dir, llama_tokenzier_path, args, knn=False, stage=None, load_llama=False)
|
66 |
+
|
67 |
+
print("Loading Model Checkpoint")
|
68 |
+
checkpoint = torch.load(args.model, map_location='cpu')
|
69 |
+
|
70 |
+
new_ckpt = {}
|
71 |
+
for key, value in checkpoint['model'].items():
|
72 |
+
if "generation_model" in key:
|
73 |
+
continue
|
74 |
+
key = key.replace("module.", "")
|
75 |
+
new_ckpt[key] = value
|
76 |
+
|
77 |
+
load_result = model.load_state_dict(new_ckpt, strict=False)
|
78 |
+
assert len(load_result.unexpected_keys) == 0, f"Unexpected keys: {load_result.unexpected_keys}"
|
79 |
+
model.eval()
|
80 |
+
model.to("cuda")
|
81 |
+
#model.generation_model.to("cuda")
|
82 |
+
#model.mert_model.to("cuda")
|
83 |
+
#model.vit_model.to("cuda")
|
84 |
+
#model.vivit_model.to("cuda")
|
85 |
+
|
86 |
+
transform = transforms.Compose(
|
87 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x)])
|
88 |
+
|
89 |
+
|
90 |
+
def postprocess(self, y):
|
91 |
+
if y is None:
|
92 |
+
return []
|
93 |
+
for i, (message, response) in enumerate(y):
|
94 |
+
y[i] = (
|
95 |
+
None if message is None else mdtex2html.convert((message)),
|
96 |
+
None if response is None else mdtex2html.convert(response),
|
97 |
+
)
|
98 |
+
return y
|
99 |
+
|
100 |
+
|
101 |
+
gr.Chatbot.postprocess = postprocess
|
102 |
+
|
103 |
+
|
104 |
+
def parse_text(text, image_path, video_path, audio_path):
|
105 |
+
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
|
106 |
+
outputs = text
|
107 |
+
lines = text.split("\n")
|
108 |
+
lines = [line for line in lines if line != ""]
|
109 |
+
count = 0
|
110 |
+
for i, line in enumerate(lines):
|
111 |
+
if "```" in line:
|
112 |
+
count += 1
|
113 |
+
items = line.split('`')
|
114 |
+
if count % 2 == 1:
|
115 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
116 |
+
else:
|
117 |
+
lines[i] = f'<br></code></pre>'
|
118 |
+
else:
|
119 |
+
if i > 0:
|
120 |
+
if count % 2 == 1:
|
121 |
+
line = line.replace("`", "\`")
|
122 |
+
line = line.replace("<", "<")
|
123 |
+
line = line.replace(">", ">")
|
124 |
+
line = line.replace(" ", " ")
|
125 |
+
line = line.replace("*", "*")
|
126 |
+
line = line.replace("_", "_")
|
127 |
+
line = line.replace("-", "-")
|
128 |
+
line = line.replace(".", ".")
|
129 |
+
line = line.replace("!", "!")
|
130 |
+
line = line.replace("(", "(")
|
131 |
+
line = line.replace(")", ")")
|
132 |
+
line = line.replace("$", "$")
|
133 |
+
lines[i] = "<br>" + line
|
134 |
+
text = "".join(lines) + "<br>"
|
135 |
+
if image_path is not None:
|
136 |
+
text += f'<img src="./file={image_path}" style="display: inline-block;"><br>'
|
137 |
+
outputs = f'<Image>{image_path}</Image> ' + outputs
|
138 |
+
if video_path is not None:
|
139 |
+
text += f' <video controls playsinline height="320" width="240" style="display: inline-block;" src="./file={video_path}"></video6><br>'
|
140 |
+
outputs = f'<Video>{video_path}</Video> ' + outputs
|
141 |
+
if audio_path is not None:
|
142 |
+
text += f'<audio controls playsinline><source src="./file={audio_path}" type="audio/wav"></audio><br>'
|
143 |
+
outputs = f'<Audio>{audio_path}</Audio> ' + outputs
|
144 |
+
# text = text[::-1].replace(">rb<", "", 1)[::-1]
|
145 |
+
text = text[:-len("<br>")].rstrip() if text.endswith("<br>") else text
|
146 |
+
return text, outputs
|
147 |
+
|
148 |
+
|
149 |
+
def save_audio_to_local(audio, sec):
|
150 |
+
global generated_audio_files
|
151 |
+
if not os.path.exists('temp'):
|
152 |
+
os.mkdir('temp')
|
153 |
+
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
|
154 |
+
if args.music_decoder == "audioldm2":
|
155 |
+
scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
|
156 |
+
else:
|
157 |
+
scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
|
158 |
+
generated_audio_files.append(filename)
|
159 |
+
return filename
|
160 |
+
|
161 |
+
|
162 |
+
def parse_reponse(model_outputs, audio_length_in_s):
|
163 |
+
response = ''
|
164 |
+
text_outputs = []
|
165 |
+
for output_i, p in enumerate(model_outputs):
|
166 |
+
if isinstance(p, str):
|
167 |
+
response += p.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
|
168 |
+
response += '<br>'
|
169 |
+
text_outputs.append(p.replace(' '.join([f'[AUD{i}]' for i in range(8)]), ''))
|
170 |
+
elif 'aud' in p.keys():
|
171 |
+
_temp_output = ''
|
172 |
+
for idx, m in enumerate(p['aud']):
|
173 |
+
if isinstance(m, str):
|
174 |
+
response += m.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
|
175 |
+
response += '<br>'
|
176 |
+
_temp_output += m.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
|
177 |
+
else:
|
178 |
+
filename = save_audio_to_local(m, audio_length_in_s)
|
179 |
+
print(filename)
|
180 |
+
_temp_output = f'<Audio>{filename}</Audio> ' + _temp_output
|
181 |
+
response += f'<audio controls playsinline><source src="./file={filename}" type="audio/wav"></audio>'
|
182 |
+
text_outputs.append(_temp_output)
|
183 |
+
else:
|
184 |
+
pass
|
185 |
+
response = response[:-len("<br>")].rstrip() if response.endswith("<br>") else response
|
186 |
+
return response, text_outputs
|
187 |
+
|
188 |
+
|
189 |
+
def reset_user_input():
|
190 |
+
return gr.update(value='')
|
191 |
+
|
192 |
+
|
193 |
+
def reset_dialog():
|
194 |
+
return [], []
|
195 |
+
|
196 |
+
|
197 |
+
def reset_state():
|
198 |
+
global generated_audio_files
|
199 |
+
generated_audio_files = []
|
200 |
+
return None, None, None, None, [], [], []
|
201 |
+
|
202 |
+
|
203 |
+
def upload_image(conversation, chat_history, image_input):
|
204 |
+
input_image = Image.open(image_input.name).resize(
|
205 |
+
(224, 224)).convert('RGB')
|
206 |
+
input_image.save(image_input.name) # Overwrite with smaller image.
|
207 |
+
conversation += [(f'<img src="./file={image_input.name}" style="display: inline-block;">', "")]
|
208 |
+
return conversation, chat_history + [input_image, ""]
|
209 |
+
|
210 |
+
|
211 |
+
def read_video_pyav(container, indices):
|
212 |
+
frames = []
|
213 |
+
container.seek(0)
|
214 |
+
for i, frame in enumerate(container.decode(video=0)):
|
215 |
+
frames.append(frame)
|
216 |
+
chosen_frames = []
|
217 |
+
for i in indices:
|
218 |
+
chosen_frames.append(frames[i])
|
219 |
+
return np.stack([x.to_ndarray(format="rgb24") for x in chosen_frames])
|
220 |
+
|
221 |
+
|
222 |
+
def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
|
223 |
+
converted_len = int(clip_len * frame_sample_rate)
|
224 |
+
if converted_len > seg_len:
|
225 |
+
converted_len = 0
|
226 |
+
end_idx = np.random.randint(converted_len, seg_len)
|
227 |
+
start_idx = end_idx - converted_len
|
228 |
+
indices = np.linspace(start_idx, end_idx, num=clip_len)
|
229 |
+
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
|
230 |
+
return indices
|
231 |
+
|
232 |
+
|
233 |
+
def get_video_length(filename):
|
234 |
+
print("Getting Video Length")
|
235 |
+
result = subprocess.run(["ffprobe", "-v", "error", "-show_entries",
|
236 |
+
"format=duration", "-of",
|
237 |
+
"default=noprint_wrappers=1:nokey=1", filename],
|
238 |
+
stdout=subprocess.PIPE,
|
239 |
+
stderr=subprocess.STDOUT)
|
240 |
+
return int(round(float(result.stdout)))
|
241 |
+
|
242 |
+
|
243 |
+
def get_audio_length(filename):
|
244 |
+
return int(round(librosa.get_duration(path=filename)))
|
245 |
+
|
246 |
+
|
247 |
+
def predict(
|
248 |
+
prompt_input,
|
249 |
+
image_path,
|
250 |
+
audio_path,
|
251 |
+
video_path,
|
252 |
+
chatbot,
|
253 |
+
top_p,
|
254 |
+
temperature,
|
255 |
+
history,
|
256 |
+
modality_cache,
|
257 |
+
audio_length_in_s):
|
258 |
+
global generated_audio_files
|
259 |
+
prompts = [llama.format_prompt(prompt_input)]
|
260 |
+
prompts = [model.tokenizer(x).input_ids for x in prompts]
|
261 |
+
print(image_path, audio_path, video_path)
|
262 |
+
image, audio, video = None, None, None
|
263 |
+
if image_path is not None:
|
264 |
+
image = transform(Image.open(image_path))
|
265 |
+
if audio_path is not None:
|
266 |
+
sample_rate = 24000
|
267 |
+
waveform, sr = torchaudio.load(audio_path)
|
268 |
+
if sample_rate != sr:
|
269 |
+
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
270 |
+
audio = torch.mean(waveform, 0)
|
271 |
+
if video_path is not None:
|
272 |
+
print("Opening Video")
|
273 |
+
container = av.open(video_path)
|
274 |
+
indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
275 |
+
video = read_video_pyav(container=container, indices=indices)
|
276 |
+
|
277 |
+
if len(generated_audio_files) != 0:
|
278 |
+
audio_length_in_s = get_audio_length(generated_audio_files[-1])
|
279 |
+
sample_rate = 24000
|
280 |
+
waveform, sr = torchaudio.load(generated_audio_files[-1])
|
281 |
+
if sample_rate != sr:
|
282 |
+
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
283 |
+
audio = torch.mean(waveform, 0)
|
284 |
+
audio_length_in_s = int(len(audio)//sample_rate)
|
285 |
+
print(f"Audio Length: {audio_length_in_s}")
|
286 |
+
if video_path is not None:
|
287 |
+
audio_length_in_s = get_video_length(video_path)
|
288 |
+
print(f"Video Length: {audio_length_in_s}")
|
289 |
+
if audio_path is not None:
|
290 |
+
audio_length_in_s = get_audio_length(audio_path)
|
291 |
+
generated_audio_files.append(audio_path)
|
292 |
+
print(f"Audio Length: {audio_length_in_s}")
|
293 |
+
|
294 |
+
print(image, video, audio)
|
295 |
+
response = model.generate(prompts, audio, image, video, 200, temperature, top_p,
|
296 |
+
audio_length_in_s=audio_length_in_s)
|
297 |
+
print(response)
|
298 |
+
response_chat, response_outputs = parse_reponse(response, audio_length_in_s)
|
299 |
+
print('text_outputs: ', response_outputs)
|
300 |
+
user_chat, user_outputs = parse_text(prompt_input, image_path, video_path, audio_path)
|
301 |
+
chatbot.append((user_chat, response_chat))
|
302 |
+
history.append((user_outputs, ''.join(response_outputs).replace('\n###', '')))
|
303 |
+
return chatbot, history, modality_cache, None, None, None,
|
304 |
+
|
305 |
+
|
306 |
+
with gr.Blocks() as demo:
|
307 |
+
gr.HTML("""
|
308 |
+
<h1 align="center" style=" display: flex; flex-direction: row; justify-content: center; font-size: 25pt; "><img src='./file=bot.png' width="50" height="50" style="margin-right: 10px;">M<sup style="line-height: 200%; font-size: 60%">2</sup>UGen</h1>
|
309 |
+
<h3>This is the demo page of M<sup>2</sup>UGen, a Multimodal LLM capable of Music Understanding and Generation!</h3>
|
310 |
+
<div style="display: flex;"><a href='https://arxiv.org/pdf/2311.11255.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
|
311 |
+
""")
|
312 |
+
|
313 |
+
with gr.Row():
|
314 |
+
with gr.Column(scale=0.7, min_width=500):
|
315 |
+
with gr.Row():
|
316 |
+
chatbot = gr.Chatbot(label='M2UGen Chatbot', avatar_images=(
|
317 |
+
(os.path.join(os.path.dirname(__file__), 'user.png')),
|
318 |
+
(os.path.join(os.path.dirname(__file__), "bot.png")))).style(height=440)
|
319 |
+
|
320 |
+
with gr.Tab("User Input"):
|
321 |
+
with gr.Row(scale=3):
|
322 |
+
user_input = gr.Textbox(label="Text", placeholder="Key in something here...", lines=3)
|
323 |
+
with gr.Row(scale=3):
|
324 |
+
with gr.Column(scale=1):
|
325 |
+
# image_btn = gr.UploadButton("🖼️ Upload Image", file_types=["image"])
|
326 |
+
image_path = gr.Image(type="filepath",
|
327 |
+
label="Image") # .style(height=200) # <PIL.Image.Image image mode=RGB size=512x512 at 0x7F6E06738D90>
|
328 |
+
with gr.Column(scale=1):
|
329 |
+
audio_path = gr.Audio(type='filepath') # .style(height=200)
|
330 |
+
with gr.Column(scale=1):
|
331 |
+
video_path = gr.Video() # .style(height=200) # , value=None, interactive=True
|
332 |
+
with gr.Column(scale=0.3, min_width=300):
|
333 |
+
with gr.Group():
|
334 |
+
with gr.Accordion('Text Advanced Options', open=True):
|
335 |
+
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
|
336 |
+
temperature = gr.Slider(0, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
|
337 |
+
with gr.Accordion('Audio Advanced Options', open=False):
|
338 |
+
audio_length_in_s = gr.Slider(5, 30, value=30, step=1, label="The audio length in seconds",
|
339 |
+
interactive=True)
|
340 |
+
with gr.Tab("Operation"):
|
341 |
+
with gr.Row(scale=1):
|
342 |
+
submitBtn = gr.Button(value="Submit & Run", variant="primary")
|
343 |
+
with gr.Row(scale=1):
|
344 |
+
emptyBtn = gr.Button("Clear History")
|
345 |
+
|
346 |
+
history = gr.State([])
|
347 |
+
modality_cache = gr.State([])
|
348 |
+
|
349 |
+
submitBtn.click(
|
350 |
+
predict, [
|
351 |
+
user_input,
|
352 |
+
image_path,
|
353 |
+
audio_path,
|
354 |
+
video_path,
|
355 |
+
chatbot,
|
356 |
+
top_p,
|
357 |
+
temperature,
|
358 |
+
history,
|
359 |
+
modality_cache,
|
360 |
+
audio_length_in_s
|
361 |
+
], [
|
362 |
+
chatbot,
|
363 |
+
history,
|
364 |
+
modality_cache,
|
365 |
+
image_path,
|
366 |
+
audio_path,
|
367 |
+
video_path
|
368 |
+
],
|
369 |
+
show_progress=True
|
370 |
+
)
|
371 |
+
|
372 |
+
submitBtn.click(reset_user_input, [], [user_input])
|
373 |
+
emptyBtn.click(reset_state, outputs=[
|
374 |
+
image_path,
|
375 |
+
audio_path,
|
376 |
+
video_path,
|
377 |
+
chatbot,
|
378 |
+
history,
|
379 |
+
modality_cache
|
380 |
+
], show_progress=True)
|
381 |
+
|
382 |
+
demo.queue().launch(share=True, inbrowser=True, server_name='0.0.0.0', server_port=24000)
|
bot.png
ADDED
user.png
ADDED