File size: 5,207 Bytes
1cea0e1
 
 
 
 
 
 
 
11eab42
1cea0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e99479
 
 
 
11eab42
 
34ac220
11eab42
 
34ac220
1cea0e1
 
 
11eab42
9d33c66
70f2494
 
11eab42
 
70f2494
11eab42
 
 
70f2494
11eab42
70f2494
 
b576ed1
 
34ac220
 
70f2494
 
1cea0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
014763e
1cea0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7649c8
 
70f2494
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import spaces
import os
import re
import gradio as gr
import torch
from transformers import AutoModelForCausalLM
from transformers import TextIteratorStreamer
from threading import Thread
from PIL import Image

model_name = 'AIDC-AI/Ovis1.6-Gemma2-9B'

# load model
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype=torch.bfloat16,
                                             multimodal_max_length=8192,
                                             trust_remote_code=True).to(device='cuda')
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()
streamer = TextIteratorStreamer(text_tokenizer, skip_prompt=True, skip_special_tokens=True)
image_placeholder = '<image>'
cur_dir = os.path.dirname(os.path.abspath(__file__))


@spaces.GPU
def ovis_chat(message, history,hist=[]):
    # workaround for API
    if hist != history : 
        history = hist 
    try :
        image_input = Image.open(message["files"][0]).convert("RGB")
        new_image = True
    except : 
        image_input = None
        new_image = False
    # preprocess inputs
    conversations = []
    response = ""
    text_input = message["text"]
    for msg in history:
        # case history entry pair only has text
        if isinstance(msg[0],str):
            conversations.append({
                "from": "human",
                "value": msg[0]
            })
            conversations.append({
                "from": "gpt",
                "value": msg[1]
            })
        # case history pair has an image
        elif isinstance(msg[0],tuple): 
            # case history pair is an image and user did not pass a new image
            # we override the none with the history image
            if new_image is False : 
                # always aim for the latest image in the history
                image_input = Image.open(msg[0][0]).convert("RGB")

    text_input = text_input.replace(image_placeholder, '')
    conversations.append({
        "from": "human",
        "value": text_input
    })
    if image_input is not None:
        conversations[0]["value"] = image_placeholder + '\n' + conversations[0]["value"]
    prompt, input_ids, pixel_values = model.preprocess_inputs(conversations, [image_input])
    attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
    input_ids = input_ids.unsqueeze(0).to(device=model.device)
    attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
    if image_input is None:
        pixel_values = [None]
    else:
        pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]

    with torch.inference_mode():
        gen_kwargs = dict(
            max_new_tokens=512,
            do_sample=False,
            top_p=None,
            top_k=None,
            temperature=None,
            repetition_penalty=None,
            eos_token_id=model.generation_config.eos_token_id,
            pad_token_id=text_tokenizer.pad_token_id,
            use_cache=True
        )
    response = ""
    thread = Thread(target=model.generate, 
                kwargs={"inputs": input_ids,
                        "pixel_values": pixel_values,
                        "attention_mask": attention_mask,
                        "streamer": streamer,
                        **gen_kwargs})
    thread.start()
    for new_text in streamer:
        response += new_text
        yield response
    thread.join()

def clear_chat():
    return [], None, ""

with open(f"{cur_dir}/resource/logo.svg", "r", encoding="utf-8") as svg_file:
    svg_content = svg_file.read()
font_size = "2.5em"
svg_content = re.sub(r'(<svg[^>]*)(>)', rf'\1 height="{font_size}" style="vertical-align: middle; display: inline-block;"\2', svg_content)
html = f"""
<p align="center" style="font-size: {font_size}; line-height: 1;">
    <span style="display: inline-block; vertical-align: middle;">{svg_content}</span>
    <span style="display: inline-block; vertical-align: middle;">{model_name.split('/')[-1]}</span>
</p>
<center><font size=3><b>Ovis</b> has been open-sourced on <a href='https://huggingface.co/{model_name}'>😊 Huggingface</a> and <a href='https://github.com/AIDC-AI/Ovis'>🌟 GitHub</a>. If you find Ovis useful, a like❤️ or a star🌟 would be appreciated.</font></center>
"""

latex_delimiters_set = [{
        "left": "\\(",
        "right": "\\)",
        "display": False 
    }, {
        "left": "\\begin{equation}",
        "right": "\\end{equation}",
        "display": True 
    }, {
        "left": "\\begin{align}",
        "right": "\\end{align}",
        "display": True
    }, {
        "left": "\\begin{alignat}",
        "right": "\\end{alignat}",
        "display": True
    }, {
        "left": "\\begin{gather}",
        "right": "\\end{gather}",
        "display": True
    }, {
        "left": "\\begin{CD}",
        "right": "\\end{CD}",
        "display": True
    }, {
        "left": "\\[",
        "right": "\\]",
        "display": True
    }]

hist= gr.Chatbot()
demo = gr.ChatInterface(fn=ovis_chat, textbox=gr.MultimodalTextbox(),multimodal=True,additional_inputs=hist)
demo.launch(debug=True)