Magiv2-Demo / app.py
ragavsachdeva's picture
Update app.py
9d0818e verified
raw
history blame
3.09 kB
import gradio as gr
from PIL import Image
import numpy as np
from transformers import AutoModel
import torch
import spaces
import os
# Load the model
model = AutoModel.from_pretrained("ragavsachdeva/magiv2", trust_remote_code=True).cuda().eval()
def read_image(image):
image = Image.open(image).convert("L").convert("RGB")
image = np.array(image)
return image
@spaces.GPU(duration=180)
def process_images(chapter_pages, character_bank_images, character_bank_names):
if chapter_pages is None:
return [], ""
if character_bank_images is None:
character_bank_images = []
character_bank_names = ""
if character_bank_names is None or character_bank_names == "":
character_bank_names = ",".join([os.path.splitext(os.path.basename(x))[0] for x in character_bank_images])
chapter_pages = [read_image(image) for image in chapter_pages]
character_bank = {
"images": [read_image(image) for image in character_bank_images],
"names": character_bank_names.split(",")
}
with torch.no_grad():
per_page_results = model.do_chapter_wide_prediction(chapter_pages, character_bank, use_tqdm=True, do_ocr=True)
output_images = []
transcript = []
for i, (image, page_result) in enumerate(zip(chapter_pages, per_page_results)):
output_image = model.visualise_single_image_prediction(image, page_result, filename=None)
output_images.append(output_image)
speaker_name = {
text_idx: page_result["character_names"][char_idx] for text_idx, char_idx in page_result["text_character_associations"]
}
for j in range(len(page_result["ocr"])):
if not page_result["is_essential_text"][j]:
continue
name = speaker_name.get(j, "unsure")
transcript.append(f"<{name}>: {page_result['ocr'][j]}")
transcript_text = "\n".join(transcript)
return output_images, transcript_text
# Define Gradio interface
chapter_pages_input = gr.Files(label="Chapter pages in chronological order.")
character_bank_images_input = gr.Files(label="Character reference images. If left empty, the transcript will say 'Other' for all characters.")
character_bank_names_input = gr.Textbox(label="Character names (comma separated). If left empty, the filenames of character images will be used.")
output_images = gr.Gallery(label="Output Images")
transcript_output = gr.Textbox(label="Transcript")
gr.Interface(
fn=process_images,
inputs=[chapter_pages_input, character_bank_images_input, character_bank_names_input],
outputs=[output_images, transcript_output],
title="Tails Tell Tales: Chapter-Wide Manga Transcriptions With Character Names",
description="Instructions: (i) Upload a sequence of manga pages, (ii) Upload a set of reference character images, (iii) Provide the names for each character image, (iv) Sit tight, this can take a couple of minutes (OCR model is slow). Note: The job will abort after 3mins, so don't upload too many images (30ish is fine).",
).launch()