Spaces:
Build error
Build error
import gradio as gr | |
import os | |
import sys | |
import numpy as np | |
import numpy as np | |
import torch.backends.cudnn as cudnn | |
import torch.utils.data | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from mmcv.utils import Config | |
sys.path.append('.') | |
from image_forgery_detection import build_detector | |
from image_forgery_detection import Compose | |
transform_pil = transforms.Compose([ | |
transforms.ToPILImage(), | |
]) | |
def inference_api(f_path): | |
print(f_path) | |
results = dict(img_info=dict(filename=f_path, ann=dict(seg_map='None'))) | |
results['seg_fields'] = [] | |
results['img_prefix'] = None | |
results['seg_prefix'] = None | |
inputs = pipelines(results) | |
img = inputs['img'].data | |
img_meta = inputs['img_metas'].data | |
if 'dct_vol' in inputs: | |
dct_vol = inputs['dct_vol'].data | |
qtables = inputs['qtables'].data | |
with torch.no_grad(): | |
img = img.unsqueeze(dim=0) | |
if 'dct_vol' in inputs: | |
dct_vol = dct_vol.unsqueeze(dim=0) | |
qtables = qtables.unsqueeze(dim=0) | |
cls_pred, seg_pred = model(img, dct_vol, qtables, [img_meta, ], return_loss=False, rescale=True) | |
else: | |
cls_pred, seg_pred = model(img, [img_meta, ], return_loss=False, rescale=True) | |
cls_pred = cls_pred[0] | |
seg = seg_pred[0, 0] | |
seg = np.array(transform_pil(torch.from_numpy(seg))) | |
thresh_int = 255 * thresh | |
seg[seg>=thresh_int] = 255 | |
seg[seg<thresh_int] = 0 | |
return '{:.3f}'.format(cls_pred), seg | |
if __name__ == '__main__': | |
model_path = './models/latest.pth' | |
cfg = Config.fromfile('./models/config.py') | |
global model | |
global pipelines | |
global thresh | |
thresh = 0.5 | |
if hasattr(cfg.model.base_model, 'backbone'): | |
cfg.model.base_model.backbone.pretrained = None | |
else: | |
cfg.model.base_model.pretrained = None | |
model = build_detector(cfg.model) | |
if os.path.exists(model_path): | |
checkpoint = torch.load(model_path, map_location='cpu')['state_dict'] | |
model.load_state_dict(checkpoint, strict=True) | |
print("load %s finish" % (os.path.basename(model_path))) | |
else: | |
print("%s not exist" % model_path) | |
exit(1) | |
model.eval() | |
pipelines = Compose(cfg.data.val[0].pipeline) | |
iface = gr.Interface( | |
inference_api, | |
inputs=gr.components.Image(label="Upload image to detect", type="filepath"), | |
# outputs=['text', 'image'], | |
outputs=[gr.components.Textbox(type="text", label="image forgery score"), | |
gr.components.Image(type="numpy", label="predict mask")], | |
title="Forged? Or Not?", | |
) | |
# iface.launch(server_name='0.0.0.0', share=True) | |
iface.launch() | |