import gradio as gr from PIL import Image import numpy as np from transformers import AutoModel import torch import spaces import os # Load the model model = AutoModel.from_pretrained("ragavsachdeva/magiv2", trust_remote_code=True).cuda().eval() def read_image(image): image = Image.open(image).convert("L").convert("RGB") image = np.array(image) return image @spaces.GPU(duration=180) def process_images(chapter_pages, character_bank_images, character_bank_names): if chapter_pages is None: return [], "" if character_bank_images is None: character_bank_images = [] character_bank_names = "" if character_bank_names is None or character_bank_names == "": character_bank_names = ",".join([os.path.splitext(os.path.basename(x))[0] for x in character_bank_images]) chapter_pages = [read_image(image) for image in chapter_pages] character_bank = { "images": [read_image(image) for image in character_bank_images], "names": character_bank_names.split(",") } with torch.no_grad(): per_page_results = model.do_chapter_wide_prediction(chapter_pages, character_bank, use_tqdm=True, do_ocr=True) output_images = [] transcript = [] for i, (image, page_result) in enumerate(zip(chapter_pages, per_page_results)): output_image = model.visualise_single_image_prediction(image, page_result, filename=None) output_images.append(output_image) speaker_name = { text_idx: page_result["character_names"][char_idx] for text_idx, char_idx in page_result["text_character_associations"] } for j in range(len(page_result["ocr"])): if not page_result["is_essential_text"][j]: continue name = speaker_name.get(j, "unsure") transcript.append(f"<{name}>: {page_result['ocr'][j]}") transcript_text = "\n".join(transcript) return output_images, transcript_text # Define Gradio interface chapter_pages_input = gr.Files(label="Chapter pages in chronological order.") character_bank_images_input = gr.Files(label="Character reference images. If left empty, the transcript will say 'Other' for all characters.") character_bank_names_input = gr.Textbox(label="Character names (comma separated). If left empty, the filenames of character images will be used.") output_images = gr.Gallery(label="Output Images") transcript_output = gr.Textbox(label="Transcript") gr.Interface( fn=process_images, inputs=[chapter_pages_input, character_bank_images_input, character_bank_names_input], outputs=[output_images, transcript_output], title="Tails Tell Tales: Chapter-Wide Manga Transcriptions With Character Names", description="Instructions: (i) Upload a sequence of manga pages, (ii) Upload a set of reference character images, (iii) Provide the names for each character image, (iv) Sit tight, this can take a couple of minutes (OCR model is slow). Note: The job will abort after 3mins, so don't upload too many images (30ish is fine).", ).launch()