echo840 commited on
Commit
d28c270
1 Parent(s): 33f1780

Add application file

Browse files
Files changed (1) hide show
  1. app.py +352 -0
app.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ import copy
6
+ import gradio as gr
7
+ import os
8
+ import re
9
+ import secrets
10
+ import tempfile
11
+
12
+ from PIL import Image
13
+ from monkey_model.modeling_monkey import MonkeyLMHeadModel
14
+ from monkey_model.tokenization_qwen import QWenTokenizer
15
+ from monkey_model.configuration_monkey import MonkeyConfig
16
+
17
+ import shutil
18
+ from pathlib import Path
19
+ import json
20
+ DEFAULT_CKPT_PATH = 'echo840/Monkey' # '/home/zhangli/demo/'
21
+ BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
22
+ PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
23
+ title_markdown = ("""
24
+ # Welcome to Monkey
25
+
26
+ Hello! I'm Monkey, a Large Language and Vision Assistant. Before talking to me, please read the **Operation Guide** and **Terms of Use**.
27
+
28
+ > Note: This demo represents a more advanced iteration of the chat system, building upon the previous version to deliver an enhanced interactive experience. As a result, we cannot guarantee that the question-answering scenarios presented in the paper can be replicated accurately using this updated version.
29
+
30
+ ## Operation Guide
31
+
32
+ Click the **Upload** button to upload an image. Then, you can get Monkey's answer in two ways:
33
+ - Click the **Generate** and Monkey will generate a description of the image.
34
+ - Enter the question in the dialog box, click the **Submit**, and Monkey will answer the question based on the image.
35
+ - Click **Clear History** to clear the current image and Q&A content.
36
+
37
+ """)
38
+
39
+ policy_markdown = ("""
40
+ ## Terms of Use
41
+
42
+ By using this service, users are required to agree to the following terms:
43
+
44
+ - Monkey is for research use only and unauthorized commercial use is prohibited. For any query, please contact the author.
45
+ - Monkey's generation capabilities are limited, so we recommend that users do not rely entirely on its answers.
46
+ - Monkey's security measures are limited, so we cannot guarantee that the output is completely appropriate. We strongly recommend that users do not intentionally guide Monkey to generate harmful content, including hate speech, discrimination, violence, pornography, deception, etc.
47
+
48
+ """)
49
+
50
+ # ## Some Prompt Examples
51
+
52
+ # In order to generate more detailed captions, we provide some input examples so that you can conduct more interesting explorations.
53
+
54
+ # - Generate the detailed caption in English.
55
+ # - Explain the visual content of the image in great detail.
56
+ # - Analyze the image in a comprehensive and detailed manner.
57
+ # - Describe the image in as much detail as possible in English without duplicating it.
58
+ # - Describe the image in as much detail as possible in English, including as many elements from the image as possible, but without repetition.
59
+
60
+
61
+ def _get_args():
62
+ parser = ArgumentParser()
63
+ parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
64
+ help="Checkpoint name or path, default to %(default)r")
65
+ parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
66
+
67
+ parser.add_argument("--share", action="store_true", default=True,
68
+ help="Create a publicly shareable link for the interface.")
69
+ parser.add_argument("--inbrowser", action="store_true", default=False,
70
+ help="Automatically launch the interface in a new tab on the default browser.")
71
+ parser.add_argument("--server-port", type=int, default=8000,
72
+ help="Demo server port.")
73
+ parser.add_argument("--server-name", type=str, default="127.0.0.1",
74
+ help="Demo server name.")
75
+
76
+ args = parser.parse_args()
77
+ return args
78
+
79
+
80
+ def _load_model_tokenizer(args):
81
+ tokenizer = QWenTokenizer.from_pretrained(
82
+ args.checkpoint_path, trust_remote_code=True)
83
+
84
+ if args.cpu_only:
85
+ device_map = "cpu"
86
+ else:
87
+ device_map = "cuda"
88
+
89
+ model = MonkeyLMHeadModel.from_pretrained(
90
+ args.checkpoint_path,
91
+ device_map=device_map,
92
+ trust_remote_code=True,
93
+ ).eval()
94
+ # model.generation_config = GenerationConfig.from_pretrained(
95
+ # args.checkpoint_path, trust_remote_code=True, resume_download=True,
96
+ # )
97
+ tokenizer.padding_side = 'left'
98
+ tokenizer.pad_token_id = tokenizer.eod_id
99
+ return model, tokenizer
100
+
101
+
102
+ def _parse_text(text):
103
+ lines = text.split("\n")
104
+ lines = [line for line in lines if line != ""]
105
+ count = 0
106
+ for i, line in enumerate(lines):
107
+ if "```" in line:
108
+ count += 1
109
+ items = line.split("`")
110
+ if count % 2 == 1:
111
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
112
+ else:
113
+ lines[i] = f"<br></code></pre>"
114
+ else:
115
+ if i > 0:
116
+ if count % 2 == 1:
117
+ line = line.replace("`", r"\`")
118
+ line = line.replace("<", "&lt;")
119
+ line = line.replace(">", "&gt;")
120
+ line = line.replace(" ", "&nbsp;")
121
+ line = line.replace("*", "&ast;")
122
+ line = line.replace("_", "&lowbar;")
123
+ line = line.replace("-", "&#45;")
124
+ line = line.replace(".", "&#46;")
125
+ line = line.replace("!", "&#33;")
126
+ line = line.replace("(", "&#40;")
127
+ line = line.replace(")", "&#41;")
128
+ line = line.replace("$", "&#36;")
129
+ lines[i] = "<br>" + line
130
+ text = "".join(lines)
131
+ return text
132
+
133
+
134
+ def _launch_demo(args, model, tokenizer):
135
+ def predict(_chatbot, task_history):
136
+ chat_query = _chatbot[-1][0]
137
+ query = task_history[-1][0]
138
+ question = _parse_text(query)
139
+ # print("User: " + _parse_text(query))
140
+ full_response = ""
141
+
142
+
143
+ img_path = _chatbot[0][0][0]
144
+ try:
145
+ Image.open(img_path)
146
+ except:
147
+ response = "Please upload a picture."
148
+ _chatbot[-1] = (_parse_text(chat_query), response)
149
+ full_response = _parse_text(response)
150
+
151
+ task_history[-1] = (query, full_response)
152
+ # print("Monkey: " + _parse_text(full_response))
153
+ return _chatbot
154
+
155
+ query = f'<img>{img_path}</img> {question} Answer: '
156
+ print(query)
157
+
158
+ all_history = query
159
+ all_history_0 = ''
160
+ if len(_chatbot) > 2:
161
+ all_history = ''
162
+ for conv in _chatbot[1:-1]:
163
+ q = conv[0]
164
+ a = conv[1]
165
+ all_history_0 = all_history + f'{q} Answer: {a} '
166
+ all_history = all_history_0 + f'<img>{img_path}</img> ' # 1288 tokens
167
+ all_history = all_history + f'{question} Answer: '
168
+ print(all_history)
169
+ tokens = all_history.split()
170
+ last_2048_tokens = tokens[-760:]
171
+ all_history = " ".join(last_2048_tokens)
172
+ print(all_history)
173
+
174
+ # input_ids = tokenizer(query, return_tensors='pt', padding='longest')
175
+ input_ids = tokenizer(all_history, return_tensors='pt', padding='longest')
176
+
177
+ attention_mask = input_ids.attention_mask
178
+ input_ids = input_ids.input_ids
179
+
180
+ pred = model.generate(
181
+ input_ids=input_ids.cuda(),
182
+ attention_mask=attention_mask.cuda(),
183
+ do_sample=False,
184
+ num_beams=1,
185
+ max_new_tokens=512,
186
+ min_new_tokens=1,
187
+ length_penalty=3,
188
+ num_return_sequences=1,
189
+ output_hidden_states=True,
190
+ use_cache=True,
191
+ pad_token_id=tokenizer.eod_id,
192
+ eos_token_id=tokenizer.eod_id,
193
+ )
194
+ response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
195
+
196
+ _chatbot[-1] = (_parse_text(chat_query), response)
197
+ full_response = _parse_text(response)
198
+
199
+ # with open('./history/question_answer.jsonl', 'a',encoding="utf-8") as file: # 使用 'a' 模式打开文件,表示以追加模式写入
200
+ # data = {query:response}
201
+ # json_line = json.dumps(data)
202
+ # file.write(json_line + '\n')
203
+ # with open('./history/all_history_together.jsonl', 'a',encoding="utf-8") as file: # 使用 'a' 模式打开文件,表示以追加模式写入
204
+ # data = f'<img>{img_path}</img> ' + all_history_0 + f'{question} Answer: {full_response}'
205
+ # json_line = json.dumps(data)
206
+ # file.write(json_line + '\n')
207
+
208
+
209
+ task_history[-1] = (query, full_response)
210
+ print("Monkey: " + _parse_text(full_response))
211
+ return _chatbot
212
+
213
+ def caption(_chatbot, task_history):
214
+
215
+ query = "Describe the image in as much detail as possible in English, including as many elements from the image as possible, but without repetition. Answer: "
216
+ chat_query = "Describe the image in as much detail as possible in English, including as many elements from the image as possible, but without repetition. Answer: "
217
+
218
+ question = _parse_text(query)
219
+ print("User: " + _parse_text(query))
220
+
221
+ full_response = ""
222
+
223
+ try:
224
+ img_path = _chatbot[0][0][0]
225
+ Image.open(img_path)
226
+ except:
227
+ response = "Please upload a picture."
228
+
229
+ _chatbot.append((None, response))
230
+ full_response = _parse_text(response)
231
+
232
+ task_history.append((None, full_response))
233
+ print("Monkey: " + _parse_text(full_response))
234
+ return _chatbot
235
+ img_path = _chatbot[0][0][0]
236
+ query = f'<img>{img_path}</img> {chat_query} '
237
+ print(query)
238
+ input_ids = tokenizer(query, return_tensors='pt', padding='longest')
239
+ attention_mask = input_ids.attention_mask
240
+ input_ids = input_ids.input_ids
241
+
242
+ pred = model.generate(
243
+ input_ids=input_ids.cuda(),
244
+ attention_mask=attention_mask.cuda(),
245
+ do_sample=True,
246
+ temperature=0.7,
247
+ max_new_tokens=250,
248
+ min_new_tokens=1,
249
+ length_penalty=3,
250
+ num_return_sequences=1,
251
+ output_hidden_states=True,
252
+ use_cache=True,
253
+ pad_token_id=tokenizer.eod_id,
254
+ eos_token_id=tokenizer.eod_id,
255
+ )
256
+ response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
257
+
258
+ _chatbot.append((None, response))
259
+ full_response = _parse_text(response)
260
+
261
+ task_history.append((None, full_response))
262
+ print("Monkey: " + _parse_text(full_response))
263
+ return _chatbot
264
+
265
+ def add_text(history, task_history, text):
266
+ task_text = text
267
+ if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
268
+ task_text = text[:-1]
269
+ history = history + [(_parse_text(text), None)]
270
+ task_history = task_history + [(task_text, None)]
271
+ # print(history, task_history, text)
272
+ return history, task_history, ""
273
+
274
+ def add_file(history, task_history, file):
275
+ save_path = os.path.join("./history/test_image",file.name.split("/")[-2])
276
+ Path(save_path).mkdir(exist_ok=True,parents=True)
277
+ shutil.copy(file.name,save_path)
278
+ history = [((file.name,), None)]
279
+ task_history = [((file.name,), None)]
280
+ # print(history, task_history, file)
281
+ return history, task_history
282
+
283
+ def reset_user_input():
284
+ return gr.update(value="")
285
+
286
+ def reset_state(task_history):
287
+ # with open('./history/all_history_separate.jsonl', 'a',encoding="utf-8") as file: # 使用 'a' 模式打开文件,表示以追加模式写入
288
+ # data = task_history
289
+ # json_line = json.dumps(data)
290
+ # file.write(json_line + '\n')
291
+ task_history.clear()
292
+ return []
293
+
294
+
295
+ with gr.Blocks() as demo:
296
+ gr.Markdown(title_markdown)
297
+
298
+ chatbot = gr.Chatbot(label='Monkey', elem_classes="control-height", height=600,avatar_images=("./images/logo_user.png","./images/logo_monkey.png"),layout="bubble",bubble_full_width=False,show_copy_button=True)
299
+ query = gr.Textbox(lines=1, label='Input')
300
+ task_history = gr.State([])
301
+
302
+ with gr.Row():
303
+ empty_bin = gr.Button("Clear History")
304
+ submit_btn = gr.Button("Submit")
305
+
306
+ generate_btn_en = gr.Button("Generate")
307
+ addfile_btn = gr.UploadButton("Upload", file_types=["image"])
308
+
309
+ submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
310
+ predict, [chatbot, task_history], [chatbot], show_progress=True
311
+ )
312
+ generate_btn_en.click(caption, [chatbot, task_history], [chatbot], show_progress=True)
313
+
314
+ submit_btn.click(reset_user_input, [], [query])
315
+ empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
316
+
317
+ addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True,scroll_to_output=True)
318
+
319
+ with gr.Row(variant="compact"):
320
+ with gr.Column(scale=2):
321
+ with gr.Row():
322
+ a = gr.Image(Image.open("./images/logo_monkey.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
323
+ b = gr.Image(Image.open("./images/logo_hust.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
324
+ with gr.Column(scale=4):
325
+ with gr.Row():
326
+ a = gr.Image(Image.open("./images/logo_monkey.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
327
+ c = gr.Image(Image.open("./images/logo_vlr.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
328
+ b = gr.Image(Image.open("./images/logo_hust.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
329
+ b = gr.Image(Image.open("./images/logo_king.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
330
+ with gr.Column(scale=2):
331
+ with gr.Row():
332
+ a = gr.Image(Image.open("./images/logo_monkey.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
333
+ b = gr.Image(Image.open("./images/logo_hust.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
334
+
335
+ gr.Markdown(policy_markdown)
336
+
337
+ demo.queue().launch(
338
+ server_name="0.0.0.0",
339
+ server_port=7682,
340
+ share=True
341
+ )
342
+
343
+
344
+ def main():
345
+ args = _get_args()
346
+
347
+ model, tokenizer = _load_model_tokenizer(args)
348
+ _launch_demo(args, model, tokenizer)
349
+
350
+
351
+ if __name__ == '__main__':
352
+ main()