not-lain's picture
Update app.py
b7649c8 verified
import spaces
import os
import re
import gradio as gr
import torch
from transformers import AutoModelForCausalLM
from transformers import TextIteratorStreamer
from threading import Thread
from PIL import Image
model_name = 'AIDC-AI/Ovis1.6-Gemma2-9B'
# load model
model = AutoModelForCausalLM.from_pretrained(model_name,
torch_dtype=torch.bfloat16,
multimodal_max_length=8192,
trust_remote_code=True).to(device='cuda')
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()
streamer = TextIteratorStreamer(text_tokenizer, skip_prompt=True, skip_special_tokens=True)
image_placeholder = '<image>'
cur_dir = os.path.dirname(os.path.abspath(__file__))
@spaces.GPU
def ovis_chat(message, history,hist=[]):
# workaround for API
if hist != history :
history = hist
try :
image_input = Image.open(message["files"][0]).convert("RGB")
new_image = True
except :
image_input = None
new_image = False
# preprocess inputs
conversations = []
response = ""
text_input = message["text"]
for msg in history:
# case history entry pair only has text
if isinstance(msg[0],str):
conversations.append({
"from": "human",
"value": msg[0]
})
conversations.append({
"from": "gpt",
"value": msg[1]
})
# case history pair has an image
elif isinstance(msg[0],tuple):
# case history pair is an image and user did not pass a new image
# we override the none with the history image
if new_image is False :
# always aim for the latest image in the history
image_input = Image.open(msg[0][0]).convert("RGB")
text_input = text_input.replace(image_placeholder, '')
conversations.append({
"from": "human",
"value": text_input
})
if image_input is not None:
conversations[0]["value"] = image_placeholder + '\n' + conversations[0]["value"]
prompt, input_ids, pixel_values = model.preprocess_inputs(conversations, [image_input])
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
input_ids = input_ids.unsqueeze(0).to(device=model.device)
attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
if image_input is None:
pixel_values = [None]
else:
pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]
with torch.inference_mode():
gen_kwargs = dict(
max_new_tokens=512,
do_sample=False,
top_p=None,
top_k=None,
temperature=None,
repetition_penalty=None,
eos_token_id=model.generation_config.eos_token_id,
pad_token_id=text_tokenizer.pad_token_id,
use_cache=True
)
response = ""
thread = Thread(target=model.generate,
kwargs={"inputs": input_ids,
"pixel_values": pixel_values,
"attention_mask": attention_mask,
"streamer": streamer,
**gen_kwargs})
thread.start()
for new_text in streamer:
response += new_text
yield response
thread.join()
def clear_chat():
return [], None, ""
with open(f"{cur_dir}/resource/logo.svg", "r", encoding="utf-8") as svg_file:
svg_content = svg_file.read()
font_size = "2.5em"
svg_content = re.sub(r'(<svg[^>]*)(>)', rf'\1 height="{font_size}" style="vertical-align: middle; display: inline-block;"\2', svg_content)
html = f"""
<p align="center" style="font-size: {font_size}; line-height: 1;">
<span style="display: inline-block; vertical-align: middle;">{svg_content}</span>
<span style="display: inline-block; vertical-align: middle;">{model_name.split('/')[-1]}</span>
</p>
<center><font size=3><b>Ovis</b> has been open-sourced on <a href='https://huggingface.co/{model_name}'>😊 Huggingface</a> and <a href='https://github.com/AIDC-AI/Ovis'>🌟 GitHub</a>. If you find Ovis useful, a like❤️ or a star🌟 would be appreciated.</font></center>
"""
latex_delimiters_set = [{
"left": "\\(",
"right": "\\)",
"display": False
}, {
"left": "\\begin{equation}",
"right": "\\end{equation}",
"display": True
}, {
"left": "\\begin{align}",
"right": "\\end{align}",
"display": True
}, {
"left": "\\begin{alignat}",
"right": "\\end{alignat}",
"display": True
}, {
"left": "\\begin{gather}",
"right": "\\end{gather}",
"display": True
}, {
"left": "\\begin{CD}",
"right": "\\end{CD}",
"display": True
}, {
"left": "\\[",
"right": "\\]",
"display": True
}]
hist= gr.Chatbot()
demo = gr.ChatInterface(fn=ovis_chat, textbox=gr.MultimodalTextbox(),multimodal=True,additional_inputs=hist)
demo.launch(debug=True)