import data import torch import gradio as gr from models import imagebind_model from models.imagebind_model import ModalityType device = "cuda:0" if torch.cuda.is_available() else "cpu" model = imagebind_model.imagebind_huge(pretrained=True) model.eval() model.to(device) def image_text_zeroshot(image, text_list): image_paths = [image] labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] inputs = { ModalityType.TEXT: data.load_and_transform_text(labels, device), ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), } with torch.no_grad(): embeddings = model(inputs) scores = ( torch.softmax( embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 ) .squeeze(0) .tolist() ) score_dict = {label: score for label, score in zip(labels, scores)} return score_dict def audio_text_zeroshot(audio, text_list): audio_paths = [audio] labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] inputs = { ModalityType.TEXT: data.load_and_transform_text(labels, device), ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), } with torch.no_grad(): embeddings = model(inputs) scores = ( torch.softmax( embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1 ) .squeeze(0) .tolist() ) score_dict = {label: score for label, score in zip(labels, scores)} return score_dict def video_text_zeroshot(image, text_list): image_paths = [image] labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] inputs = { ModalityType.TEXT: data.load_and_transform_text(labels, device), ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), } with torch.no_grad(): embeddings = model(inputs) scores = ( torch.softmax( embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 ) .squeeze(0) .tolist() ) score_dict = {label: score for label, score in zip(labels, scores)} return score_dict def doubleimage_text_zeroshot(image, image2, text_list): image_paths = [image, image2] labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] inputs = { ModalityType.TEXT: data.load_and_transform_text(labels, device), ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), } with torch.no_grad(): embeddings = model(inputs) scores = ( torch.softmax( embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 ) .squeeze(0) .tolist() ) score_dict = {label: score for label, score in zip(labels, scores)} return score_dict def doubleimage_text_zeroshotOLD(image, image2, text_list): image_paths = [image, image2] labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] inputs = { ModalityType.TEXT: data.load_and_transform_text(labels, device), ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), } with torch.no_grad(): embeddings = model(inputs) return str(torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1) ) def inference( task, text_list=None, image=None, audio=None, image2=None, ): if task == "image-text": result = image_text_zeroshot(image, text_list) elif task == "audio-text": result = audio_text_zeroshot(audio, text_list) elif task == "embeddings": result = doubleimage_text_zeroshot(image, image2, text_list) else: raise NotImplementedError return result def main(): inputs = [ gr.inputs.Radio( choices=[ "image-text", "audio-text", "embeddings", ], type="value", default="embeddings", label="Task", ), gr.inputs.Textbox(lines=1, label="Candidate texts"), gr.inputs.Image(type="filepath", label="Input image"), gr.inputs.Audio(type="filepath", label="Input audio"), gr.inputs.Image(type="filepath", label="Input image2"), ] iface = gr.Interface( inference, inputs, "label", title="Multimodal AI assitive agents for Learning Disorders : Demo with embeddings of ImageBind: ", ) iface.launch() if __name__ == "__main__": main()