Zimabluee commited on
Commit
6b9fb43
1 Parent(s): 9ef2a14

add gradio UI

Browse files
Files changed (1) hide show
  1. poc_app.py +37 -10
poc_app.py CHANGED
@@ -1,25 +1,31 @@
1
  # external imports
2
  import time
 
3
 
4
  # local imports
5
  from blip_image_caption_large import Blip_Image_Caption_Large
6
  from phi3_mini_4k_instruct import Phi3_Mini_4k_Instruct
7
  from musicgen_small import Musicgen_Small
8
 
9
- def main():
 
10
  # test image captioning
11
  image_caption_start_time = time.time()
12
  image_caption_model = Blip_Image_Caption_Large()
13
- test_caption = image_caption_model.caption_image_local_pipeline("data/test3.jpg")
 
 
14
  print(test_caption)
 
15
  image_caption_end_time = time.time()
16
 
17
  # test text generation
18
  text_generation_start_time = time.time()
19
  text_generation_model = Phi3_Mini_4k_Instruct()
20
-
21
  #TODO: move this to a config file
22
- text_generation_model.local_pipeline.model.config.max_length = 200
 
23
 
24
  #TODO: move system prompt somewhere else, allow for genre override
25
  messages = [
@@ -44,10 +50,31 @@ def main():
44
  music_generation_duration = music_generation_end_time - music_generation_start_time
45
  total_duration = music_generation_end_time - image_caption_start_time
46
 
47
- # output durations
48
- print(f"Image Captioning Duration: {image_caption_duration}")
49
- print(f"Text Generation Duration: {text_generation_duration}")
50
- print(f"Music Generation Duration: {music_generation_duration}")
51
- print(f"Total Duration: {total_duration}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- main()
 
1
  # external imports
2
  import time
3
+ import gradio as gr
4
 
5
  # local imports
6
  from blip_image_caption_large import Blip_Image_Caption_Large
7
  from phi3_mini_4k_instruct import Phi3_Mini_4k_Instruct
8
  from musicgen_small import Musicgen_Small
9
 
10
+ #image_to_music function
11
+ def image_to_music(image_path):
12
  # test image captioning
13
  image_caption_start_time = time.time()
14
  image_caption_model = Blip_Image_Caption_Large()
15
+
16
+ test_caption = image_caption_model.caption_image_local_pipeline(image_path)
17
+
18
  print(test_caption)
19
+
20
  image_caption_end_time = time.time()
21
 
22
  # test text generation
23
  text_generation_start_time = time.time()
24
  text_generation_model = Phi3_Mini_4k_Instruct()
25
+
26
  #TODO: move this to a config file
27
+ text_generation_model.local_pipeline.model.config.max_new_tokens = 200
28
+
29
 
30
  #TODO: move system prompt somewhere else, allow for genre override
31
  messages = [
 
50
  music_generation_duration = music_generation_end_time - music_generation_start_time
51
  total_duration = music_generation_end_time - image_caption_start_time
52
 
53
+ # output generated_text, audio and duration to gradio
54
+ return (test_caption[0]["generated_text"], test_text[-1]['generated_text'][-1]['content'], "data/musicgen_out.wav",
55
+ f"Image Captioning Duration: {image_caption_duration} sec",
56
+ f"Text Generation Duration: {text_generation_duration} sec",
57
+ f"Music Generation Duration: {music_generation_duration} sec",
58
+ f"Total Duration: {total_duration} sec")
59
+
60
+ # Gradio UI
61
+ def gradio():
62
+ # Define Gradio Interface, information from (https://www.gradio.app/docs/chatinterface)
63
+ with gr.Blocks() as demo:
64
+ gr.Markdown("<h1 style='text-align: center;'> ⛺ Image to Music Generator 🎼</h1>")
65
+ image_input = gr.Image(type="filepath", label="Upload Image")
66
+ with gr.Row():
67
+ caption_output = gr.Textbox(label="Image Caption")
68
+ music_description_output = gr.Textbox(label="Music Description")
69
+ durations = gr.Textbox(label="Processing Times", interactive=False, placeholder="Time statistics will appear here")
70
+
71
+ music_output = gr.Audio(label="Generated Music")
72
+ # Button to trigger the process
73
+ generate_button = gr.Button("Generate Music")
74
+
75
+ generate_button.click(fn=image_to_music, inputs=[image_input], outputs=[caption_output, music_description_output, music_output, durations])
76
+
77
+ # Launch Gradio app
78
+ demo.launch()
79
 
80
+ gradio()