Vasudevakrishna commited on
Commit
fd8d20e
1 Parent(s): 48ccb79

App added.

Browse files
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import whisperx
3
+ import gradio as gr
4
+ from peft import PeftModel
5
+ from configs import get_config_phase2
6
+ from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM
7
+
8
+ config = get_config_phase2()
9
+
10
+ clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
11
+
12
+ base_model = AutoModelForCausalLM.from_pretrained(
13
+ config.get("phi2_model_name"),
14
+ low_cpu_mem_usage=True,
15
+ return_dict=True,
16
+ torch_dtype=torch.float32,
17
+ trust_remote_code=True
18
+ )
19
+
20
+
21
+ ckpts = "ckpts/Qlora_adaptor/"
22
+ phi2_model = PeftModel.from_pretrained(base_model, ckpts)
23
+ phi2_model = phi2_model.merge_and_unload().to(config.get("device"))
24
+
25
+ projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed"))
26
+ projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
27
+
28
+ # tokenizer
29
+ tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
30
+ processor = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)
31
+
32
+ audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32")
33
+
34
+
35
+ def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
36
+ batch_size = 1
37
+ start_iq = tokenizer.encode("<iQ>")
38
+ end_iq = tokenizer.encode("</iQ>")
39
+ start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
40
+ end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
41
+ start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
42
+ end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
43
+
44
+ inputs_embeddings = []
45
+ inputs_embeddings.append(start_iq_embeds)
46
+
47
+ predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device'))
48
+
49
+ if img is not None:
50
+ images = processor(images=img, return_tensors="pt")['pixel_values'].to(config.get("device"))
51
+ images = {'pixel_values': images.to(config.get("device"))}
52
+ clip_outputs = clip_model(**images)
53
+ # remove cls token
54
+ images = clip_outputs.last_hidden_state[:, 1:, :]
55
+ image_embeddings = projection_layer(images).to(torch.float32)
56
+ inputs_embeddings.append(image_embeddings)
57
+
58
+ if aud is not None:
59
+ trans = audio_model.transcribe(aud)
60
+ audio_res = ""
61
+ for seg in trans['segments']:
62
+ audio_res += seg['text']
63
+ audio_res = audio_res.strip()
64
+ audio_tokens = tokenizer(audio_res,return_tensors="pt", return_attention_mask=False)['input_ids']
65
+ audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device")))
66
+ inputs_embeddings.append(audio_embeds)
67
+
68
+ if q!='':
69
+ ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
70
+ q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device")))
71
+ inputs_embeddings.append(q_embeds)
72
+
73
+ inputs_embeddings.append(end_iq_embeds)
74
+ # Combine embeddings
75
+ combined_embeds = torch.cat(inputs_embeddings, dim=1)
76
+ predicted_caption = phi2_model.generate(inputs_embeds=combined_embeds,
77
+ max_new_tokens=max_tokens,
78
+ return_dict_in_generate = True)
79
+
80
+ # for pos in range(max_tokens - 1):
81
+ # model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
82
+ # print(model_output_logits.shape)
83
+ # predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
84
+ # predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
85
+ # predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
86
+ # print(predicted_caption)
87
+ # next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
88
+ # combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
89
+ # print("combined_embeds", combined_embeds.shape)
90
+ # predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
91
+ predicted_captions_decoded =tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
92
+ predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
93
+ return predicted_captions_decoded
94
+
95
+ # List of examples (image, audio, question, max_tokens)
96
+ examples = [
97
+ ["./examples/Image_1.jpg", None, "Explain image?", 20],
98
+ ["./examples/Image_2.jpg", None, "How many animals are there in image?", 10],
99
+ ["./examples/Image_3.jpg", None, "What is in the image?", 20],
100
+ ["./examples/Image_4.jpg", None, "What represents this Image?", 20],
101
+ ]
102
+
103
+ with gr.Blocks() as demo:
104
+
105
+ gr.Markdown(
106
+ """
107
+ # MultiModelLLM
108
+ Multimodel GPT with inputs as Image, Audio, Text with output as Text.
109
+ """
110
+ )
111
+
112
+ with gr.Row():
113
+ with gr.Column():
114
+ image = gr.Image(label='Image', type="pil", value=None)
115
+ audio_q = gr.Audio(label="Audio Question", value=None, sources=['microphone', 'upload'], type='filepath')
116
+ question = gr.Text(label ='Question?', value=None)
117
+ max_tokens = gr.Slider(1, 50, value=10, step=1, label="Max tokens")
118
+ with gr.Row():
119
+ answer = gr.Text(label ='Answer')
120
+ with gr.Row():
121
+ submit = gr.Button("Submit")
122
+ submit.click(generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=[answer])
123
+ clear_btn = gr.ClearButton([image, audio_q, question, max_tokens, answer])
124
+ # Add examples
125
+ gr.Examples(examples=examples, inputs=[image, audio_q, question, max_tokens], outputs=answer)
126
+
127
+
128
+ if __name__ == "__main__":
129
+
130
+ demo.launch(share=True, debug=True)
examples/Image_1.jpg ADDED
examples/Image_2.jpg ADDED
examples/Image_3.jpg ADDED
examples/Images_4.jpg ADDED