import torch import sys import gradio as gr from PIL import Image device = 'cpu' import clip import os from torch import nn import numpy as np import torch import torch.nn.functional as nnf import sys from transformers import GPT2Tokenizer, GPT2LMHeadModel from tqdm import tqdm, trange import PIL.Image #ggf import transformers device = 'cuda' if torch.cuda.is_available() else 'cpu' model_path = 'coco_prefix_latest.pt' class MLP(nn.Module): def forward(self, x): return self.model(x) def __init__(self, sizes, bias=True, act=nn.Tanh): super(MLP, self).__init__() layers = [] for i in range(len(sizes) -1): layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) if i < len(sizes) - 2: layers.append(act()) self.model = nn.Sequential(*layers) class ClipCaptionModel(nn.Module): #@functools.lru_cache #FIXME def get_dummy_token(self, batch_size, device): return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) def forward(self, tokens, prefix, mask, labels): embedding_text = self.gpt.transformer.wte(tokens) prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size) #print(embedding_text.size()) #torch.Size([5, 67, 768]) #print(prefix_projections.size()) #torch.Size([5, 1, 768]) embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1) if labels is not None: dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) labels = torch.cat((dummy_token, tokens), dim=1) out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask) return out def __init__(self, prefix_length, prefix_size: int = 512): super(ClipCaptionModel, self).__init__() self.prefix_length = prefix_length self.gpt = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] if prefix_length > 10: # not enough memory self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length) else: self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length)) class ClipCaptionPrefix(ClipCaptionModel): def parameters(self, recurse: bool = True): return self.clip_project.parameters() def train(self, mode: bool = True): super(ClipCaptionPrefix, self).train(mode) self.gpt.eval() return self clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') prefix_length = 10 model = ClipCaptionModel(prefix_length) model.load_state_dict(torch.load(model_path, map_location='cpu')) model.to(device) def generate2( model, tokenizer, tokens=None, prompt=None, embed=None, entry_count=1, entry_length=67, top_p=0.98, temperature=1., stop_token = '.', ): model.eval() generated_num = 0 generated_list = [] stop_token_index = tokenizer.encode(stop_token)[0] filter_value = -float("Inf") device = next(model.parameters()).device with torch.no_grad(): for entry_idx in trange(entry_count): if embed is not None: generated = embed else: if tokens is None: tokens = torch.tensor(tokenizer.encode(prompt)) tokens = tokens.unsqueeze(0).to(device) generated = model.gpt.transformer.wte(tokens) for i in range(entry_length): outputs = model.gpt(inputs_embeds=generated) logits = outputs.logits logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ ..., :-1 ].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[:, indices_to_remove] = filter_value # top_k = 2000 top_p = 0.98 #print(logits) #next_token = transformers.top_k_top_p_filtering(logits.to(torch.int64).unsqueeze(0), top_k=top_k, top_p=top_p) next_token = torch.argmax(logits, -1).unsqueeze(0) next_token_embed = model.gpt.transformer.wte(next_token) if tokens is None: tokens = next_token else: tokens = torch.cat((tokens, next_token), dim=1) generated = torch.cat((generated, next_token_embed), dim=1) if stop_token_index == next_token.item(): break output_list = list(tokens.squeeze().cpu().numpy()) output_text = tokenizer.decode(output_list) generated_list.append(output_text) return generated_list[0] def _to_caption(pil_image): image = preprocess(pil_image).unsqueeze(0).to(device) with torch.no_grad(): prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed) return generated_text_prefix def classify_image(inp): print(type(inp)) inp = Image.fromarray(inp) texts = _to_caption(inp) print(texts) return texts image = gr.inputs.Image(shape=(256, 256)) label = gr.outputs.Label(num_top_classes=3) iface = gr.Interface(fn=classify_image, description="https://github.com/AlexWortega/ruImageCaptioning RuImage Captioning trained for a image2text task to predict caption of image by https://t.me/lovedeathtransformers Alex Wortega", inputs=image, outputs="text",examples=[ ['1.jpeg']]) iface.launch()