abhaskumarsinha's picture
Update app.py
6f8dc18
import tensorflow as tf
from GPT import generate_output
import numpy as np
import gradio as gr
from MinimalGPT_2 import MinimalGPT
from subword.apply_bpe import BPE
codec = open('./subword/dataset/codec.txt', encoding='utf-8')
bpe = BPE(codec)
output = MinimalGPT(data_path='./subword/dataset/codec.txt',
learning_rate=0,
output_length=11,
epochs = 0,
batch_size = 1,
gpt_input=10,
d_model=128,
h=8,
decoder_stacks=1,
starting_chunk = 0,
ending_chunk = 0,
chunk_size = 1,
vocabulary_start = 0,
vocabulary_end = 0,
save=False,
load_tokenizer='./tokenizer.mgt',
load_weights='./weights.mgw',
save_tokenizer=None,
save_weights=None,
optimizer=None,
inference_only = False,
return_model_and_vectorizer = True,
return_model_and_vectorizer_and_output = False,
GPT_attention = True,
TPU = False)
model, vectorizer = output[0], output[1]
def pad_or_slice_tensor(tensor):
length = tf.shape(tensor)[0]
def pad_tensor():
num_zeros = 10 - length
zeros = tf.zeros((num_zeros,), dtype=tensor.dtype)
padded_tensor = tf.concat([zeros, tensor], axis=0)
return padded_tensor
def slice_tensor():
sliced_tensor = tensor[-10:]
return sliced_tensor
padded_or_sliced_tensor = tf.cond(
tf.less(length, 10),
pad_tensor,
slice_tensor
)
return padded_or_sliced_tensor
def generate_text(input_text):
tokens = bpe.process_line(input_text)
tokens = vectorizer(tokens)[:10]
tokens = pad_or_slice_tensor(tokens)
tokens = tf.reshape(tokens, (1, 10))
output_text = generate_output(model, vectorizer, input_sequence = tokens.numpy(), gpt_input = 10)
return output_text.replace('@@ ', '')
iface = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
title="Text Generation",
description="Generate text using Huggingface Transformers",
theme=gr.themes.Soft()
)
iface.launch()