import gradio as gui import peft from peft import LoraConfig from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor import torch from peft import PeftModel import torch.nn as nn import whisper import os clip_model_name = "openai/clip-vit-base-patch32" phi_model_name = "microsoft/phi-2" tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True) processor = AutoProcessor.from_pretrained(clip_model_name) tokenizer.pad_token = tokenizer.eos_token IMAGE_TOKEN_ID = 23893 # token for word comment QA_TOKEN_ID = 50295 # token for qa device = "cuda" if torch.cuda.is_available() else "cpu" clip_embed = 768 phi_embed = 2560 audio_batch_size = 16 current_dir = os.getcwd() class SimpleResBlock(nn.Module): def __init__(self, phi_embed): super().__init__() self.pre_norm = nn.LayerNorm(phi_embed) self.proj = nn.Sequential( nn.Linear(phi_embed, phi_embed), nn.GELU(), nn.Linear(phi_embed, phi_embed) ) def forward(self, x): x = self.pre_norm(x) return x + self.proj(x) # models clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device) projection = torch.nn.Linear(clip_embed, phi_embed).to(device) resblock = SimpleResBlock(phi_embed).to(device) phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device) audio_model = whisper.load_model("tiny", device=device) lora_adaptor_path = os.path.join(current_dir, 'model_chkpt', 'lora_adaptor') projection_path = os.path.join(current_dir, 'model_chkpt', 'step2_projection.pth') resblock_path = os.path.join(current_dir, 'model_chkpt', 'step2_resblock.pth') # load weights model_to_merge = PeftModel.from_pretrained(phi_model,lora_adaptor_path, local_files_only=True, device_map={'': device}) merged_model = model_to_merge.merge_and_unload() projection.load_state_dict(torch.load(projection_path,map_location=torch.device(device))) resblock.load_state_dict(torch.load(resblock_path,map_location=torch.device(device))) def generate_response(img=None,img_audio=None,val_q=None): max_generate_length = 100 val_combined_embeds = [] with torch.no_grad(): # image if img is not None: image_processed = processor(images=img, return_tensors="pt").to(device) clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:] val_image_embeds = projection(clip_val_outputs) val_image_embeds = resblock(val_image_embeds).to(torch.float16) img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device) img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0) val_combined_embeds.append(val_image_embeds) val_combined_embeds.append(img_token_embeds) # audio if img_audio is not None: audio_result = audio_model.transcribe(img_audio) audio_text = '' for seg in audio_result['segments']: audio_text += seg['text'] audio_text = audio_text.strip() audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device) audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0) val_combined_embeds.append(audio_embeds) # text question if len(val_q) != 0: val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device) val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0) val_combined_embeds.append(val_q_embeds) if img_audio is not None or len(val_q) != 0: # add QA Token QA_token_tensor = torch.tensor(QA_TOKEN_ID).to(device) QA_token_embeds = merged_model.model.embed_tokens(QA_token_tensor).unsqueeze(0).unsqueeze(0) val_combined_embeds.append(QA_token_embeds) val_combined_embeds = torch.cat(val_combined_embeds,dim=1) predicted_caption = merged_model.generate(inputs_embeds=val_combined_embeds, max_new_tokens=max_generate_length, return_dict_in_generate = True) predicted_captions_decoded = tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0] predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>", "") return predicted_captions_decoded # Gradio interface setup with added styling with gui.Blocks() as app_interface: with gui.Row(): with gui.Column(): image_input = gui.Image(label='Upload Image', type="pil") with gui.Column(): audio_input = gui.Audio(label="Audio Input", sources=['microphone', 'upload'], type='filepath') text_input = gui.Text(label='Enter Text', placeholder="Type your query here...") with gui.Row(): output_response = gui.Textbox(label='Generated Response', placeholder="Response will appear here...", lines=5) submit_button = gui.Button("Generate Response", variant="primary") submit_button.click(generate_response, inputs=[image_input, audio_input, text_input], outputs=output_response) if __name__ == "__main__": app_interface.launch(share=True)