import time from threading import Thread from typing import Dict, List import gradio as gr import spaces import torch from PIL import Image from transformers import ( AutoProcessor, MllamaForConditionalGeneration, TextIteratorStreamer, ) # Constants DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CHECKPOINT = "toandev/Viet-Receipt-Llama-3.2-11B-Vision-Instruct" # Model initialization model = MllamaForConditionalGeneration.from_pretrained( CHECKPOINT, torch_dtype=torch.bfloat16 ).to(DEVICE) processor = AutoProcessor.from_pretrained(CHECKPOINT) def process_chat_history(history: List) -> tuple[List[Dict], List[Image.Image]]: """ Process chat history to extract messages and images. Args: history: List of chat messages Returns: Tuple containing processed messages and images """ messages = [] images = [] for i, msg in enumerate(history): if isinstance(msg[0], tuple): messages.extend( [ { "role": "user", "content": [ {"type": "text", "text": history[i + 1][0]}, {"type": "image"}, ], }, { "role": "assistant", "content": [{"type": "text", "text": history[i + 1][1]}], }, ] ) images.append(Image.open(msg[0][0]).convert("RGB")) elif isinstance(history[i - 1], tuple) and isinstance(msg[0], str): continue elif isinstance(history[i - 1][0], str) and isinstance(msg[0], str): messages.extend( [ {"role": "user", "content": [{"type": "text", "text": msg[0]}]}, { "role": "assistant", "content": [{"type": "text", "text": msg[1]}], }, ] ) return messages, images @spaces.GPU def bot_streaming(message: Dict, history: List, max_new_tokens: int = 250): """ Generate streaming responses for the chatbot. Args: message: Current message containing text and files history: Chat history max_new_tokens: Maximum number of tokens to generate Yields: Generated text buffer """ text = message["text"] messages, images = process_chat_history(history) # Handle current message if len(message["files"]) == 1: image = ( Image.open(message["files"][0]) if isinstance(message["files"][0], str) else Image.open(message["files"][0]["path"]) ).convert("RGB") images.append(image) messages.append( { "role": "user", "content": [{"type": "text", "text": text}, {"type": "image"}], } ) else: messages.append({"role": "user", "content": [{"type": "text", "text": text}]}) # Process inputs texts = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = ( processor(text=texts, images=images, return_tensors="pt") if images else processor(text=texts, return_tensors="pt") ).to(DEVICE) # Setup streaming streamer = TextIteratorStreamer( processor, skip_special_tokens=True, skip_prompt=True ) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text time.sleep(0.01) yield buffer demo = gr.ChatInterface( fn=bot_streaming, textbox=gr.MultimodalTextbox(placeholder="Ask me anything..."), additional_inputs=[ gr.Slider( minimum=10, maximum=500, value=250, step=10, label="Maximum number of new tokens to generate", ) ], examples=[ [ { "text": "What is the total amount in this bill?", "files": ["./examples/01.jpg"], }, 200, ], [ { "text": "What is the name of the restaurant in this bill?", "files": ["./examples/02.jpg"], }, 200, ], ], cache_examples=False, stop_btn="Stop", fill_height=True, multimodal=True, type="messages", ) if __name__ == "__main__": demo.launch(debug=True)