import os import cv2 import numpy as np from PIL import Image import json import gradio as gr import matplotlib.pyplot as plt import torch from torch import nn import torch.nn.functional as F import timm from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer class CFG: image_path = './images' captions_path = './captions' batch_size = 64 num_workers = 4 head_lr = 1e-3 image_encoder_lr = 1e-4 text_encoder_lr = 1e-5 weight_decay = 1e-3 patience = 1 factor = 0.8 epochs = 2 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name = 'resnet50' image_embedding = 2048 text_encoder_model = "distilbert-base-uncased" text_embedding = 768 text_tokenizer = "distilbert-base-uncased" max_length = 200 pretrained = True # for both image encoder and text encoder trainable = True # for both image encoder and text encoder temperature = 1.0 # image size size = 224 # for projection head; used for both image and text encoders num_projection_layers = 1 projection_dim = 256 dropout = 0.1 class ImageEncoder(nn.Module): """ Encode images to a fixed size vector """ def __init__( self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable ): super().__init__() self.model = timm.create_model( model_name, pretrained, num_classes=0, global_pool="avg" ) for p in self.model.parameters(): p.requires_grad = trainable def forward(self, x): return self.model(x) class TextEncoder(nn.Module): def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable): super().__init__() if pretrained: self.model = DistilBertModel.from_pretrained(model_name) else: self.model = DistilBertModel(config=DistilBertConfig()) for p in self.model.parameters(): p.requires_grad = trainable # we are using the CLS token hidden representation as the sentence's embedding self.target_token_idx = 0 def forward(self, input_ids, attention_mask): output = self.model(input_ids=input_ids, attention_mask=attention_mask) last_hidden_state = output.last_hidden_state return last_hidden_state[:, self.target_token_idx, :] class ProjectionHead(nn.Module): def __init__( self, embedding_dim, projection_dim=CFG.projection_dim, dropout=CFG.dropout ): super().__init__() self.projection = nn.Linear(embedding_dim, projection_dim) self.gelu = nn.GELU() self.fc = nn.Linear(projection_dim, projection_dim) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(projection_dim) def forward(self, x): projected = self.projection(x) x = self.gelu(projected) x = self.fc(x) x = self.dropout(x) x = x + projected x = self.layer_norm(x) return x class CLIPModel(nn.Module): def __init__( self, temperature=CFG.temperature, image_embedding=CFG.image_embedding, text_embedding=CFG.text_embedding, ): super().__init__() self.image_encoder = ImageEncoder() self.text_encoder = TextEncoder() self.image_projection = ProjectionHead(embedding_dim=image_embedding) self.text_projection = ProjectionHead(embedding_dim=text_embedding) self.temperature = temperature def forward(self, batch): # Getting Image and Text Features image_features = self.image_encoder(batch["image"]) text_features = self.text_encoder( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] ) # Getting Image and Text Embeddings (with same dimension) image_embeddings = self.image_projection(image_features) text_embeddings = self.text_projection(text_features) # Calculating the Loss images_similarity = image_embeddings @ text_embeddings.T / self.temperature texts_similarity = images_similarity.T labels = torch.arange(batch["image"].shape[0]).long().to(CFG.device) total_loss = ( F.cross_entropy(images_similarity, labels) + F.cross_entropy(texts_similarity, labels) ) / 2 return total_loss def find_matches_cpu(model, image_embeddings, query, image_filenames, n=4): tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) encoded_query = tokenizer([query]) batch = { key: torch.tensor(values).to('cpu') for key, values in encoded_query.items() } with torch.no_grad(): text_features = model.text_encoder( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] ) text_embeddings = model.text_projection(text_features) image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) dot_similarity = text_embeddings_n @ image_embeddings_n.T values, indices = torch.topk(dot_similarity.squeeze(0), n * 5) matches = [image_filenames[idx] for idx in indices[::5]] return matches def rle_decode(img_rle_array, img_name, img_size): encoded_image = img_rle_array # Initialize variables for decoding decoded_image = [] for i in range(0, len(encoded_image), 2): pixel_value = encoded_image[i] run_length = encoded_image[i + 1] decoded_image.extend([pixel_value] * run_length) # Convert the decoded image back to a NumPy array decoded_array = np.array(decoded_image, dtype=np.uint8) # Reshape the decoded array to the original image shape (224, 224) decoded_image = decoded_array.reshape(img_size) # Use original shape # Create a PIL Image from the decoded array decoded_image = Image.fromarray(decoded_image) decoded_image_save_path = './' + str(img_name) # Save or display the decoded image decoded_image.save(decoded_image_save_path) # Save the decoded image to a file return decoded_image_save_path def get_matched_image(matches, val_file_dict_loaded): img_size = (112, 112) match_img_list = [] for img_name in matches: img_rle_array = val_file_dict_loaded[img_name] decoded_image_path = rle_decode(img_rle_array, img_name, img_size) match_img_list.append(decoded_image_path) return match_img_list def get_grayscale_image(text_query): model_inf = CLIPModel().to('cpu') model_inf.load_state_dict(torch.load('best_clip_model_cpu.pt', map_location='cpu')) clip_image_embeddings_np_inf = np.load('clip_image_embeddings.npy') image_embeddings_inf = torch.tensor(clip_image_embeddings_np_inf) img_file_names = np.load('val_img_file_names.npy',allow_pickle=True) with open("val_imgs_rle_encode.json", "r") as json_file: val_file_dict_loaded = json.load(json_file) matches = find_matches_cpu(model_inf, image_embeddings_inf, query=text_query, image_filenames=img_file_names, n=1) matched_images = get_matched_image(matches, val_file_dict_loaded) return matched_images def gradio_fn(text): text_query = str(text) match_img_list = get_grayscale_image(text_query) pil_img = Image.open(match_img_list[0]) pil_img = pil_img.resize((224, 224)) np_img_array = np.array(pil_img) return np_img_array demo = gr.Interface(fn=gradio_fn, inputs="text", outputs=gr.Image(height=224, width=224), title="CLIP Image Search") demo.launch(share=True)