pix2code / app.py
Bruno's picture
Rename apply.py to app.py
11ff00c
from util import UIDataset, Vocabulary
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
from torch.utils.data import DataLoader
from model import *
from torchvision import transforms
from PIL import Image
# Carrega o modelo treinado
net = Pix2Code()
net.load_state_dict(torch.load('./pix2code.weights'))
net.cuda().eval()
# Carrega o vocabulário
vocab = Vocabulary('voc.pkl')
# Define uma transformação para redimensionar e normalizar as imagens
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Função que receberá a imagem e retornará o código GUI gerado
def generate_gui(image):
# Aplica a transformação na imagem
image = transform(image).unsqueeze(0).cuda()
# Cria um contexto inicial
context = torch.tensor([vocab.to_vec(' '), vocab.to_vec('<START>')]).unsqueeze(0).float().cuda()
# Inicializa uma lista para armazenar o código gerado
code = []
# Gera o código iterativamente até encontrar o token <END>
for i in range(200):
# Passa a imagem e o contexto para a rede neural e obtém o índice do token com maior probabilidade
index = torch.argmax(net(image, context), 2).squeeze()[-1:].squeeze()
# Converte o índice para o token correspondente
token = vocab.to_vocab(int(index))
# Se encontrar o token <END>, interrompe a geração do código
if token == '<END>':
break
# Adiciona o token à lista de código gerado
code.append(token)
# Atualiza o contexto com o token gerado
context = torch.cat([context, torch.tensor([vocab.to_vec(token)]).unsqueeze(0).float().cuda()], dim=1)
# Retorna o código gerado como uma string
return ''.join(code)
import gradio as gr
# Define o componente de entrada
image_input = gr.inputs.Image()
# Define o componente de saída
text_output = gr.outputs.Textbox()
# Cria a interface Gradio
iface = gr.Interface(
fn=generate_gui,
inputs=image_input,
outputs=text_output,
title='Pix2Code',
description='Gerador de código GUI a partir de imagens',
theme='default'
)
# Executa a interface
iface.launch()