czd358121692 commited on
Commit
b8336b2
β€’
1 Parent(s): 2d85f4c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +398 -0
app.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import gradio as gr
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torchaudio
8
+ from einops import rearrange
9
+ import psutil
10
+ import humanize
11
+ import spaces
12
+ from transformers import (
13
+ AutoProcessor,
14
+ AutoModelForVision2Seq,
15
+ pipeline
16
+ )
17
+ from huggingface_hub import scan_cache_dir
18
+ from stable_audio_tools import get_pretrained_model
19
+ from stable_audio_tools.inference.generation import generate_diffusion_cond
20
+
21
+ # Cache setup code remains same
22
+ CACHE_ROOT = '/tmp'
23
+ os.environ['HF_HOME'] = CACHE_ROOT
24
+ os.environ['HUGGINGFACE_HUB_CACHE'] = os.path.join(CACHE_ROOT, 'hub')
25
+ os.environ['XDG_CACHE_HOME'] = os.path.join(CACHE_ROOT, 'cache')
26
+
27
+ # Global model variables
28
+ kosmos_model = None
29
+ kosmos_processor = None
30
+ zephyr_pipe = None
31
+ audio_model = None
32
+ audio_config = None
33
+
34
+ def initialize_models():
35
+ global kosmos_model, kosmos_processor, zephyr_pipe, audio_model, audio_config
36
+ try:
37
+ print("Loading Kosmos-2...")
38
+ kosmos_model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
39
+ kosmos_processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
40
+ if torch.cuda.is_available():
41
+ kosmos_model = kosmos_model.to("cuda")
42
+ except Exception as e:
43
+ print(f"Error loading Kosmos-2: {e}")
44
+ raise
45
+ try:
46
+ print("Loading Zephyr...")
47
+ zephyr_pipe = pipeline(
48
+ "text-generation",
49
+ model="HuggingFaceH4/zephyr-7b-beta",
50
+ torch_dtype=torch.bfloat16,
51
+ device_map="auto"
52
+ )
53
+ except Exception as e:
54
+ print(f"Error loading Zephyr: {e}")
55
+ raise
56
+
57
+ try:
58
+ print("Loading Stable Audio...")
59
+ audio_model, audio_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
60
+ if torch.cuda.is_available():
61
+ audio_model = audio_model.to("cuda")
62
+ except Exception as e:
63
+ print(f"Error loading Stable Audio: {e}")
64
+ raise
65
+
66
+ def get_caption(image_in):
67
+ if not image_in:
68
+ raise gr.Error("Please provide an image")
69
+
70
+ try:
71
+ # Convert image to PIL if needed
72
+ if isinstance(image_in, str):
73
+ image = Image.open(image_in)
74
+ elif isinstance(image_in, np.ndarray):
75
+ image = Image.fromarray(image_in)
76
+
77
+ if image.mode != "RGB":
78
+ image = image.convert("RGB")
79
+
80
+ prompt = "<grounding>Describe this image in detail without names:"
81
+ inputs = kosmos_processor(text=prompt, images=image, return_tensors="pt")
82
+
83
+ device = next(kosmos_model.parameters()).device
84
+ inputs = {k: v.to(device) for k, v in inputs.items()}
85
+
86
+ with torch.no_grad():
87
+ generated_ids = kosmos_model.generate(
88
+ pixel_values=inputs["pixel_values"],
89
+ input_ids=inputs["input_ids"],
90
+ attention_mask=inputs["attention_mask"],
91
+ image_embeds_position_mask=inputs["image_embeds_position_mask"],
92
+ max_new_tokens=128,
93
+ )
94
+
95
+ generated_text = kosmos_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
96
+ processed_text, _ = kosmos_processor.post_process_generation(generated_text)
97
+
98
+ # Clean up output
99
+ for prefix in ["Describe this image in detail without names", "An image of", "<grounding>"]:
100
+ processed_text = processed_text.replace(prefix, "").strip()
101
+
102
+ return processed_text
103
+
104
+ except Exception as e:
105
+ raise gr.Error(f"Image caption generation failed: {str(e)}")
106
+
107
+ # Continuing from previous code...
108
+
109
+ def get_musical_prompt(user_prompt, chosen_model):
110
+ if not user_prompt:
111
+ raise gr.Error("No image caption provided")
112
+
113
+ try:
114
+ standard_sys = """
115
+ You are a musician AI who specializes in translating architectural spaces into musical experiences. Your job is to create concise musical descriptions that capture the essence of architectural photographs.
116
+
117
+ Consider these elements in your composition:
118
+ - Spatial Experience: expansive/intimate spaces, layered forms, acoustical qualities
119
+ - Materials & Textures: metallic, glass, concrete translated into instrumental textures
120
+ - Musical Elements: blend of classical structure and jazz improvisation
121
+ - Orchestration: symphonic layers, solo instruments, or ensemble variations
122
+ - Soundscapes: environmental depth and spatial audio qualities
123
+
124
+ Respond immediately with a single musical prompt. No explanation, just the musical description.
125
+ """
126
+ instruction = f"""
127
+ <|system|>
128
+ {standard_sys}</s>
129
+ <|user|>
130
+ {user_prompt}</s>
131
+ """
132
+
133
+ outputs = zephyr_pipe(
134
+ instruction.strip(),
135
+ max_new_tokens=256,
136
+ do_sample=True,
137
+ temperature=0.75,
138
+ top_k=50,
139
+ top_p=0.92
140
+ )
141
+
142
+ musical_prompt = outputs[0]["generated_text"]
143
+
144
+ # Clean system message and tokens
145
+ cleaned_prompt = musical_prompt.replace("<|system|>", "").replace("</s>", "").replace("<|user|>", "").replace("<|assistant|>", "")
146
+
147
+ lines = cleaned_prompt.split('\n')
148
+ relevant_lines = [line.strip() for line in lines
149
+ if line.strip() and
150
+ not line.startswith('-') and
151
+ not line.startswith('Example') and
152
+ not line.startswith('Instructions') and
153
+ not line.startswith('Consider') and
154
+ not line.startswith('Incorporate')]
155
+
156
+ if relevant_lines:
157
+ final_prompt = relevant_lines[-1].strip()
158
+ if len(final_prompt) >= 10:
159
+ return final_prompt
160
+
161
+ raise ValueError("Could not extract valid musical prompt")
162
+
163
+ except Exception as e:
164
+ print(f"Error in get_musical_prompt: {str(e)}")
165
+ return "Ambient orchestral composition with piano and strings, creating a contemplative atmosphere"
166
+
167
+ def get_stable_audio_open(prompt, seconds_total=47, steps=100, cfg_scale=7):
168
+ try:
169
+ torch.cuda.empty_cache() # Clear GPU memory before generation
170
+
171
+ device = "cuda" if torch.cuda.is_available() else "cpu"
172
+ sample_rate = audio_config["sample_rate"]
173
+ sample_size = audio_config["sample_size"]
174
+
175
+ # Set up conditioning
176
+ conditioning = [{
177
+ "prompt": prompt,
178
+ "seconds_start": 0,
179
+ "seconds_total": seconds_total
180
+ }]
181
+
182
+ # Generate audio
183
+ output = generate_diffusion_cond(
184
+ audio_model,
185
+ steps=steps,
186
+ cfg_scale=cfg_scale,
187
+ conditioning=conditioning,
188
+ sample_size=sample_size,
189
+ sigma_min=0.3,
190
+ sigma_max=500,
191
+ sampler_type="dpmpp-3m-sde",
192
+ device=device
193
+ )
194
+
195
+ output = rearrange(output, "b d n -> d (b n)")
196
+ output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
197
+
198
+ # Save to temporary file
199
+ output_path = os.path.join(CACHE_ROOT, f"output_{os.urandom(8).hex()}.wav")
200
+ torchaudio.save(output_path, output, sample_rate)
201
+
202
+ return output_path
203
+
204
+ except Exception as e:
205
+ torch.cuda.empty_cache() # Clear GPU memory on error
206
+ raise gr.Error(f"Music generation failed: {str(e)}")
207
+
208
+ def check_api():
209
+ try:
210
+ if all([kosmos_model, kosmos_processor, zephyr_pipe, audio_model, audio_config]):
211
+ return "Orchestra ready. 🎹 πŸ‘οΈ 🎼"
212
+ return "Orchestra is tuning..."
213
+ except Exception:
214
+ return "Orchestra is tuning..."
215
+
216
+ # Rest of the utility functions remain the same
217
+ def get_storage_info():
218
+ disk_usage = psutil.disk_usage('/tmp')
219
+ used = humanize.naturalsize(disk_usage.used)
220
+ total = humanize.naturalsize(disk_usage.total)
221
+ percent = disk_usage.percent
222
+ return f"Storage: {used}/{total} ({percent}% used)"
223
+
224
+ def smart_cleanup():
225
+ try:
226
+ cache_info = scan_cache_dir()
227
+ seen_models = {}
228
+
229
+ for repo in cache_info.repos:
230
+ model_id = repo.repo_id
231
+ if model_id not in seen_models:
232
+ seen_models[model_id] = []
233
+ seen_models[model_id].append(repo)
234
+
235
+ for model_id, repos in seen_models.items():
236
+ if len(repos) > 1:
237
+ repos.sort(key=lambda x: x.last_modified, reverse=True)
238
+ for repo in repos[1:]:
239
+ shutil.rmtree(repo.repo_path)
240
+ print(f"Removed duplicate cache for {model_id}")
241
+
242
+ return get_storage_info()
243
+
244
+ except Exception as e:
245
+ print(f"Error during cleanup: {e}")
246
+ return "Cleanup error occurred"
247
+
248
+ def get_image_examples():
249
+ image_dir = "images"
250
+ image_extensions = ['.jpg', '.jpeg', '.png']
251
+ examples = []
252
+
253
+ if not os.path.exists(image_dir):
254
+ print(f"Warning: Image directory '{image_dir}' not found")
255
+ return []
256
+
257
+ for filename in os.listdir(image_dir):
258
+ if any(filename.lower().endswith(ext) for ext in image_extensions):
259
+ examples.append([os.path.join(image_dir, filename)])
260
+
261
+ return examples
262
+
263
+ def infer(image_in, api_status):
264
+ if image_in is None:
265
+ raise gr.Error("Please provide an image of architecture")
266
+
267
+ if api_status == "Orchestra is tuning...":
268
+ raise gr.Error("The model is still tuning, please try again later")
269
+
270
+ try:
271
+ gr.Info("🎭 Finding a poetry in form and light...")
272
+ user_prompt = get_caption(image_in)
273
+
274
+ gr.Info("🎼 Weaving into melody...")
275
+ musical_prompt = get_musical_prompt(user_prompt, "Stable Audio Open")
276
+
277
+ gr.Info("🎻 Breathing life into notes...")
278
+ music_o = get_stable_audio_open(musical_prompt)
279
+
280
+ torch.cuda.empty_cache() # Clear GPU memory after generation
281
+ return gr.update(value=musical_prompt, interactive=True), gr.update(visible=True), music_o
282
+ except Exception as e:
283
+ torch.cuda.empty_cache()
284
+ raise gr.Error(f"Generation failed: {str(e)}")
285
+
286
+ def retry(caption):
287
+ musical_prompt = caption
288
+ gr.Info("🎹 Refreshing with a new vibe...")
289
+ music_o = get_stable_audio_open(musical_prompt)
290
+ return music_o
291
+
292
+ # UI Definition
293
+ demo_title = "Musical Toy for Frank"
294
+ description = "A humble attempt to hear Architecture through Music"
295
+
296
+ css = """
297
+ #col-container {
298
+ margin: 0 auto;
299
+ max-width: 980px;
300
+ text-align: left;
301
+ }
302
+ #inspi-prompt textarea {
303
+ font-size: 20px;
304
+ line-height: 24px;
305
+ font-weight: 600;
306
+ }
307
+ """
308
+
309
+ with gr.Blocks(css=css) as demo:
310
+ # UI layout remains exactly the same as in your original code
311
+ with gr.Column(elem_id="col-container"):
312
+ gr.HTML(f"""
313
+ <h2 style="text-align: center;">{demo_title}</h2>
314
+ <p style="text-align: center;">{description}</p>
315
+ """)
316
+ with gr.Row():
317
+ with gr.Column():
318
+ image_in = gr.Image(
319
+ label="Inspire us:",
320
+ type="filepath",
321
+ elem_id="image-in"
322
+ )
323
+ gr.Examples(
324
+ examples=get_image_examples(),
325
+ fn=infer,
326
+ inputs=[image_in],
327
+ examples_per_page=5,
328
+ label="β™ͺ β™ͺ ..."
329
+ )
330
+ submit_btn = gr.Button("Listen to it...")
331
+ with gr.Column():
332
+ check_status = gr.Textbox(
333
+ label="Status",
334
+ interactive=False,
335
+ value=check_api()
336
+ )
337
+ caption = gr.Textbox(
338
+ label="Explanation & Inspiration...",
339
+ interactive=False,
340
+ elem_id="inspi-prompt"
341
+ )
342
+ retry_btn = gr.Button("🎲", visible=False)
343
+ result = gr.Audio(
344
+ label="Music"
345
+ )
346
+
347
+ # Credits section remains the same
348
+ gr.HTML("""
349
+ <div style="margin-top: 40px; padding: 20px; border-top: 1px solid #ddd;">
350
+ <!-- Your existing credits HTML -->
351
+ </div>
352
+ """)
353
+
354
+ # Event handlers
355
+ demo.load(
356
+ fn=check_api,
357
+ outputs=check_status,
358
+ )
359
+
360
+ retry_btn.click(
361
+ fn=retry,
362
+ inputs=[caption],
363
+ outputs=[result]
364
+ )
365
+
366
+ submit_btn.click(
367
+ fn=infer,
368
+ inputs=[
369
+ image_in,
370
+ check_status
371
+ ],
372
+ outputs=[
373
+ caption,
374
+ retry_btn,
375
+ result
376
+ ]
377
+ )
378
+
379
+ with gr.Column():
380
+ storage_info = gr.Textbox(label="Storage Info", value=get_storage_info())
381
+ cleanup_btn = gr.Button("Smart Cleanup")
382
+
383
+ cleanup_btn.click(
384
+ fn=smart_cleanup,
385
+ outputs=storage_info
386
+ )
387
+
388
+ if __name__ == "__main__":
389
+ print("Initializing models...")
390
+ initialize_models()
391
+ print("Models initialized successfully")
392
+
393
+ demo.queue(max_size=16).launch(
394
+ show_api=False,
395
+ show_error=True,
396
+ server_name="0.0.0.0",
397
+ server_port=7860,
398
+ )