import json from copy import deepcopy import torch import base64 from io import BytesIO from typing import Any, List, Dict from PIL import Image from transformers import AutoTokenizer, AutoModel def chat( model, image_list, msgs_list, tokenizer, vision_hidden_states=None, max_new_tokens=1024, sampling=True, max_inp_length=2048, system_prompt_list=None, **kwargs ): copy_msgs_lst = [] images_list = [] tgt_sizes_list = [] for i in range(len(msgs_list)): msgs = msgs_list[i] image = image_list[i] system_prompt = system_prompt_list[i] if system_prompt_list else None if isinstance(msgs, str): msgs = json.loads(msgs) copy_msgs = deepcopy(msgs) if image is not None and isinstance(copy_msgs[0]['content'], str): copy_msgs[0]['content'] = [image, copy_msgs[0]['content']] images = [] tgt_sizes = [] for i, msg in enumerate(copy_msgs): role = msg["role"] content = msg["content"] assert role in ["user", "assistant"] if i == 0: assert role == "user", "The role of first msg should be user" if isinstance(content, str): content = [content] cur_msgs = [] for c in content: if isinstance(c, Image.Image): image = c if model.config.slice_mode: slice_images, image_placeholder = model.get_slice_image_placeholder( image, tokenizer ) cur_msgs.append(image_placeholder) for slice_image in slice_images: slice_image = model.transform(slice_image) H, W = slice_image.shape[1:] images.append(model.reshape_by_patch(slice_image)) tgt_sizes.append( torch.Tensor([H // model.config.patch_size, W // model.config.patch_size]).type(torch.int32)) else: images.append(model.transform(image)) cur_msgs.append( tokenizer.im_start + tokenizer.unk_token * model.config.query_num + tokenizer.im_end ) elif isinstance(c, str): cur_msgs.append(c) msg['content'] = '\n'.join(cur_msgs) if tgt_sizes: tgt_sizes = torch.vstack(tgt_sizes) if system_prompt: sys_msg = {'role': 'system', 'content': system_prompt} copy_msgs = [sys_msg] + copy_msgs copy_msgs_lst.append(copy_msgs) images_list.append(images) tgt_sizes_list.append(tgt_sizes) input_ids_list = tokenizer.apply_chat_template(copy_msgs_lst, tokenize=True, add_generation_prompt=False) if sampling: generation_config = { "top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.05 } else: generation_config = { "num_beams": 3, "repetition_penalty": 1.2, } generation_config.update( (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys() ) with torch.inference_mode(): res, vision_hidden_states = model.generate( input_id_list=input_ids_list, max_inp_length=max_inp_length, img_list=images_list, tgt_sizes=tgt_sizes_list, tokenizer=tokenizer, max_new_tokens=max_new_tokens, vision_hidden_states=vision_hidden_states, return_vision_hidden_states=True, stream=False, **generation_config ) return res class EndpointHandler(): # batch def __init__(self, path=""): # Use a pipeline as a high-level helper model_name = "SwordElucidator/MiniCPM-Llama3-V-2_5-int4" model = AutoModel.from_pretrained(model_name, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model.eval() self.model = model self.tokenizer = tokenizer def __call__(self, data: Any) -> List[List[Dict[str, float]]]: inputs = data.pop("inputs", data) image_list = [] msgs_list = [] for input_ in inputs: image = input_.pop("image", None) # base64 image as bytes question = input_.pop("question", None) msgs = input_.pop("msgs", None) image = Image.open(BytesIO(base64.b64decode(image))) if not msgs: msgs = [{'role': 'user', 'content': question}] image_list.append(image) msgs_list.append(msgs) return chat( self.model, image_list, msgs_list, self.tokenizer, )