Spaces:
Sleeping
Sleeping
import gradio as gr | |
import PIL.Image | |
import transformers | |
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
import torch | |
import os | |
import string | |
import functools | |
import re | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import spaces | |
model_id = "mattraj/curacel-transcription-1" | |
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device) | |
processor = PaliGemmaProcessor.from_pretrained(model_id) | |
def resize_and_pad(image, target_dim): | |
# Calculate the aspect ratio | |
scale_factor = 1 | |
aspect_ratio = image.width / image.height | |
if aspect_ratio > 1: | |
# Width is greater than height | |
new_width = int(target_dim * scale_factor) | |
new_height = int((target_dim / aspect_ratio) * scale_factor) | |
else: | |
# Height is greater than width | |
new_height = int(target_dim * scale_factor) | |
new_width = int(target_dim * aspect_ratio * scale_factor) | |
resized_image = image.resize((new_width, new_height), Image.LANCZOS) | |
# Create a new image with the target dimensions and a white background | |
new_image = Image.new("RGB", (target_dim, target_dim), (255, 255, 255)) | |
new_image.paste(resized_image, ((target_dim - new_width) // 2, (target_dim - new_height) // 2)) | |
return new_image | |
###### Transformers Inference | |
def infer( | |
image: PIL.Image.Image, | |
text: str, | |
max_new_tokens: int | |
) -> str: | |
inputs = processor(text=text, images=resize_and_pad(image), return_tensors="pt").to(device) | |
with torch.inference_mode(): | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=False | |
) | |
result = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
return result[0][len(text):].lstrip("\n") | |
##### Parse segmentation output tokens into masks | |
##### Also returns bounding boxes with their labels | |
def parse_segmentation(input_image, input_text): | |
out = infer(input_image, input_text, max_new_tokens=100) | |
objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True) | |
labels = set(obj.get('name') for obj in objs if obj.get('name')) | |
color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} | |
highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] | |
annotated_img = ( | |
input_image, | |
[ | |
( | |
obj['mask'] if obj.get('mask') is not None else obj['xyxy'], | |
obj['name'] or '', | |
) | |
for obj in objs | |
if 'mask' in obj or 'xyxy' in obj | |
], | |
) | |
has_annotations = bool(annotated_img[1]) | |
return annotated_img | |
######## Demo | |
INTRO_TEXT = """## Curacel Handwritten Arabic demo\n\n | |
Finetuned from: google/paligemma-3b-pt-448 | |
Translation model demo at: https://prod.arabic-gpt.ai/ | |
Prompts: | |
Translate the Arabic to English: {model output} | |
The following is a diagnosis in Arabic from a medical billing form we need to translate to English. The transcriber is not necessariily accurate so one or more characters or words may be wrong. Given what is written, what is the most likely diagnosis. Think step by step, and think about similar words or mispellings in Arabic. Give multiple arabic diagnoses along with the translation in English for each, then finally select the diagnosis that makes the most sense given what was transcribed and print the English translation as your most likely final translation. Transcribed text: {model output} | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(INTRO_TEXT) | |
with gr.Tab("Text Generation"): | |
with gr.Column(): | |
image = gr.Image(type="pil") | |
text_input = gr.Text(label="Input Text") | |
text_output = gr.Text(label="Text Output") | |
chat_btn = gr.Button() | |
tokens = gr.Slider( | |
label="Max New Tokens", | |
info="Set to larger for longer generation.", | |
minimum=10, | |
maximum=100, | |
value=20, | |
step=10, | |
) | |
chat_inputs = [ | |
image, | |
text_input, | |
tokens | |
] | |
chat_outputs = [ | |
text_output | |
] | |
chat_btn.click( | |
fn=infer, | |
inputs=chat_inputs, | |
outputs=chat_outputs, | |
) | |
examples = [["./diagnosis-1.jpg", "Transcribe the Arabic text."], | |
["./examples/sign.jpg", "Transcribe the Arabic text."]] | |
gr.Markdown("") | |
gr.Examples( | |
examples=examples, | |
inputs=chat_inputs, | |
) | |
''' | |
with gr.Tab("Segment/Detect"): | |
image = gr.Image(type="pil") | |
seg_input = gr.Text(label="Entities to Segment/Detect") | |
seg_btn = gr.Button("Submit") | |
annotated_image = gr.AnnotatedImage(label="Output") | |
examples = [["./diagnosis-1.jpg", "Transcribe the Arabic text."], | |
["./examples/sign.jpg", "Transcribe the Arabic text."]] | |
gr.Markdown( | |
"") | |
gr.Examples( | |
examples=examples, | |
inputs=[image, seg_input], | |
) | |
seg_inputs = [ | |
image, | |
seg_input | |
] | |
seg_outputs = [ | |
annotated_image | |
] | |
seg_btn.click( | |
fn=parse_segmentation, | |
inputs=seg_inputs, | |
outputs=seg_outputs, | |
) | |
''' | |
### Postprocessing Utils for Segmentation Tokens | |
### Segmentation tokens are passed to another VAE which decodes them to a mask | |
_MODEL_PATH = 'vae-oid.npz' | |
_SEGMENT_DETECT_RE = re.compile( | |
r'(.*?)' + | |
r'<loc(\d{4})>' * 4 + r'\s*' + | |
'(?:%s)?' % (r'<seg(\d{3})>' * 16) + | |
r'\s*([^;<>]+)? ?(?:; )?', | |
) | |
def _get_params(checkpoint): | |
"""Converts PyTorch checkpoint to Flax params.""" | |
def transp(kernel): | |
return np.transpose(kernel, (2, 3, 1, 0)) | |
def conv(name): | |
return { | |
'bias': checkpoint[name + '.bias'], | |
'kernel': transp(checkpoint[name + '.weight']), | |
} | |
def resblock(name): | |
return { | |
'Conv_0': conv(name + '.0'), | |
'Conv_1': conv(name + '.2'), | |
'Conv_2': conv(name + '.4'), | |
} | |
return { | |
'_embeddings': checkpoint['_vq_vae._embedding'], | |
'Conv_0': conv('decoder.0'), | |
'ResBlock_0': resblock('decoder.2.net'), | |
'ResBlock_1': resblock('decoder.3.net'), | |
'ConvTranspose_0': conv('decoder.4'), | |
'ConvTranspose_1': conv('decoder.6'), | |
'ConvTranspose_2': conv('decoder.8'), | |
'ConvTranspose_3': conv('decoder.10'), | |
'Conv_1': conv('decoder.12'), | |
} | |
def _quantized_values_from_codebook_indices(codebook_indices, embeddings): | |
batch_size, num_tokens = codebook_indices.shape | |
assert num_tokens == 16, codebook_indices.shape | |
unused_num_embeddings, embedding_dim = embeddings.shape | |
encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0) | |
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) | |
return encodings | |
def _get_reconstruct_masks(): | |
"""Reconstructs masks from codebook indices. | |
Returns: | |
A function that expects indices shaped `[B, 16]` of dtype int32, each | |
ranging from 0 to 127 (inclusive), and that returns a decoded masks sized | |
`[B, 64, 64, 1]`, of dtype float32, in range [-1, 1]. | |
""" | |
class ResBlock(nn.Module): | |
features: int | |
def __call__(self, x): | |
original_x = x | |
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) | |
x = nn.relu(x) | |
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) | |
x = nn.relu(x) | |
x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) | |
return x + original_x | |
class Decoder(nn.Module): | |
"""Upscales quantized vectors to mask.""" | |
def __call__(self, x): | |
num_res_blocks = 2 | |
dim = 128 | |
num_upsample_layers = 4 | |
x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) | |
x = nn.relu(x) | |
for _ in range(num_res_blocks): | |
x = ResBlock(features=dim)(x) | |
for _ in range(num_upsample_layers): | |
x = nn.ConvTranspose( | |
features=dim, | |
kernel_size=(4, 4), | |
strides=(2, 2), | |
padding=2, | |
transpose_kernel=True, | |
)(x) | |
x = nn.relu(x) | |
dim //= 2 | |
x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) | |
return x | |
def reconstruct_masks(codebook_indices): | |
quantized = _quantized_values_from_codebook_indices( | |
codebook_indices, params['_embeddings'] | |
) | |
return Decoder().apply({'params': params}, quantized) | |
with open(_MODEL_PATH, 'rb') as f: | |
params = _get_params(dict(np.load(f))) | |
return jax.jit(reconstruct_masks, backend='cpu') | |
def extract_objs(text, width, height, unique_labels=False): | |
"""Returns objs for a string with "<loc>" and "<seg>" tokens.""" | |
objs = [] | |
seen = set() | |
while text: | |
m = _SEGMENT_DETECT_RE.match(text) | |
if not m: | |
break | |
print("m", m) | |
gs = list(m.groups()) | |
before = gs.pop(0) | |
name = gs.pop() | |
y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]] | |
y1, x1, y2, x2 = map(round, (y1 * height, x1 * width, y2 * height, x2 * width)) | |
seg_indices = gs[4:20] | |
if seg_indices[0] is None: | |
mask = None | |
else: | |
seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32) | |
m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0] | |
m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1) | |
m64 = PIL.Image.fromarray((m64 * 255).astype('uint8')) | |
mask = np.zeros([height, width]) | |
if y2 > y1 and x2 > x1: | |
mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0 | |
content = m.group() | |
if before: | |
objs.append(dict(content=before)) | |
content = content[len(before):] | |
while unique_labels and name in seen: | |
name = (name or '') + "'" | |
seen.add(name) | |
objs.append(dict( | |
content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name)) | |
text = text[len(before) + len(content):] | |
if text: | |
objs.append(dict(content=text)) | |
return objs | |
######### | |
if __name__ == "__main__": | |
demo.queue(max_size=10).launch(debug=True) |