File size: 4,684 Bytes
abdf424 1dfce6a 6800630 abdf424 20de438 abdf424 20de438 abdf424 20de438 abdf424 686ef17 abdf424 20de438 abdf424 20de438 abdf424 20de438 686ef17 20de438 686ef17 abdf424 20de438 686ef17 20de438 abdf424 20de438 686ef17 abdf424 20de438 6800630 1dfce6a abdf424 686ef17 20de438 abdf424 20de438 686ef17 abdf424 686ef17 abdf424 686ef17 abdf424 686ef17 abdf424 20de438 686ef17 20de438 686ef17 20de438 abdf424 686ef17 20de438 1dfce6a 20de438 686ef17 1dfce6a 20de438 1dfce6a 20de438 1dfce6a 20de438 abdf424 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import time
from threading import Thread
from typing import Dict, List
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import (
AutoProcessor,
MllamaForConditionalGeneration,
TextIteratorStreamer,
)
# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT = "toandev/Viet-Receipt-Llama-3.2-11B-Vision-Instruct"
# Model initialization
model = MllamaForConditionalGeneration.from_pretrained(
CHECKPOINT, torch_dtype=torch.bfloat16
).to(DEVICE)
processor = AutoProcessor.from_pretrained(CHECKPOINT)
def process_chat_history(history: List) -> tuple[List[Dict], List[Image.Image]]:
"""
Process chat history to extract messages and images.
Args:
history: List of chat messages
Returns:
Tuple containing processed messages and images
"""
messages = []
images = []
for i, msg in enumerate(history):
if isinstance(msg[0], tuple):
messages.extend(
[
{
"role": "user",
"content": [
{"type": "text", "text": history[i + 1][0]},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": history[i + 1][1]}],
},
]
)
images.append(Image.open(msg[0][0]).convert("RGB"))
elif isinstance(history[i - 1], tuple) and isinstance(msg[0], str):
continue
elif isinstance(history[i - 1][0], str) and isinstance(msg[0], str):
messages.extend(
[
{"role": "user", "content": [{"type": "text", "text": msg[0]}]},
{
"role": "assistant",
"content": [{"type": "text", "text": msg[1]}],
},
]
)
return messages, images
@spaces.GPU
def bot_streaming(message: Dict, history: List, max_new_tokens: int = 250):
"""
Generate streaming responses for the chatbot.
Args:
message: Current message containing text and files
history: Chat history
max_new_tokens: Maximum number of tokens to generate
Yields:
Generated text buffer
"""
text = message["text"]
messages, images = process_chat_history(history)
# Handle current message
if len(message["files"]) == 1:
image = (
Image.open(message["files"][0])
if isinstance(message["files"][0], str)
else Image.open(message["files"][0]["path"])
).convert("RGB")
images.append(image)
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": text}, {"type": "image"}],
}
)
else:
messages.append({"role": "user", "content": [{"type": "text", "text": text}]})
# Process inputs
texts = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = (
processor(text=texts, images=images, return_tensors="pt")
if images
else processor(text=texts, return_tensors="pt")
).to(DEVICE)
# Setup streaming
streamer = TextIteratorStreamer(
processor, skip_special_tokens=True, skip_prompt=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
demo = gr.ChatInterface(
fn=bot_streaming,
textbox=gr.MultimodalTextbox(placeholder="Ask me anything..."),
additional_inputs=[
gr.Slider(
minimum=10,
maximum=500,
value=250,
step=10,
label="Maximum number of new tokens to generate",
)
],
examples=[
[
{
"text": "What is the total amount in this bill?",
"files": ["./examples/01.jpg"],
},
200,
],
[
{
"text": "What is the name of the restaurant in this bill?",
"files": ["./examples/02.jpg"],
},
200,
],
],
cache_examples=False,
stop_btn="Stop",
fill_height=True,
multimodal=True,
type="messages",
)
if __name__ == "__main__":
demo.launch(debug=True)
|