File size: 5,032 Bytes
42fea26
 
 
2005ef8
42fea26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46b9cd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42fea26
 
46b9cd3
 
42fea26
 
 
 
 
 
 
 
 
 
46b9cd3
42fea26
 
 
 
2005ef8
42fea26
 
 
2005ef8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import spaces
import os

import gradio as gr
import torch
from transformers import AutoModelForCausalLM

model_name = 'AIDC-AI/Ovis1.6-Gemma2-9B'

# load model
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype=torch.bfloat16,
                                             multimodal_max_length=8192,
                                             trust_remote_code=True).to(device='cuda')
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()
image_placeholder = '<image>'


@spaces.GPU
def ovis_chat(chatbot, image_input, text_input):
    # preprocess inputs
    conversations = []
    for query, response in chatbot:
        conversations.append({
            "from": "human",
            "value": query
        })
        conversations.append({
            "from": "gpt",
            "value": response
        })
    text_input = text_input.replace(image_placeholder, '')
    conversations.append({
        "from": "human",
        "value": text_input
    })
    if image_input is not None:
        conversations[0]["value"] = image_placeholder + '\n' + conversations[0]["value"]
    prompt, input_ids, pixel_values = model.preprocess_inputs(conversations, [image_input])
    attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
    input_ids = input_ids.unsqueeze(0).to(device=model.device)
    attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
    if image_input is None:
        pixel_values = [None]
    else:
        pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]

    # generate output
    with torch.inference_mode():
        gen_kwargs = dict(
            max_new_tokens=512,
            do_sample=False,
            top_p=None,
            top_k=None,
            temperature=None,
            repetition_penalty=None,
            eos_token_id=model.generation_config.eos_token_id,
            pad_token_id=text_tokenizer.pad_token_id,
            use_cache=True
        )
    output_ids = model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0]
    output = text_tokenizer.decode(output_ids, skip_special_tokens=True)
    chatbot.append((text_input, output))

    return chatbot, ""


def clear_chat():
    return [], None, ""

md = f'''# <center>{model_name.split('/')[-1]}</center>
###
Ovis has been open-sourced on [GitHub](https://github.com/AIDC-AI/Ovis) and [Huggingface](https://huggingface.co/{model_name}). If you find Ovis useful, a star or a like would be appreciated.
'''

html = f"""
<center><font size=8> {model_name.split('/')[-1]}</font></center>
<center><font size=3>Ovis has been open-sourced on <a href='https://github.com/AIDC-AI/Ovis'>GitHub</a> and <a href='https://huggingface.co/{model_name}'>Huggingface</a>. If you find Ovis useful, a star or a like would be appreciated.</font></center>
"""


latex_delimiters_set = [{
        "left": "\\(",
        "right": "\\)",
        "display": False  # 行内公式
    }, {
        "left": "\\begin{equation}",
        "right": "\\end{equation}",
        "display": True  # 块级公式
    }, {
        "left": "\\begin{align}",
        "right": "\\end{align}",
        "display": True  # 块级公式
    }, {
        "left": "\\begin{alignat}",
        "right": "\\end{alignat}",
        "display": True  # 块级公式
    }, {
        "left": "\\begin{gather}",
        "right": "\\end{gather}",
        "display": True  # 块级公式
    }, {
        "left": "\\begin{CD}",
        "right": "\\end{CD}",
        "display": True  # 块级公式
    }, {
        "left": "\\[",
        "right": "\\]",
        "display": True  # 块级公式
    }]


text_input = gr.Textbox(label="prompt", placeholder="Enter your text here...", lines=1, container=False)
with gr.Blocks(title=model_name.split('/')[-1]) as demo:
    # gr.Markdown(md)
    gr.HTML(html)
    cur_dir = os.path.dirname(os.path.abspath(__file__))
    with gr.Row():
        with gr.Column(scale=3):
            image_input = gr.Image(label="image", height=350, type="pil")
            gr.Examples(
                examples=[
                    [f"{cur_dir}/examples/rs-1.png", "What shape should come as the fourth shape?"]],
                inputs=[image_input, text_input]
            )
        with gr.Column(scale=7):
            chatbot = gr.Chatbot(label="Ovis", layout="panel", height=800, show_copy_button=True, latex_delimiters=latex_delimiters_set)
            text_input.render()
            with gr.Row():
                send_btn = gr.Button("Send", variant="primary")
                clear_btn = gr.Button("Clear", variant="secondary")

    send_click_event = send_btn.click(ovis_chat, [chatbot, image_input, text_input], [chatbot, text_input])
    submit_event = text_input.submit(ovis_chat, [chatbot, image_input, text_input], [chatbot, text_input])
    clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input])

demo.launch()