|
import argparse |
|
import os.path |
|
import random |
|
import time |
|
from functools import partial |
|
|
|
import evaluate |
|
from tabulate import tabulate |
|
from tqdm import tqdm |
|
|
|
from texify.inference import batch_inference |
|
from texify.model.model import load_model |
|
from texify.model.processor import load_processor |
|
from PIL import Image |
|
from texify.settings import settings |
|
import json |
|
import base64 |
|
import io |
|
from rapidfuzz.distance import Levenshtein |
|
|
|
|
|
def normalize_text(text): |
|
|
|
text = text.replace("$", "") |
|
text = text.replace("\[", "") |
|
text = text.replace("\]", "") |
|
text = text.replace("\(", "") |
|
text = text.replace("\)", "") |
|
text = text.strip() |
|
return text |
|
|
|
|
|
def score_text(predictions, references): |
|
bleu = evaluate.load("bleu") |
|
bleu_results = bleu.compute(predictions=predictions, references=references) |
|
|
|
meteor = evaluate.load('meteor') |
|
meteor_results = meteor.compute(predictions=predictions, references=references) |
|
|
|
lev_dist = [] |
|
for p, r in zip(predictions, references): |
|
lev_dist.append(Levenshtein.normalized_distance(p, r)) |
|
|
|
return { |
|
'bleu': bleu_results["bleu"], |
|
'meteor': meteor_results['meteor'], |
|
'edit': sum(lev_dist) / len(lev_dist) |
|
} |
|
|
|
|
|
def image_to_pil(image): |
|
decoded = base64.b64decode(image) |
|
return Image.open(io.BytesIO(decoded)) |
|
|
|
|
|
def load_images(source_data): |
|
images = [sd["image"] for sd in source_data] |
|
images = [image_to_pil(image) for image in images] |
|
return images |
|
|
|
|
|
def inference_texify(source_data, model, processor): |
|
images = load_images(source_data) |
|
|
|
write_data = [] |
|
for i in tqdm(range(0, len(images), settings.BATCH_SIZE), desc="Texify inference"): |
|
batch = images[i:i+settings.BATCH_SIZE] |
|
text = batch_inference(batch, model, processor) |
|
for j, t in enumerate(text): |
|
eq_idx = i + j |
|
write_data.append({"text": t, "equation": source_data[eq_idx]["equation"]}) |
|
|
|
return write_data |
|
|
|
|
|
def inference_pix2tex(source_data): |
|
from pix2tex.cli import LatexOCR |
|
model = LatexOCR() |
|
|
|
images = load_images(source_data) |
|
write_data = [] |
|
for i in tqdm(range(len(images)), desc="Pix2tex inference"): |
|
try: |
|
text = model(images[i]) |
|
except ValueError: |
|
|
|
text = "" |
|
write_data.append({"text": text, "equation": source_data[i]["equation"]}) |
|
|
|
return write_data |
|
|
|
|
|
def image_to_bmp(image): |
|
img_out = io.BytesIO() |
|
image.save(img_out, format="BMP") |
|
return img_out |
|
|
|
|
|
def inference_nougat(source_data, batch_size=1): |
|
import torch |
|
from nougat.postprocessing import markdown_compatible |
|
from nougat.utils.checkpoint import get_checkpoint |
|
from nougat.utils.dataset import ImageDataset |
|
from nougat.utils.device import move_to_device |
|
from nougat import NougatModel |
|
|
|
|
|
images = load_images(source_data) |
|
images = [image_to_bmp(image) for image in images] |
|
predictions = [] |
|
|
|
ckpt = get_checkpoint(None, model_tag="0.1.0-small") |
|
model = NougatModel.from_pretrained(ckpt) |
|
if settings.TORCH_DEVICE_MODEL != "cpu": |
|
move_to_device(model, bf16=settings.CUDA, cuda=settings.CUDA) |
|
model.eval() |
|
|
|
dataset = ImageDataset( |
|
images, |
|
partial(model.encoder.prepare_input, random_padding=False), |
|
) |
|
|
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
pin_memory=True, |
|
shuffle=False, |
|
) |
|
|
|
for idx, sample in tqdm(enumerate(dataloader), desc="Nougat inference", total=len(dataloader)): |
|
model.config.max_length = settings.MAX_TOKENS |
|
model_output = model.inference(image_tensors=sample, early_stopping=False) |
|
output = [markdown_compatible(o) for o in model_output["predictions"]] |
|
predictions.extend(output) |
|
return predictions |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Benchmark the performance of texify.") |
|
parser.add_argument("--data_path", type=str, help="Path to JSON file with source images/equations", default=os.path.join(settings.DATA_DIR, "bench_data.json")) |
|
parser.add_argument("--result_path", type=str, help="Path to JSON file to save results to.", default=os.path.join(settings.DATA_DIR, "bench_results.json")) |
|
parser.add_argument("--max", type=int, help="Maximum number of images to benchmark.", default=None) |
|
parser.add_argument("--pix2tex", action="store_true", help="Run pix2tex scoring", default=False) |
|
parser.add_argument("--nougat", action="store_true", help="Run nougat scoring", default=False) |
|
args = parser.parse_args() |
|
|
|
source_path = os.path.abspath(args.data_path) |
|
result_path = os.path.abspath(args.result_path) |
|
os.makedirs(os.path.dirname(result_path), exist_ok=True) |
|
model = load_model() |
|
processor = load_processor() |
|
|
|
with open(source_path, "r") as f: |
|
source_data = json.load(f) |
|
|
|
if args.max: |
|
random.seed(1) |
|
source_data = random.sample(source_data, args.max) |
|
|
|
start = time.time() |
|
predictions = inference_texify(source_data, model, processor) |
|
times = {"texify": time.time() - start} |
|
text = [normalize_text(p["text"]) for p in predictions] |
|
references = [normalize_text(p["equation"]) for p in predictions] |
|
|
|
scores = score_text(text, references) |
|
|
|
write_data = { |
|
"texify": { |
|
"scores": scores, |
|
"text": [{"prediction": p, "reference": r} for p, r in zip(text, references)] |
|
} |
|
} |
|
|
|
if args.pix2tex: |
|
start = time.time() |
|
predictions = inference_pix2tex(source_data) |
|
times["pix2tex"] = time.time() - start |
|
|
|
p_text = [normalize_text(p["text"]) for p in predictions] |
|
|
|
p_scores = score_text(p_text, references) |
|
|
|
write_data["pix2tex"] = { |
|
"scores": p_scores, |
|
"text": [{"prediction": p, "reference": r} for p, r in zip(p_text, references)] |
|
} |
|
|
|
if args.nougat: |
|
start = time.time() |
|
predictions = inference_nougat(source_data) |
|
times["nougat"] = time.time() - start |
|
n_text = [normalize_text(p) for p in predictions] |
|
|
|
n_scores = score_text(n_text, references) |
|
|
|
write_data["nougat"] = { |
|
"scores": n_scores, |
|
"text": [{"prediction": p, "reference": r} for p, r in zip(n_text, references)] |
|
} |
|
|
|
score_table = [] |
|
score_headers = ["bleu", "meteor", "edit"] |
|
score_dirs = ["⬆", "⬆", "⬇", "⬇"] |
|
|
|
for method in write_data.keys(): |
|
score_table.append([method, *[write_data[method]["scores"][h] for h in score_headers], times[method]]) |
|
|
|
score_headers.append("time taken (s)") |
|
score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)] |
|
print() |
|
print(tabulate(score_table, headers=["Method", *score_headers])) |
|
print() |
|
print("Higher is better for BLEU and METEOR, lower is better for edit distance and time taken.") |
|
print("Note that pix2tex is unbatched (I couldn't find a batch inference method in the docs), so time taken is higher than it should be.") |
|
|
|
with open(result_path, "w") as f: |
|
json.dump(write_data, f, indent=4) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|
|
|