DocGeoNet / app.py
HaoFeng2019's picture
Update app.py
2718a79
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
from PIL import Image
import warnings
import gradio as gr
from model import DocGeoNet
from seg import U2NETP
import glob
warnings.filterwarnings('ignore')
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.msk = U2NETP(3, 1)
self.DocTr = DocGeoNet()
def forward(self, x):
msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
msk = (msk > 0.5).float()
x = msk * x
_, _, bm = self.DocTr(x)
bm = (2 * (bm / 255.) - 1) * 0.99
return bm
def reload_seg_model(model, path=""):
if not bool(path):
return model
else:
model_dict = model.state_dict()
pretrained_dict = torch.load(path, map_location='cpu')
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
def reload_rec_model(model, path=""):
if not bool(path):
return model
else:
model_dict = model.state_dict()
pretrained_dict = torch.load(path, map_location='cpu')
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
def rec(input_image):
seg_model_path = './model_pretrained/preprocess.pth'
rec_model_path = './model_pretrained/DocGeoNet.pth'
net = Net()
reload_rec_model(net.DocTr, rec_model_path)
reload_seg_model(net.msk, seg_model_path)
net.eval()
im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1
h, w, _ = im_ori.shape
im = cv2.resize(im_ori, (256, 256))
im = im.transpose(2, 0, 1)
im = torch.from_numpy(im).float().unsqueeze(0)
with torch.no_grad():
bm = net(im)
bm = bm.cpu()
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
bm0 = cv2.blur(bm0, (3, 3))
bm1 = cv2.blur(bm1, (3, 3))
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8)
# Convert from BGR to RGB
img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
return Image.fromarray(img_rec)
demo_img_files = glob.glob('./distorted/*.[jJ][pP][gG]') + glob.glob('./distorted/*.[pP][nN][gG]')
# Gradio Interface
input_image = gr.inputs.Image()
output_image = gr.outputs.Image(type='pil')
iface = gr.Interface(fn=rec, inputs=input_image, outputs=output_image, title="DocGeoNet",examples=demo_img_files)
#iface.launch(server_port=8821, server_name="0.0.0.0")
iface.launch()