import argparse import os import sys import numpy as np import cv2 import torch import gradio as gr from PIL import Image sys.path.insert(0, os.path.join(os.getcwd(), "..")) from unimernet.common.config import Config import unimernet.tasks as tasks from unimernet.processors import load_processor class ImageProcessor: def __init__(self, cfg_path): self.cfg_path = cfg_path self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model, self.vis_processor = self.load_model_and_processor() def load_model_and_processor(self): args = argparse.Namespace(cfg_path=self.cfg_path, options=None) cfg = Config(args) task = tasks.setup_task(cfg) model = task.build_model(cfg).to(self.device) vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) return model, vis_processor def process_single_image(self, image_path): try: raw_image = Image.open(image_path) except IOError: print(f"Error: Unable to open image at {image_path}") return # Convert PIL Image to OpenCV format open_cv_image = np.array(raw_image) # Convert RGB to BGR if len(open_cv_image.shape) == 3: # Convert RGB to BGR open_cv_image = open_cv_image[:, :, ::-1].copy() # Display the image using cv2 image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) output = self.model.generate({"image": image}) pred = output["pred_str"][0] print(f'Prediction:\n{pred}') cv2.imshow('Original Image', open_cv_image) cv2.waitKey(0) cv2.destroyAllWindows() return pred def recognize_image(input_img): # latex_code = processor.process_single_image(input_img.name) return "100" def gradio_reset(): return gr.update(value=None) if __name__ == "__main__": # == init model == # root_path = os.path.abspath(os.getcwd()) # config_path = os.path.join(root_path, "cfg_tiny.yaml") # processor_tiny = ImageProcessor(config_path) # print("== all models init. ==") # == init model == with open("header.html", "r") as file: header = file.read() with gr.Blocks() as demo: gr.HTML(header) with gr.Row(): with gr.Column(): input_img = gr.Image(label=" ", interactive=True) with gr.Row(): clear = gr.Button("Clear") predict = gr.Button(value="Recognize", interactive=True, variant="primary") with gr.Column(): gr.Button(value="Predict Latex:", interactive=False) pred_latex = gr.Textbox(label='Latex', interactive=False) clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex]) predict.click(recognize_image, inputs=[input_img], outputs=[pred_latex]) demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)