--- license: mit --- # GigaBind A finetuned ImageBind using Lora for images, audio, and many many other modalitiesi ## Usage ```python import logging import torch import data from models import imagebind_model from models.imagebind_model import ModalityType, load_module from models import lora as LoRA logging.basicConfig(level=logging.INFO, force=True) lora = True linear_probing = False device = "cpu" # "cuda:0" if torch.cuda.is_available() else "cpu" load_head_post_proc_finetuned = True assert not (linear_probing and lora), \ "Linear probing is a subset of LoRA training procedure for ImageBind. " \ "Cannot set both linear_probing=True and lora=True. " if lora and not load_head_post_proc_finetuned: # Hack: adjust lora_factor to the `max batch size used during training / temperature` to compensate missing norm lora_factor = 12 / 0.07 else: # This assumes proper loading of all params but results in shift from original dist in case of LoRA lora_factor = 1 text_list=["bird", "car", "dog3", "dog5", "dog8", "grey_sloth_plushie"] image_paths=[".assets/bird_image.jpg", ".assets/car_image.jpg", ".assets/dog3.jpg", ".assets/dog5.jpg", ".assets/dog8.jpg", ".assets/grey_sloth_plushie.jpg"] audio_paths=[".assets/bird_audio.wav", ".assets/car_audio.wav", ".assets/dog_audio.wav"] # Instantiate model model = imagebind_model.imagebind_huge(pretrained=True) if lora: model.modality_trunks.update( LoRA.apply_lora_modality_trunks(model.modality_trunks, rank=4, layer_idxs={ModalityType.TEXT: [0, 1, 2, 3, 4, 5, 6, 7, 8], ModalityType.VISION: [0, 1, 2, 3, 4, 5, 6, 7, 8]}, modality_names=[ModalityType.TEXT, ModalityType.VISION])) # Load LoRA params if found LoRA.load_lora_modality_trunks(model.modality_trunks, checkpoint_dir=".checkpoints/lora/550_epochs_lora", postfix="_dreambooth_last") if load_head_post_proc_finetuned: # Load postprocessors & heads load_module(model.modality_postprocessors, module_name="postprocessors", checkpoint_dir=".checkpoints/lora/550_epochs_lora", postfix="_dreambooth_last") load_module(model.modality_heads, module_name="heads", checkpoint_dir=".checkpoints/lora/550_epochs_lora", postfix="_dreambooth_last") elif linear_probing: # Load heads load_module(model.modality_heads, module_name="heads", checkpoint_dir="./.checkpoints/lora/500_epochs_lp", postfix="_dreambooth_last") model.eval() model.to(device) # Load data inputs = { ModalityType.TEXT: data.load_and_transform_text(text_list, device), ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device, to_tensor=True), ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), } with torch.no_grad(): embeddings = model(inputs) print( "Vision x Text: ", torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T * (lora_factor if lora else 1), dim=-1), ) print( "Audio x Text: ", torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T * (lora_factor if lora else 1), dim=-1), ) print( "Vision x Audio: ", torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1), ) ```