test-gpt-omni / app.py
TuringsSolutions's picture
Update app.py
5e0126f verified
raw
history blame
3.87 kB
import time
from threading import Thread
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
# Model Configuration
model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
print("Loading model...")
processor = AutoProcessor.from_pretrained(model_id)
# Adjusted model loading to use Accelerate's `device_map`
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto" # Uses the Accelerate library for efficient memory usage
)
print("Model loaded successfully!")
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<img src="https://cdn-uploads.huggingface.co/production/uploads/64ccdc322e592905f922a06e/DDIW0kbWmdOQWwy4XMhwX.png"
style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA-Llama-3-8B</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">
Llava-Llama-3-8B is fine-tuned from Meta-Llama-3-8B-Instruct and CLIP-ViT-Large-patch14-336
using ShareGPT4V-PT and InternVL-SFT by XTuner.
</p>
</div>
"""
def bot_streaming(message, history):
"""Handles message processing with image and text streaming."""
try:
image = None
# Extract image from message or history
if message["files"]:
image = message["files"][-1]["path"] if isinstance(message["files"][-1], dict) else message["files"][-1]
else:
for hist in history:
if isinstance(hist[0], tuple):
image = hist[0][0]
if not image:
return "Error: Please upload an image for LLaVA to work."
# Prepare inputs
image = Image.open(image)
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|>"
inputs = processor(prompt, image, return_tensors="pt").to(model.device, dtype=torch.float16)
# Stream text generation
streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
time.sleep(0.5) # Allow some time for initial generation
# Stream the generated response
for new_text in streamer:
if "<|eot_id|>" in new_text:
new_text = new_text.split("<|eot_id|>")[0]
buffer += new_text
yield buffer
except Exception as e:
yield f"Error: {str(e)}"
# Define Gradio interface components
chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1)
chat_input = gr.MultimodalTextbox(
interactive=True, file_types=["image"], placeholder="Enter message or upload a file...", show_label=False
)
with gr.Blocks(fill_height=True) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="LLaVA Llama-3-8B",
examples=[
{"text": "What is on the flower?", "files": ["./bee.jpg"]},
{"text": "How to make this pastry?", "files": ["./baklava.png"]}
],
description=(
"Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). "
"Upload an image and start chatting about it, or simply try one of the examples below. "
"If you don't upload an image, you will receive an error."
),
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
# Launch the Gradio app
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)