Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import cv2 | |
import os | |
from PIL import Image | |
import warnings | |
import gradio as gr | |
from model import DocGeoNet | |
from seg import U2NETP | |
import glob | |
warnings.filterwarnings('ignore') | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.msk = U2NETP(3, 1) | |
self.DocTr = DocGeoNet() | |
def forward(self, x): | |
msk, _1,_2,_3,_4,_5,_6 = self.msk(x) | |
msk = (msk > 0.5).float() | |
x = msk * x | |
_, _, bm = self.DocTr(x) | |
bm = (2 * (bm / 255.) - 1) * 0.99 | |
return bm | |
def reload_seg_model(model, path=""): | |
if not bool(path): | |
return model | |
else: | |
model_dict = model.state_dict() | |
pretrained_dict = torch.load(path, map_location='cpu') | |
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict} | |
model_dict.update(pretrained_dict) | |
model.load_state_dict(model_dict) | |
return model | |
def reload_rec_model(model, path=""): | |
if not bool(path): | |
return model | |
else: | |
model_dict = model.state_dict() | |
pretrained_dict = torch.load(path, map_location='cpu') | |
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict} | |
model_dict.update(pretrained_dict) | |
model.load_state_dict(model_dict) | |
return model | |
def rec(input_image): | |
seg_model_path = './model_pretrained/preprocess.pth' | |
rec_model_path = './model_pretrained/DocGeoNet.pth' | |
net = Net() | |
reload_rec_model(net.DocTr, rec_model_path) | |
reload_seg_model(net.msk, seg_model_path) | |
net.eval() | |
im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1 | |
h, w, _ = im_ori.shape | |
im = cv2.resize(im_ori, (256, 256)) | |
im = im.transpose(2, 0, 1) | |
im = torch.from_numpy(im).float().unsqueeze(0) | |
with torch.no_grad(): | |
bm = net(im) | |
bm = bm.cpu() | |
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow | |
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow | |
bm0 = cv2.blur(bm0, (3, 3)) | |
bm1 = cv2.blur(bm1, (3, 3)) | |
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2 | |
out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True) | |
img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8) | |
# Convert from BGR to RGB | |
img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB) | |
return Image.fromarray(img_rec) | |
demo_img_files = glob.glob('./distorted/*.[jJ][pP][gG]') + glob.glob('./distorted/*.[pP][nN][gG]') | |
# Gradio Interface | |
input_image = gr.inputs.Image() | |
output_image = gr.outputs.Image(type='pil') | |
iface = gr.Interface(fn=rec, inputs=input_image, outputs=output_image, title="DocGeoNet",examples=demo_img_files) | |
#iface.launch(server_port=8821, server_name="0.0.0.0") | |
iface.launch() |