yangwu
update
1c49612
import gradio as gr
import os
import sys
import numpy as np
import numpy as np
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms as transforms
from mmcv.utils import Config
sys.path.append('.')
from image_forgery_detection import build_detector
from image_forgery_detection import Compose
transform_pil = transforms.Compose([
transforms.ToPILImage(),
])
def inference_api(f_path):
print(f_path)
results = dict(img_info=dict(filename=f_path, ann=dict(seg_map='None')))
results['seg_fields'] = []
results['img_prefix'] = None
results['seg_prefix'] = None
inputs = pipelines(results)
img = inputs['img'].data
img_meta = inputs['img_metas'].data
if 'dct_vol' in inputs:
dct_vol = inputs['dct_vol'].data
qtables = inputs['qtables'].data
with torch.no_grad():
img = img.unsqueeze(dim=0)
if 'dct_vol' in inputs:
dct_vol = dct_vol.unsqueeze(dim=0)
qtables = qtables.unsqueeze(dim=0)
cls_pred, seg_pred = model(img, dct_vol, qtables, [img_meta, ], return_loss=False, rescale=True)
else:
cls_pred, seg_pred = model(img, [img_meta, ], return_loss=False, rescale=True)
cls_pred = cls_pred[0]
seg = seg_pred[0, 0]
seg = np.array(transform_pil(torch.from_numpy(seg)))
thresh_int = 255 * thresh
seg[seg>=thresh_int] = 255
seg[seg<thresh_int] = 0
return '{:.3f}'.format(cls_pred), seg
if __name__ == '__main__':
model_path = './models/latest.pth'
cfg = Config.fromfile('./models/config.py')
global model
global pipelines
global thresh
thresh = 0.5
if hasattr(cfg.model.base_model, 'backbone'):
cfg.model.base_model.backbone.pretrained = None
else:
cfg.model.base_model.pretrained = None
model = build_detector(cfg.model)
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location='cpu')['state_dict']
model.load_state_dict(checkpoint, strict=True)
print("load %s finish" % (os.path.basename(model_path)))
else:
print("%s not exist" % model_path)
exit(1)
model.eval()
pipelines = Compose(cfg.data.val[0].pipeline)
iface = gr.Interface(
inference_api,
inputs=gr.components.Image(label="Upload image to detect", type="filepath"),
# outputs=['text', 'image'],
outputs=[gr.components.Textbox(type="text", label="image forgery score"),
gr.components.Image(type="numpy", label="predict mask")],
title="Forged? Or Not?",
)
# iface.launch(server_name='0.0.0.0', share=True)
iface.launch()