SwordElucidator commited on
Commit
9538a7b
1 Parent(s): 2ee3b1b

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +158 -0
handler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from copy import deepcopy
3
+
4
+ import torch
5
+
6
+ import base64
7
+ from io import BytesIO
8
+ from typing import Any, List, Dict
9
+
10
+ from PIL import Image
11
+ from transformers import AutoTokenizer, AutoModel
12
+
13
+
14
+ def chat(
15
+ model,
16
+ image_list,
17
+ msgs_list,
18
+ tokenizer,
19
+ vision_hidden_states=None,
20
+ max_new_tokens=1024,
21
+ sampling=True,
22
+ max_inp_length=2048,
23
+ system_prompt_list=None,
24
+ **kwargs
25
+ ):
26
+ copy_msgs_lst = []
27
+ images_list = []
28
+ tgt_sizes_list = []
29
+ for i in range(len(msgs_list)):
30
+ msgs = msgs_list[i]
31
+ image = image_list[i]
32
+ system_prompt = system_prompt_list[i] if system_prompt_list else None
33
+ if isinstance(msgs, str):
34
+ msgs = json.loads(msgs)
35
+
36
+ copy_msgs = deepcopy(msgs)
37
+
38
+ if image is not None and isinstance(copy_msgs[0]['content'], str):
39
+ copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
40
+
41
+ images = []
42
+ tgt_sizes = []
43
+ for i, msg in enumerate(copy_msgs):
44
+ role = msg["role"]
45
+ content = msg["content"]
46
+ assert role in ["user", "assistant"]
47
+ if i == 0:
48
+ assert role == "user", "The role of first msg should be user"
49
+ if isinstance(content, str):
50
+ content = [content]
51
+
52
+ cur_msgs = []
53
+ for c in content:
54
+ if isinstance(c, Image.Image):
55
+ image = c
56
+ if model.config.slice_mode:
57
+ slice_images, image_placeholder = model.get_slice_image_placeholder(
58
+ image, tokenizer
59
+ )
60
+ cur_msgs.append(image_placeholder)
61
+ for slice_image in slice_images:
62
+ slice_image = model.transform(slice_image)
63
+ H, W = slice_image.shape[1:]
64
+ images.append(model.reshape_by_patch(slice_image))
65
+ tgt_sizes.append(
66
+ torch.Tensor([H // model.config.patch_size, W // model.config.patch_size]).type(torch.int32))
67
+ else:
68
+ images.append(model.transform(image))
69
+ cur_msgs.append(
70
+ tokenizer.im_start
71
+ + tokenizer.unk_token * model.config.query_num
72
+ + tokenizer.im_end
73
+ )
74
+ elif isinstance(c, str):
75
+ cur_msgs.append(c)
76
+
77
+ msg['content'] = '\n'.join(cur_msgs)
78
+ if tgt_sizes:
79
+ tgt_sizes = torch.vstack(tgt_sizes)
80
+
81
+ if system_prompt:
82
+ sys_msg = {'role': 'system', 'content': system_prompt}
83
+ copy_msgs = [sys_msg] + copy_msgs
84
+
85
+ copy_msgs_lst.append(copy_msgs)
86
+ images_list.append(images)
87
+ tgt_sizes_list.append(tgt_sizes)
88
+
89
+ input_ids_list = tokenizer.apply_chat_template(copy_msgs_lst, tokenize=True, add_generation_prompt=False)
90
+
91
+ if sampling:
92
+ generation_config = {
93
+ "top_p": 0.8,
94
+ "top_k": 100,
95
+ "temperature": 0.7,
96
+ "do_sample": True,
97
+ "repetition_penalty": 1.05
98
+ }
99
+ else:
100
+ generation_config = {
101
+ "num_beams": 3,
102
+ "repetition_penalty": 1.2,
103
+ }
104
+
105
+ generation_config.update(
106
+ (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
107
+ )
108
+
109
+ with torch.inference_mode():
110
+ res, vision_hidden_states = model.generate(
111
+ input_id_list=input_ids_list,
112
+ max_inp_length=max_inp_length,
113
+ img_list=images_list,
114
+ tgt_sizes=tgt_sizes_list,
115
+ tokenizer=tokenizer,
116
+ max_new_tokens=max_new_tokens,
117
+ vision_hidden_states=vision_hidden_states,
118
+ return_vision_hidden_states=True,
119
+ stream=False,
120
+ **generation_config
121
+ )
122
+ return res
123
+
124
+
125
+ class EndpointHandler(): # batch
126
+ def __init__(self, path=""):
127
+ # Use a pipeline as a high-level helper
128
+ model_name = "SwordElucidator/MiniCPM-Llama3-V-2_5-int4"
129
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
130
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
131
+ model.eval()
132
+ self.model = model
133
+ self.tokenizer = tokenizer
134
+
135
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
136
+ inputs = data.pop("inputs", data)
137
+
138
+ image_list = []
139
+ msgs_list = []
140
+
141
+ for input_ in inputs:
142
+ image = input_.pop("image", None) # base64 image as bytes
143
+ question = input_.pop("question", None)
144
+ msgs = input_.pop("msgs", None)
145
+ image = Image.open(BytesIO(base64.b64decode(image)))
146
+
147
+ if not msgs:
148
+ msgs = [{'role': 'user', 'content': question}]
149
+
150
+ image_list.append(image)
151
+ msgs_list.append(msgs)
152
+
153
+ return chat(
154
+ self.model,
155
+ image_list,
156
+ msgs_list,
157
+ self.tokenizer,
158
+ )