andrewqian123
commited on
Commit
•
2033829
1
Parent(s):
35f8d11
Update modeling_minicpmv.py
Browse files- modeling_minicpmv.py +37 -34
modeling_minicpmv.py
CHANGED
@@ -274,7 +274,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
274 |
|
275 |
def chat(
|
276 |
self,
|
277 |
-
|
278 |
msgs,
|
279 |
tokenizer,
|
280 |
processor=None,
|
@@ -290,42 +290,45 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
290 |
processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
|
291 |
if isinstance(msgs, str):
|
292 |
msgs = json.loads(msgs)
|
293 |
-
copy_msgs = deepcopy(msgs)
|
294 |
|
295 |
assert len(msgs) > 0, "msgs is empty"
|
296 |
assert sampling or not stream, "if use stream mode, make sure sampling=True"
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
copy_msgs[
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
assert role
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
329 |
|
330 |
if sampling:
|
331 |
generation_config = {
|
|
|
274 |
|
275 |
def chat(
|
276 |
self,
|
277 |
+
images,
|
278 |
msgs,
|
279 |
tokenizer,
|
280 |
processor=None,
|
|
|
290 |
processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
|
291 |
if isinstance(msgs, str):
|
292 |
msgs = json.loads(msgs)
|
293 |
+
# copy_msgs = deepcopy(msgs)
|
294 |
|
295 |
assert len(msgs) > 0, "msgs is empty"
|
296 |
assert sampling or not stream, "if use stream mode, make sure sampling=True"
|
297 |
+
assert(len(msgs) == len(images)), "Make sure to have one image per item in your batch"
|
298 |
+
batchM = []
|
299 |
+
batchI = []
|
300 |
+
for ind in range(len(images)):
|
301 |
+
image = images[ind]
|
302 |
+
if image is not None and isinstance(copy_msgs[0]["content"], str):
|
303 |
+
# deep copy element
|
304 |
+
copy_msgs = deepcopy(msgs[ind])
|
305 |
+
copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
|
306 |
+
|
307 |
+
imagelist = []
|
308 |
+
for i, msg in enumerate(copy_msgs):
|
309 |
+
role = msg["role"]
|
310 |
+
content = msg["content"]
|
311 |
+
assert role in ["user", "assistant"]
|
312 |
+
if i == 0:
|
313 |
+
assert role == "user", "The role of first msg should be user"
|
314 |
+
if isinstance(content, str):
|
315 |
+
content = [content]
|
316 |
+
cur_msgs = []
|
317 |
+
for c in content:
|
318 |
+
if isinstance(c, Image.Image):
|
319 |
+
imagelist.append(c)
|
320 |
+
cur_msgs.append("(<image>./</image>)")
|
321 |
+
elif isinstance(c, str):
|
322 |
+
cur_msgs.append(c)
|
323 |
+
msg["content"] = "\n".join(cur_msgs)
|
324 |
+
|
325 |
+
if system_prompt:
|
326 |
+
sys_msg = {'role': 'system', 'content': system_prompt}
|
327 |
+
copy_msgs = [sys_msg] + copy_msgs
|
328 |
+
batchM.append(copy_msgs)
|
329 |
+
batchI.append(imagelist)
|
330 |
+
prompt = processor.tokenizer.apply_chat_template(batchM, tokenize=False, add_generation_prompt=True)
|
331 |
+
inputs = processor(prompt, batchI, return_tensors="pt", max_length=max_inp_length).to(self.device)
|
332 |
|
333 |
if sampling:
|
334 |
generation_config = {
|