import gradio as gr import spaces import time import torch from PIL import Image from transformers import AutoProcessor, AutoModelForVision2Seq from transformers.image_utils import load_image from typing import List processor = AutoProcessor.from_pretrained("TIGER-Lab/Mantis-8B-Idefics2") model = AutoModelForVision2Seq.from_pretrained("TIGER-Lab/Mantis-8B-Idefics2", torch_dtype=torch.bfloat16) @spaces.GPU def generate_stream(text:str, images:List[Image.Image], history: List[dict], **kwargs): global processor, model model.to("cuda") if not images: images = None prompt = processor.apply_chat_template(history, add_generation_prompt=True) print("Prompt: ") print(prompt) print("Images: ") print(images) inputs = processor(text=prompt, images=images, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} from transformers import TextIteratorStreamer from threading import Thread streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) kwargs["streamer"] = streamer inputs.update(kwargs) thread = Thread(target=model.generate, kwargs=inputs) thread.start() output = "" for _output in streamer: output += _output yield output def enable_next_image(uploaded_images, image): uploaded_images.append(image) return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False) def add_message(history, message): if message["files"]: for file in message["files"]: history.append([(file,), None]) if message["text"]: history.append([message["text"], None]) return history, gr.MultimodalTextbox(value=None) def print_like_dislike(x: gr.LikeData): print(x.index, x.value, x.liked) def get_chat_images(history): images = [] for message in history: if isinstance(message[0], tuple): image = load_image(message[0][0]) images.append(image) return images def get_chat_history(history): images = get_chat_images(history) messages = [] cur_image_idx = 0 for i, message in enumerate(history): if isinstance(message[0], str): num_images = message[0].count("") messages.append( { "role": "user", "content": [] } ) print(num_images, cur_image_idx, len(images)) assert num_images + cur_image_idx <= len(images), f"Number of images uploaded is less than the number of placeholders in the text. Please upload more images." if num_images > 0: split_text = message[0].split("") if split_text[0].strip(): messages[-1]["content"].append({"type": "text", "text": split_text[0].strip()}) for idx in range(num_images): messages[-1]["content"].append({"type": "image"}) if split_text[idx + 1].strip(): messages[-1]["content"].append({"type": "text", "text": split_text[idx + 1].strip()}) else: messages[-1]["content"].append({"type": "text", "text": message[0]}) if message[1]: messages.append( { "role": "assistant", "content": [{"type": "text", "text": message[1]}] } ) elif isinstance(message[0], tuple): pass return messages, images def bot(history): cur_messages = {"text": "", "images": []} for message in history[::-1]: if message[1]: break if isinstance(message[0], str): cur_messages["text"] = message[0] + " " + cur_messages["text"] elif isinstance(message[0], tuple): cur_messages["images"].extend(message[0]) cur_messages["text"] = cur_messages["text"].strip() cur_messages["images"] = cur_messages["images"][::-1] if not cur_messages["text"]: raise gr.Error("Please enter a message") if cur_messages['text'].count("") < len(cur_messages['images']): gr.Warning("The number of images uploaded is more than the number of placeholders in the text. Will automatically prepend to the text.") cur_messages['text'] = " "* (len(cur_messages['images']) - cur_messages['text'].count("")) + cur_messages['text'] history[-1][0] = cur_messages["text"] if cur_messages['text'].count("") > len(cur_messages['images']): gr.Warning("The number of images uploaded is less than the number of placeholders in the text. Will automatically remove extra placeholders from the text.") cur_messages['text'] = cur_messages['text'][::-1].replace(""[::-1], "", cur_messages['text'].count("") - len(cur_messages['images']))[::-1] history[-1][0] = cur_messages["text"] chat_history, chat_images = get_chat_history(history) generation_kwargs = { "max_new_tokens": 4096, "num_beams": 1, "do_sample": False } response = generate_stream(None, chat_images, chat_history, **generation_kwargs) for _output in response: history[-1][1] = _output time.sleep(0.05) yield history def build_demo(): with gr.Blocks() as demo: gr.Markdown(""" # Mantis Mantis is a multimodal conversational AI model that can chat with users about images and text. It's optimized for multi-image reasoning, where inverleaved text and images can be used to generate responses. ### [Paper](https://arxiv.org/abs/2405.01483) | [Github](https://github.com/TIGER-AI-Lab/Mantis) | [Models](https://huggingface.co/collections/TIGER-Lab/mantis-6619b0834594c878cdb1d6e4) | [Dataset](https://huggingface.co/datasets/TIGER-Lab/Mantis-Instruct) | [Website](https://tiger-ai-lab.github.io/Mantis/) """) gr.Markdown("""## Chat with Mantis Mantis supports interleaved text-image input format, where you can simply use the placeholder `` to indicate the position of uploaded images. The model is optimized for multi-image reasoning, while preserving the ability to chat about text and images in a single conversation. (The model currently serving is [🤗 TIGER-Lab/Mantis-8B-Idefics2](https://huggingface.co/TIGER-Lab/Mantis-8B-Idefics2)) """) chatbot = gr.Chatbot(line_breaks=True) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use to indicate the position of uploaded images", show_label=True) chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) """ with gr.Accordion(label='Advanced options', open=False): temperature = gr.Slider( label='Temperature', minimum=0.1, maximum=2.0, step=0.1, value=0.2, interactive=True ) top_p = gr.Slider( label='Top-p', minimum=0.05, maximum=1.0, step=0.05, value=1.0, interactive=True ) """ bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response") chatbot.like(print_like_dislike, None, None) with gr.Row(): send_button = gr.Button("Send") clear_button = gr.ClearButton([chatbot, chat_input]) send_button.click( add_message, [chatbot, chat_input], [chatbot, chat_input] ).then( bot, chatbot, chatbot, api_name="bot_response" ) gr.Examples( examples=[ { "text": " Which image shows a different mood of character from the others?", "files": ["./examples/image12.jpg", "./examples/image13.jpg", "./examples/image14.jpg"] }, { "text": " What's the difference between these two images? Please describe as much as you can.", "files": ["./examples/image1.jpg", "./examples/image2.jpg"] }, { "text": " Which image shows an older dog?", "files": ["./examples/image8.jpg", "./examples/image9.jpg"] }, { "text": "Write a description for the given image sequence in a single paragraph, what is happening in this episode?", "files": ["./examples/image3.jpg", "./examples/image4.jpg", "./examples/image5.jpg", "./examples/image6.jpg", "./examples/image7.jpg"] }, { "text": " How many dices are there in image 1 and image 2 respectively?", "files": ["./examples/image10.jpg", "./examples/image15.jpg"] }, ], inputs=[chat_input], ) gr.Markdown(""" ## Citation ``` @article{jiang2024mantis, title={MANTIS: Interleaved Multi-Image Instruction Tuning}, author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu}, journal={arXiv preprint arXiv:2405.01483}, year={2024} } ```""") return demo if __name__ == "__main__": demo = build_demo() demo.launch()