RECModel / app.py
mmazuecos's picture
Back to gr.Interface, removed interactivity.
f13173e
raw
history blame contribute delete
No virus
4.94 kB
from models import IntuitionKillingMachine
from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords
from torchvision.transforms import Compose
from encoders import get_tokenizer
from PIL import Image, ImageDraw
from zipfile import ZipFile
from copy import copy
import gradio as gr
import pandas as pd
import torch
def parse_model_args(model_path):
_, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13]
return {
'dataset': dataset,
'max_length': int(max_length),
'input_size': int(input_size),
'backbone': backbone,
'num_heads': int(num_heads),
'num_layers': int(num_layers),
'num_conv': int(num_conv),
'mu': float(mu),
'mask_pooling': bool(mask_pooling == '1')
}
class Prober:
def __init__(self,
df_path=None,
dataset_path=None,
model_checkpoint=None):
params = parse_model_args(model_checkpoint)
mean = [0.485, 0.456, 0.406]
sdev = [0.229, 0.224, 0.225]
self.tokenizer = get_tokenizer()
self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']]
self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4]))
self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', ''))
self.model = IntuitionKillingMachine(
backbone=params['backbone'],
pretrained=True,
num_heads=params['num_heads'],
num_layers=params['num_layers'],
num_conv=params['num_conv'],
segmentation_head=bool(params['mu'] > 0.0),
mask_pooling=params['mask_pooling']
)
self.load_model(model_checkpoint)
self.transform = Compose([
ToTensor(),
Normalize(mean, sdev),
SquarePad(),
Resize(size=(params['input_size'], params['input_size'])),
NormalizeBoxCoords(),
])
self.max_length = 30
self.zipfile = ZipFile(dataset_path, 'r')
def load_model(self, model_checkpoint):
checkpoint = torch.load(
model_checkpoint, map_location=lambda storage, loc: storage
)
# strip 'model.' from pl checkpoint
state_dict = {
k[len('model.'):]: v
for k, v in checkpoint['state_dict'].items()
}
missing, _ = self.model.load_state_dict(state_dict, strict=False)
# ensure the only missing keys are those of the segmentation head only
assert [k for k in missing if 'segm' not in k] == []
self.model = self.model.eval()
def preview_image(self, idx):
img_path, target, = self.df.loc[idx][['file_path','bbox']].values
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
return img
@torch.no_grad()
def probe(self, idx, re, search_by_sample_id: bool= True):
if search_by_sample_id:
img_path, target, = self.df.loc[idx][['file_path','bbox']].values
else:
img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
if re != "":
W0, H0 = img.size
sample = {
'image': img,
'image_size': (H0, W0), # image original size
'bbox': torch.tensor([copy(target)]),
'bbox_raw': torch.tensor([copy(target)]),
'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask
'mask_bbox': None, # target bbox mask
}
sample = self.transform(sample)
tok = self.tokenizer(re,
max_length=30,
return_tensors='pt',
truncation=True)
inn = {'image': torch.stack([sample['image']]),
'mask': torch.stack([sample['mask']]),
'tok': tok}
output = undo_box_transforms_batch(self.model(inn)[0],
[sample['tr_param']]).numpy().tolist()[0]
img1 = ImageDraw.Draw(img)
#img1.rectangle(target, outline ="#0000FF00", width=3)
img1.rectangle(output, outline ="#00FF0000", width=3)
return img
else:
return img
prober = Prober(
df_path = 'data/val-sim_metric.json',
dataset_path = "data/saiapr_tc-12.zip",
model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
)
demo = gr.Interface(fn=prober.probe, inputs=["number", "text"], outputs="image")#, live=True)
demo.queue(concurrency_count=10)
demo.launch()