Spaces:
Sleeping
Sleeping
import os | |
import numpy | |
import torch | |
from torch import nn | |
from PIL import Image | |
from transformers import BertTokenizer | |
from Model import clip | |
from Model.bert import BertLMHeadModel, BertConfig | |
from Model.clip.model import Transformer | |
class Proj(nn.Module): | |
def __init__(self, encoder_output_size, num_head=16): | |
super().__init__() | |
self.encoder_output_size = encoder_output_size | |
self.transformer = Transformer(encoder_output_size, 1, num_head) | |
self.linear = nn.Linear(encoder_output_size, 768) | |
return | |
def forward(self, x): | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
return self.linear(x) | |
class TRCaptionNet(nn.Module): | |
def __init__(self, config: dict): | |
super().__init__() | |
# parameters | |
self.max_length = config["max_length"] | |
self.proj_flag = config["proj"] | |
assert type(self.proj_flag) == bool | |
self.proj_num_head = config["proj_num_head"] | |
# vision encoder | |
self.vision_encoder, preprocess = clip.load(config["clip"], jit=False) | |
self.vision_encoder.eval() | |
self.vision_encoder = self.vision_encoder.visual.float() | |
with torch.no_grad(): | |
dummy_input_image = preprocess(Image.fromarray(numpy.zeros((512, 512, 3), dtype=numpy.uint8))).to(next(self.parameters()).device) | |
encoder_output_size = self.vision_encoder(dummy_input_image.unsqueeze(0)).shape[-1] | |
# language decoder | |
if not os.path.isfile(config["bert"]): | |
self.language_decoder = BertLMHeadModel.from_pretrained(config["bert"], | |
is_decoder=True, | |
add_cross_attention=True) | |
self.tokenizer = BertTokenizer.from_pretrained(config["bert"]) | |
else: | |
med_config = BertConfig.from_json_file(config["bert"]) | |
self.language_decoder = BertLMHeadModel(config=med_config) | |
self.tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-turkish-cased") | |
# proj | |
if self.proj_flag: | |
if self.proj_num_head is None: | |
self.proj = nn.Linear(encoder_output_size, 768) | |
else: | |
self.proj = Proj(encoder_output_size, self.proj_num_head) | |
else: | |
self.proj = None | |
return | |
def generate(self, images, max_length: int = None, min_length: int = 12, num_beams: int = 3, | |
repetition_penalty: float = 1.1): | |
image_embeds = self.vision_encoder(images) | |
if self.proj is not None: | |
image_embeds = self.proj(image_embeds) | |
image_atts = torch.ones(image_embeds.shape[:-1], dtype=torch.long).to(images.device) | |
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask": image_atts} | |
input_ids = torch.ones((image_embeds.shape[0], 1), device=images.device, dtype=torch.long) | |
input_ids *= 2 | |
outputs = self.language_decoder.generate(input_ids=input_ids, | |
max_length=self.max_length if max_length is None else max_length, | |
min_length=min_length, | |
num_beams=num_beams, | |
eos_token_id=self.tokenizer.sep_token_id, | |
pad_token_id=self.tokenizer.pad_token_id, | |
repetition_penalty=repetition_penalty, | |
**model_kwargs) | |
captions = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
return captions | |
def test(): | |
model = TRCaptionNet({ | |
"max_length": 35, | |
"clip": "ViT-B/32", | |
"bert": "dbmdz/bert-base-turkish-cased" | |
}) | |
return | |
if __name__ == '__main__': | |
test() | |