import gradio as gr import spaces import time from PIL import Image from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration from typing import List processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-llava-7b-v1.1") model = LlavaForConditionalGeneration.from_pretrained("TIGER-Lab/Mantis-llava-7b-v1.1") @spaces.GPU def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs): global processor, model model = model.to("cuda") if not images: images = None for text, history in chat_mllava(text, images, model, processor, history=history, stream=True, **kwargs): yield text return text def enable_next_image(uploaded_images, image): uploaded_images.append(image) return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False) def add_message(history, message): if message["files"]: for file in message["files"]: history.append([(file,), None]) if message["text"]: history.append([message["text"], None]) return history, gr.MultimodalTextbox(value=None) def print_like_dislike(x: gr.LikeData): print(x.index, x.value, x.liked) def get_chat_history(history): chat_history = [] for i, message in enumerate(history): if isinstance(message[0], str): chat_history.append({"role": "user", "text": message[0]}) if i != len(history) - 1: assert message[1], "The bot message is not provided, internal error" chat_history.append({"role": "assistant", "text": message[1]}) else: assert not message[1], "the bot message internal error, get: {}".format(message[1]) chat_history.append({"role": "assistant", "text": ""}) return chat_history def get_chat_images(history): images = [] for message in history: if isinstance(message[0], tuple): images.extend(message[0]) return images def bot(history): print(history) cur_messages = {"text": "", "images": []} for message in history[::-1]: if message[1]: break if isinstance(message[0], str): cur_messages["text"] = message[0] + " " + cur_messages["text"] elif isinstance(message[0], tuple): cur_messages["images"].extend(message[0]) cur_messages["text"] = cur_messages["text"].strip() cur_messages["images"] = cur_messages["images"][::-1] if not cur_messages["text"]: raise gr.Error("Please enter a message") if cur_messages['text'].count("") < len(cur_messages['images']): gr.Warning("The number of images uploaded is more than the number of placeholders in the text. Will automatically prepend to the text.") cur_messages['text'] = " "* (len(cur_messages['images']) - cur_messages['text'].count("")) + cur_messages['text'] history[-1][0] = cur_messages["text"] if cur_messages['text'].count("") > len(cur_messages['images']): gr.Warning("The number of images uploaded is less than the number of placeholders in the text. Will automatically remove extra placeholders from the text.") cur_messages['text'] = cur_messages['text'][::-1].replace(""[::-1], "", cur_messages['text'].count("") - len(cur_messages['images']))[::-1] history[-1][0] = cur_messages["text"] chat_history = get_chat_history(history) chat_images = get_chat_images(history) generation_kwargs = { "max_new_tokens": 4096, "temperature": 0.7, "top_p": 1.0, "do_sample": True, } print(None, chat_images, chat_history, generation_kwargs) response = generate(None, chat_images, chat_history, **generation_kwargs) for _output in response: history[-1][1] = _output time.sleep(0.05) yield history def build_demo(): with gr.Blocks() as demo: gr.Markdown(""" # Mantis Mantis is a multimodal conversational AI model that can chat with users about images and text. It's optimized for multi-image reasoning, where inverleaved text and images can be used to generate responses. """) chatbot = gr.Chatbot(line_breaks=True) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use to indicate the position of uploaded images", show_label=True) chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response") chatbot.like(print_like_dislike, None, None) with gr.Row(): send_button = gr.Button("Send") clear_button = gr.ClearButton([chatbot, chat_input]) send_button.click( add_message, [chatbot, chat_input], [chatbot, chat_input] ).then( bot, chatbot, chatbot, api_name="bot_response" ) return demo if __name__ == "__main__": demo = build_demo() demo.launch()