|
from util import UIDataset, Vocabulary |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from model import * |
|
from torchvision import transforms |
|
from PIL import Image |
|
|
|
|
|
net = Pix2Code() |
|
net.load_state_dict(torch.load('./pix2code.weights')) |
|
net.cuda().eval() |
|
|
|
|
|
vocab = Vocabulary('voc.pkl') |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
]) |
|
|
|
|
|
def generate_gui(image): |
|
|
|
image = transform(image).unsqueeze(0).cuda() |
|
|
|
|
|
context = torch.tensor([vocab.to_vec(' '), vocab.to_vec('<START>')]).unsqueeze(0).float().cuda() |
|
|
|
|
|
code = [] |
|
|
|
|
|
for i in range(200): |
|
|
|
index = torch.argmax(net(image, context), 2).squeeze()[-1:].squeeze() |
|
|
|
|
|
token = vocab.to_vocab(int(index)) |
|
|
|
|
|
if token == '<END>': |
|
break |
|
|
|
|
|
code.append(token) |
|
|
|
|
|
context = torch.cat([context, torch.tensor([vocab.to_vec(token)]).unsqueeze(0).float().cuda()], dim=1) |
|
|
|
|
|
return ''.join(code) |
|
|
|
import gradio as gr |
|
|
|
|
|
image_input = gr.inputs.Image() |
|
|
|
|
|
text_output = gr.outputs.Textbox() |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_gui, |
|
inputs=image_input, |
|
outputs=text_output, |
|
title='Pix2Code', |
|
description='Gerador de código GUI a partir de imagens', |
|
theme='default' |
|
) |
|
|
|
|
|
iface.launch() |
|
|