toandev's picture
Update app.py
6800630 verified
import time
from threading import Thread
from typing import Dict, List
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import (
AutoProcessor,
MllamaForConditionalGeneration,
TextIteratorStreamer,
)
# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT = "toandev/Viet-Receipt-Llama-3.2-11B-Vision-Instruct"
# Model initialization
model = MllamaForConditionalGeneration.from_pretrained(
CHECKPOINT, torch_dtype=torch.bfloat16
).to(DEVICE)
processor = AutoProcessor.from_pretrained(CHECKPOINT)
def process_chat_history(history: List) -> tuple[List[Dict], List[Image.Image]]:
"""
Process chat history to extract messages and images.
Args:
history: List of chat messages
Returns:
Tuple containing processed messages and images
"""
messages = []
images = []
for i, msg in enumerate(history):
if isinstance(msg[0], tuple):
messages.extend(
[
{
"role": "user",
"content": [
{"type": "text", "text": history[i + 1][0]},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": history[i + 1][1]}],
},
]
)
images.append(Image.open(msg[0][0]).convert("RGB"))
elif isinstance(history[i - 1], tuple) and isinstance(msg[0], str):
continue
elif isinstance(history[i - 1][0], str) and isinstance(msg[0], str):
messages.extend(
[
{"role": "user", "content": [{"type": "text", "text": msg[0]}]},
{
"role": "assistant",
"content": [{"type": "text", "text": msg[1]}],
},
]
)
return messages, images
@spaces.GPU
def bot_streaming(message: Dict, history: List, max_new_tokens: int = 250):
"""
Generate streaming responses for the chatbot.
Args:
message: Current message containing text and files
history: Chat history
max_new_tokens: Maximum number of tokens to generate
Yields:
Generated text buffer
"""
text = message["text"]
messages, images = process_chat_history(history)
# Handle current message
if len(message["files"]) == 1:
image = (
Image.open(message["files"][0])
if isinstance(message["files"][0], str)
else Image.open(message["files"][0]["path"])
).convert("RGB")
images.append(image)
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": text}, {"type": "image"}],
}
)
else:
messages.append({"role": "user", "content": [{"type": "text", "text": text}]})
# Process inputs
texts = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = (
processor(text=texts, images=images, return_tensors="pt")
if images
else processor(text=texts, return_tensors="pt")
).to(DEVICE)
# Setup streaming
streamer = TextIteratorStreamer(
processor, skip_special_tokens=True, skip_prompt=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
demo = gr.ChatInterface(
fn=bot_streaming,
textbox=gr.MultimodalTextbox(placeholder="Ask me anything..."),
additional_inputs=[
gr.Slider(
minimum=10,
maximum=500,
value=250,
step=10,
label="Maximum number of new tokens to generate",
)
],
examples=[
[
{
"text": "What is the total amount in this bill?",
"files": ["./examples/01.jpg"],
},
200,
],
[
{
"text": "What is the name of the restaurant in this bill?",
"files": ["./examples/02.jpg"],
},
200,
],
],
cache_examples=False,
stop_btn="Stop",
fill_height=True,
multimodal=True,
type="messages",
)
if __name__ == "__main__":
demo.launch(debug=True)