import re import torch import requests from PIL import Image, ImageDraw from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration repo = "microsoft/kosmos-2.5" device = "cuda:0" dtype = torch.bfloat16 model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype) processor = AutoProcessor.from_pretrained(repo) # sample image url = "https://huggingface.co/microsoft/kosmos-2.5/blob/main/receipt_00008.png" image = Image.open(requests.get(url, stream=True).raw) # bs = 1 prompt = "" inputs = processor(text=prompt, images=image, return_tensors="pt") height, width = inputs.pop("height"), inputs.pop("width") raw_width, raw_height = image.size scale_height = raw_height / height scale_width = raw_width / width # bs > 1, batch generation # inputs = processor(text=[prompt, prompt], images=[image,image], return_tensors="pt") # height, width = inputs.pop("height"), inputs.pop("width") # raw_width, raw_height = image.size # scale_height = raw_height / height[0] # scale_width = raw_width / width[0] inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()} inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) generated_ids = model.generate( **inputs, max_new_tokens=1024, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) def post_process(y, scale_height, scale_width): y = y.replace(prompt, "") if "" in prompt: return y pattern = r"" bboxs_raw = re.findall(pattern, y) lines = re.split(pattern, y)[1:] bboxs = [re.findall(r"\d+", i) for i in bboxs_raw] bboxs = [[int(j) for j in i] for i in bboxs] info = "" for i in range(len(lines)): box = bboxs[i] x0, y0, x1, y1 = box if not (x0 >= x1 or y0 >= y1): x0 = int(x0 * scale_width) y0 = int(y0 * scale_height) x1 = int(x1 * scale_width) y1 = int(y1 * scale_height) info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}" return info output_text = post_process(generated_text[0], scale_height, scale_width) print(output_text) draw = ImageDraw.Draw(image) lines = output_text.split("\n") for line in lines: # draw the bounding box line = list(line.split(",")) if len(line) < 8: continue line = list(map(int, line[:8])) draw.polygon(line, outline="red") image.save("output.png")