import gradio as gr import torch from PIL import Image from torchvision import transforms from transformers import T5Tokenizer, ViTImageProcessor from torch import nn from typing import Tuple import sentencepiece class Encoder(nn.Module): def __init__(self, pretrained_model): """ Implements the Encoder." Args: pretrained_model (str): name of the pretrained model """ super(Encoder, self).__init__() self.encoder = ViTModel.from_pretrained(pretrained_model) def forward(self, input): out = self.encoder(pixel_values = input) return out class Decoder(nn.Module): def __init__(self, pretrained_model, encoder_modeldim): """ Implements the Decoder." Args: pretrained_model (str): name of the pretrained model """ super(Decoder, self).__init__() self.decoder = T5ForConditionalGeneration.from_pretrained(pretrained_model) self.linear = nn.Linear(self.decoder.model_dim, encoder_modeldim, bias = False) self.encoder_modeldim = encoder_modeldim def forward(self, output_encoder, targets, decoder_ids=None): if self.decoder.model_dim!=self.encoder_modeldim: print(f"Changed model hidden dimension from {self.encoder_modeldim} to {self.decoder.model_dim}") output_encoder = self.linear(output_encoder) print(output_encoder.shape) # Validation/Testing if decoder_ids is not None: out = self.decoder(encoder_outputs=output_encoder, decoder_input_ids=decoder_ids) # Training else: out = self.decoder(encoder_outputs=output_encoder, labels=targets) return out class EncoderDecoder(nn.Module): def __init__(self, pretrained_model: Tuple[str], encoder_dmodel=768, eos_token_id=None, pad_token_id=None): """ Implements a model that combines MyEncoder and MyDecoder." Args: pretrained_model (tuple): name of the pretrained model encoder_dmodel (int): hidden dimension of the encoder output eos_token_id (torch.long): token used for end of sentence pad_token_id (torch.long): token used for padding """ super(EncoderDecoder, self).__init__() self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.encoder = Encoder(pretrained_model[0]) self.encoder_dmodel = encoder_dmodel # Freeze parameters from encoder #for p in self.encoder.parameters(): # p.requires_grad=False self.decoder = Decoder(pretrained_model[1], self.encoder_dmodel) self.decoder_start_token_id = self.decoder.decoder.config.decoder_start_token_id def forward(self, images = None, targets = None, decoder_ids = None): output_encoder = self.encoder(images) out = self.decoder(output_encoder, targets, decoder_ids) return out # Model loading and setting up the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.load("model_vit_ai.pt", map_location=device) model.to(device) # Tokenizer and Feature Extractor tokenizer = T5Tokenizer.from_pretrained('t5-base') feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') # Define the image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) ]) def preprocess_image(image): image = Image.fromarray(image.astype('uint8'), 'RGB') image = transform(image) return image.unsqueeze(0) def generate_caption(image): model.eval() with torch.no_grad(): image_tensor = preprocess_image(image).to(device) decoder_input_ids = torch.full((1, 1), model.decoder_start_token_id, dtype=torch.long, device=device) for _ in range(50): outputs = model(images=image_tensor, decoder_ids=decoder_input_ids) next_token_logits = outputs.logits[:, -1, :] next_token_id = next_token_logits.argmax(1, keepdim=True) decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1) if torch.eq(next_token_id, tokenizer.eos_token_id).all(): break caption = tokenizer.decode(decoder_input_ids.squeeze(0), skip_special_tokens=True) return caption sample_images = [ "sample_image1.jpg", "sample_image2.jpg", "sample_image3.jpg" ] interface = gr.Interface( fn=generate_caption, inputs="image", # Specify the input type as "image" outputs="text", examples=sample_images, title="Image Captioning Model", description="Upload an image, select a sample image, or use your webcam to take a photo and generate a caption." ) # Run the interface interface.launch(debug=True)