Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import gradio as gr | |
from transformers import AutoProcessor, Idefics2ForConditionalGeneration, TextIteratorStreamer | |
from threading import Thread | |
import re | |
import time | |
from PIL import Image | |
import torch | |
import spaces | |
PROCESSOR = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b") | |
model = Idefics2ForConditionalGeneration.from_pretrained( | |
"HuggingFaceM4/idefics2-8b", | |
torch_dtype=torch.bfloat16, | |
_attn_implementation="flash_attention_2", | |
trust_remote_code=True).to("cuda") | |
def turn_is_pure_media(turn): | |
return turn[1] is None | |
def format_user_prompt_with_im_history_and_system_conditioning( | |
user_prompt, chat_history | |
): | |
""" | |
Produces the resulting list that needs to go inside the processor. | |
It handles the potential image(s), the history and the system conditionning. | |
""" | |
resulting_messages = copy.deepcopy([]) | |
resulting_images = [] | |
# Format history | |
for turn in chat_history: | |
if not resulting_messages or (resulting_messages and resulting_messages[-1]["role"] != "user"): | |
resulting_messages.append( | |
{ | |
"role": "user", | |
"content": [], | |
} | |
) | |
if turn_is_pure_media(turn): | |
media = turn[0][0] | |
resulting_messages[-1]["content"].append({"type": "image"}) | |
resulting_images.append(Image.open(media)) | |
else: | |
user_utterance, assistant_utterance = turn | |
resulting_messages[-1]["content"].append( | |
{"type": "text", "text": user_utterance.strip()} | |
) | |
resulting_messages.append( | |
{ | |
"role": "assistant", | |
"content": [ | |
{"type": "text", "text": user_utterance.strip()} | |
] | |
} | |
) | |
# Format current input | |
if not user_prompt["files"]: | |
resulting_messages.append( | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": user_prompt['text']} | |
], | |
} | |
) | |
else: | |
# Choosing to put the image first (i.e. before the text), but this is an arbiratrary choice. | |
resulting_messages.append( | |
{ | |
"role": "user", | |
"content": [{"type": "image"}] * len(user_prompt['files']) + [ | |
{"type": "text", "text": user_prompt['text']} | |
] | |
} | |
) | |
for im in user_prompt["files"]: | |
print(im) | |
if isinstance(im, str): | |
resulting_images.extend([Image.open(im)]) | |
elif isinstance(im, dict): | |
resulting_images.extend([Image.open(im['path'])]) | |
return resulting_messages, resulting_images | |
def extract_images_from_msg_list(msg_list): | |
all_images = [] | |
for msg in msg_list: | |
for c_ in msg["content"]: | |
if isinstance(c_, Image.Image): | |
all_images.append(c_) | |
return all_images | |
def model_inference( | |
user_prompt, | |
chat_history, | |
decoding_strategy, | |
temperature, | |
max_new_tokens, | |
repetition_penalty, | |
top_p, | |
): | |
if user_prompt["text"].strip() == "" and not user_prompt["files"]: | |
gr.Error("Please input a query and optionally image(s).") | |
if user_prompt["text"].strip() == "" and user_prompt["files"]: | |
gr.Error("Please input a text query along the image(s).") | |
streamer = TextIteratorStreamer( | |
PROCESSOR.tokenizer, | |
skip_prompt=True, | |
timeout=5., | |
) | |
# Common parameters to all decoding strategies | |
# This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies | |
generation_args = { | |
"max_new_tokens": max_new_tokens, | |
"repetition_penalty": repetition_penalty, | |
"streamer": streamer, | |
} | |
assert decoding_strategy in [ | |
"Greedy", | |
"Top P Sampling", | |
] | |
if decoding_strategy == "Greedy": | |
generation_args["do_sample"] = False | |
elif decoding_strategy == "Top P Sampling": | |
generation_args["temperature"] = temperature | |
generation_args["do_sample"] = True | |
generation_args["top_p"] = top_p | |
# Creating model inputs | |
resulting_text, resulting_images = format_user_prompt_with_im_history_and_system_conditioning( | |
user_prompt=user_prompt, | |
chat_history=chat_history, | |
) | |
prompt = PROCESSOR.apply_chat_template(resulting_text, add_generation_prompt=True) | |
inputs = PROCESSOR(text=prompt, images=resulting_images if resulting_images else None, return_tensors="pt") | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
generation_args.update(inputs) | |
thread = Thread( | |
target=model.generate, | |
kwargs=generation_args, | |
) | |
thread.start() | |
print("Start generating") | |
acc_text = "" | |
for text_token in streamer: | |
time.sleep(0.04) | |
acc_text += text_token | |
if acc_text.endswith("<end_of_utterance>"): | |
acc_text = acc_text[:-18] | |
yield acc_text | |
print("Success - generated the following text:", acc_text) | |
print("-----") | |
BOT_AVATAR = "IDEFICS_logo.png" | |
# Hyper-parameters for generation | |
max_new_tokens = gr.Slider( | |
minimum=8, | |
maximum=1024, | |
value=512, | |
step=1, | |
interactive=True, | |
label="Maximum number of new tokens to generate", | |
) | |
repetition_penalty = gr.Slider( | |
minimum=0.01, | |
maximum=5.0, | |
value=1.2, | |
step=0.01, | |
interactive=True, | |
label="Repetition penalty", | |
info="1.0 is equivalent to no penalty", | |
) | |
decoding_strategy = gr.Radio( | |
[ | |
"Greedy", | |
"Top P Sampling", | |
], | |
value="Greedy", | |
label="Decoding strategy", | |
interactive=True, | |
info="Higher values is equivalent to sampling more low-probability tokens.", | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=5.0, | |
value=0.4, | |
step=0.1, | |
interactive=True, | |
label="Sampling temperature", | |
info="Higher values will produce more diverse outputs.", | |
) | |
top_p = gr.Slider( | |
minimum=0.01, | |
maximum=0.99, | |
value=0.8, | |
step=0.01, | |
interactive=True, | |
label="Top P", | |
info="Higher values is equivalent to sampling more low-probability tokens.", | |
) | |
chatbot = gr.Chatbot( | |
label="Idefics2", | |
avatar_images=[None, BOT_AVATAR], | |
# height=750, | |
) | |
with gr.Blocks(fill_height=True, css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img { width: auto; max-width: 30%; height: auto; max-height: 30%; }") as demo: | |
decoding_strategy.change( | |
fn=lambda selection: gr.Slider( | |
visible=( | |
selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] | |
) | |
), | |
inputs=decoding_strategy, | |
outputs=temperature, | |
) | |
decoding_strategy.change( | |
fn=lambda selection: gr.Slider( | |
visible=( | |
selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] | |
) | |
), | |
inputs=decoding_strategy, | |
outputs=repetition_penalty, | |
) | |
decoding_strategy.change( | |
fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])), | |
inputs=decoding_strategy, | |
outputs=top_p, | |
) | |
examples = [{"text": "How many items are sold?", "files":["./example_images/docvqa_example.png"]}, | |
{"text": "What is this UI about?", "files":["./example_images/s2w_example.png"]}, | |
{"text": "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips.", "files":["./example_images/travel_tips.jpg"]}, | |
{"text": "Can you tell me a very short story based on this image?", "files":["./example_images/chicken_on_money.png"]}, | |
{"text": "Where is this pastry from?", "files":["./example_images/baklava.png"]}, | |
{"text": "How much percent is the order status?", "files":["./example_images/dummy_pdf.png"]}, | |
{"text":"As an art critic AI assistant, could you describe this painting in details and make a thorough critic?.", "files":["./example_images/art_critic.jpg"]} | |
] | |
description = "Try [IDEFICS2-8B](https://huggingface.co/HuggingFaceM4/idefics2-8b), the instruction fine-tuned IDEFICS2 in this demo. 💬 IDEFICS2 is a state-of-the-art vision language model in various benchmarks. To get started, upload an image and write a text prompt or try one of the examples. You can also play with advanced generation parameters. To learn more about IDEFICS2, read [the blog](https://huggingface.co/blog/idefics2). Note that this model is not as chatty as the upcoming chatty model, and it will give shorter answers." | |
gr.ChatInterface( | |
fn=model_inference, | |
chatbot=chatbot, | |
examples=examples, | |
description=description, | |
title="Idefics2 Playground 🐶 ", | |
multimodal=True, | |
additional_inputs=[decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p], | |
) | |
demo.launch(debug=True) |