from models import IntuitionKillingMachine from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords from torchvision.transforms import Compose from encoders import get_tokenizer from PIL import Image, ImageDraw from zipfile import ZipFile from copy import copy import gradio as gr import pandas as pd import torch def parse_model_args(model_path): _, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13] return { 'dataset': dataset, 'max_length': int(max_length), 'input_size': int(input_size), 'backbone': backbone, 'num_heads': int(num_heads), 'num_layers': int(num_layers), 'num_conv': int(num_conv), 'mu': float(mu), 'mask_pooling': bool(mask_pooling == '1') } class Prober: def __init__(self, df_path=None, dataset_path=None, model_checkpoint=None): params = parse_model_args(model_checkpoint) mean = [0.485, 0.456, 0.406] sdev = [0.229, 0.224, 0.225] self.tokenizer = get_tokenizer() self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']] self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4])) self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', '')) self.model = IntuitionKillingMachine( backbone=params['backbone'], pretrained=True, num_heads=params['num_heads'], num_layers=params['num_layers'], num_conv=params['num_conv'], segmentation_head=bool(params['mu'] > 0.0), mask_pooling=params['mask_pooling'] ) self.load_model(model_checkpoint) self.transform = Compose([ ToTensor(), Normalize(mean, sdev), SquarePad(), Resize(size=(params['input_size'], params['input_size'])), NormalizeBoxCoords(), ]) self.max_length = 30 self.zipfile = ZipFile(dataset_path, 'r') def load_model(self, model_checkpoint): checkpoint = torch.load( model_checkpoint, map_location=lambda storage, loc: storage ) # strip 'model.' from pl checkpoint state_dict = { k[len('model.'):]: v for k, v in checkpoint['state_dict'].items() } missing, _ = self.model.load_state_dict(state_dict, strict=False) # ensure the only missing keys are those of the segmentation head only assert [k for k in missing if 'segm' not in k] == [] self.model = self.model.eval() def preview_image(self, idx): img_path, target, = self.df.loc[idx][['file_path','bbox']].values img = Image.open(self.zipfile.open(img_path)).convert('RGB') return img @torch.no_grad() def probe(self, idx, re, search_by_sample_id: bool= True): if search_by_sample_id: img_path, target, = self.df.loc[idx][['file_path','bbox']].values else: img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0] img = Image.open(self.zipfile.open(img_path)).convert('RGB') if re != "": W0, H0 = img.size sample = { 'image': img, 'image_size': (H0, W0), # image original size 'bbox': torch.tensor([copy(target)]), 'bbox_raw': torch.tensor([copy(target)]), 'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask 'mask_bbox': None, # target bbox mask } sample = self.transform(sample) tok = self.tokenizer(re, max_length=30, return_tensors='pt', truncation=True) inn = {'image': torch.stack([sample['image']]), 'mask': torch.stack([sample['mask']]), 'tok': tok} output = undo_box_transforms_batch(self.model(inn)[0], [sample['tr_param']]).numpy().tolist()[0] img1 = ImageDraw.Draw(img) #img1.rectangle(target, outline ="#0000FF00", width=3) img1.rectangle(output, outline ="#00FF0000", width=3) return img else: return img prober = Prober( df_path = 'data/val-sim_metric.json', dataset_path = "data/saiapr_tc-12.zip", model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt" ) demo = gr.Interface(fn=prober.probe, inputs=["number", "text"], outputs="image")#, live=True) demo.queue(concurrency_count=10) demo.launch()