|
import tensorflow as tf |
|
from GPT import generate_output |
|
import numpy as np |
|
import gradio as gr |
|
from MinimalGPT_2 import MinimalGPT |
|
|
|
from subword.apply_bpe import BPE |
|
|
|
codec = open('./subword/dataset/codec.txt', encoding='utf-8') |
|
|
|
bpe = BPE(codec) |
|
|
|
output = MinimalGPT(data_path='./subword/dataset/codec.txt', |
|
learning_rate=0, |
|
output_length=11, |
|
epochs = 0, |
|
batch_size = 1, |
|
gpt_input=10, |
|
d_model=128, |
|
h=8, |
|
decoder_stacks=1, |
|
starting_chunk = 0, |
|
ending_chunk = 0, |
|
chunk_size = 1, |
|
vocabulary_start = 0, |
|
vocabulary_end = 0, |
|
save=False, |
|
load_tokenizer='./tokenizer.mgt', |
|
load_weights='./weights.mgw', |
|
save_tokenizer=None, |
|
save_weights=None, |
|
optimizer=None, |
|
inference_only = False, |
|
return_model_and_vectorizer = True, |
|
return_model_and_vectorizer_and_output = False, |
|
GPT_attention = True, |
|
TPU = False) |
|
|
|
model, vectorizer = output[0], output[1] |
|
|
|
def pad_or_slice_tensor(tensor): |
|
length = tf.shape(tensor)[0] |
|
|
|
def pad_tensor(): |
|
num_zeros = 10 - length |
|
zeros = tf.zeros((num_zeros,), dtype=tensor.dtype) |
|
padded_tensor = tf.concat([zeros, tensor], axis=0) |
|
return padded_tensor |
|
|
|
def slice_tensor(): |
|
sliced_tensor = tensor[-10:] |
|
return sliced_tensor |
|
|
|
padded_or_sliced_tensor = tf.cond( |
|
tf.less(length, 10), |
|
pad_tensor, |
|
slice_tensor |
|
) |
|
|
|
return padded_or_sliced_tensor |
|
|
|
def generate_text(input_text): |
|
tokens = bpe.process_line(input_text) |
|
tokens = vectorizer(tokens)[:10] |
|
tokens = pad_or_slice_tensor(tokens) |
|
tokens = tf.reshape(tokens, (1, 10)) |
|
output_text = generate_output(model, vectorizer, input_sequence = tokens.numpy(), gpt_input = 10) |
|
return output_text.replace('@@ ', '') |
|
|
|
iface = gr.Interface( |
|
fn=generate_text, |
|
inputs="text", |
|
outputs="text", |
|
title="Text Generation", |
|
description="Generate text using Huggingface Transformers", |
|
theme=gr.themes.Soft() |
|
) |
|
|
|
iface.launch() |