wop's picture
Update app.py
48b3788 verified
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import gradio as gr
# Check if a GPU is available and use it, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pre-trained model and tokenizer from the saved directory
model_path = "Blexus/Quble_Test_Model_v1_Pretrain"
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
# Set model to evaluation mode
model.eval()
# Function to generate text in a stream-based manner
def generate_text(prompt):
# Tokenize and encode the input prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
max_length = 50 # Maximum length of generated text
# Generate continuation with streaming tokens
with torch.no_grad():
for generated_ids in model.generate(
input_ids,
max_length=max_length,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_k=50,
top_p=0.95,
output_scores=True, # Include scores for sampling
return_dict_in_generate=True,
use_cache=True
).sequences:
# Decode each step incrementally
decoded_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
yield decoded_text # Stream the partial text back to the UI
# Create a Gradio interface with streaming enabled
interface = gr.Interface(
fn=generate_text, # Function to call when interacting with the UI
inputs="text", # Input type: Single-line text
outputs=gr.Markdown(), # Stream output using Markdown
title="Quble Text Generation", # Title of the UI
description="Enter a prompt to generate text using Quble with live streaming." # Simple description
)
# Launch the Gradio app
interface.launch()