File size: 3,085 Bytes
a882846
 
 
 
 
149c70c
9d0818e
a882846
 
 
 
 
 
 
 
 
c5b10fd
a882846
80a0d81
 
 
 
 
 
 
a882846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6db6a4
 
 
a882846
 
 
 
 
 
 
 
5c52575
b039c26
a882846
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from PIL import Image
import numpy as np
from transformers import AutoModel
import torch
import spaces
import os

# Load the model
model = AutoModel.from_pretrained("ragavsachdeva/magiv2", trust_remote_code=True).cuda().eval()

def read_image(image):
    image = Image.open(image).convert("L").convert("RGB")
    image = np.array(image)
    return image

@spaces.GPU(duration=180)
def process_images(chapter_pages, character_bank_images, character_bank_names):
    if chapter_pages is None:
        return [], ""
    if character_bank_images is None:
        character_bank_images = []
        character_bank_names = "" 
    if character_bank_names is None or character_bank_names == "":
        character_bank_names = ",".join([os.path.splitext(os.path.basename(x))[0] for x in character_bank_images])
    chapter_pages = [read_image(image) for image in chapter_pages]
    character_bank = {
        "images": [read_image(image) for image in character_bank_images],
        "names": character_bank_names.split(",")
    }

    with torch.no_grad():
        per_page_results = model.do_chapter_wide_prediction(chapter_pages, character_bank, use_tqdm=True, do_ocr=True)

    output_images = []
    transcript = []
    for i, (image, page_result) in enumerate(zip(chapter_pages, per_page_results)):
        output_image = model.visualise_single_image_prediction(image, page_result, filename=None)
        output_images.append(output_image)
        
        speaker_name = {
            text_idx: page_result["character_names"][char_idx] for text_idx, char_idx in page_result["text_character_associations"]
        }
        
        for j in range(len(page_result["ocr"])):
            if not page_result["is_essential_text"][j]:
                continue
            name = speaker_name.get(j, "unsure") 
            transcript.append(f"<{name}>: {page_result['ocr'][j]}")
    
    transcript_text = "\n".join(transcript)
    
    return output_images, transcript_text

# Define Gradio interface
chapter_pages_input = gr.Files(label="Chapter pages in chronological order.")
character_bank_images_input = gr.Files(label="Character reference images. If left empty, the transcript will say 'Other' for all characters.")
character_bank_names_input = gr.Textbox(label="Character names (comma separated). If left empty, the filenames of character images will be used.")

output_images = gr.Gallery(label="Output Images")
transcript_output = gr.Textbox(label="Transcript")

gr.Interface(
    fn=process_images,
    inputs=[chapter_pages_input, character_bank_images_input, character_bank_names_input],
    outputs=[output_images, transcript_output],
    title="Tails Tell Tales: Chapter-Wide Manga Transcriptions With Character Names",
    description="Instructions: (i) Upload a sequence of manga pages, (ii) Upload a set of reference character images, (iii) Provide the names for each character image, (iv) Sit tight, this can take a couple of minutes (OCR model is slow). Note: The job will abort after 3mins, so don't upload too many images (30ish is fine).",
).launch()