File size: 5,537 Bytes
c3e525b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e687a1c
c3e525b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
from PIL import Image, ImageOps, ImageSequence
import numpy as np

import comfy.sample
import comfy.sd


def vencode(vae, pth):
    pilimg = pth
    pixels = np.array(pilimg).astype(np.float32) / 255.0
    pixels = torch.from_numpy(pixels)[None,]
    t = vae.encode(pixels[:,:,:,:3])
    return {"samples":t}
from pathlib import Path
if not Path("model.safetensors").exists():
    import requests
    with open("model.safetensors", "wb") as f:
        f.write(requests.get("https://huggingface.co/parsee-mizuhashi/mangaka/resolve/main/mangaka.safetensors?download=true").content)
MODEL_FILE = "model.safetensors"
unet, clip, vae = comfy.sd.load_checkpoint_guess_config(MODEL_FILE, output_vae=True, output_clip=True)[:3]# :3
BASE_NEG = "(low-quality worst-quality:1.4 (bad-anatomy (inaccurate-limb:1.2 bad-composition inaccurate-eyes extra-digit fewer-digits (extra-arms:1.2)"
DEVICE = "cpu" if not torch.cuda.is_available() else "cuda"

def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0):

    noise_mask = None
    if "noise_mask" in latent:
        noise_mask = latent["noise_mask"]
    latnt = latent["samples"]
    noise = comfy.sample.prepare_noise(latnt, seed, None)
    disable_pbar = True
    samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latnt,
                                denoise=denoise, noise_mask=noise_mask, disable_pbar=disable_pbar, seed=seed)
    out = samples
    return out
def set_mask(samples, mask):
    s = samples.copy()
    s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
    return s
def load_image_mask(image):
    image_path = image
    i = Image.open(image_path)
    i = ImageOps.exif_transpose(i)
    if i.getbands() != ("R", "G", "B", "A"):
        if i.mode == 'I':
            i = i.point(lambda i: i * (1 / 255))
        i = i.convert("RGBA")
    mask = None
    c = "A"
    if c in i.getbands():
        mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
        mask = torch.from_numpy(mask)
    else:
        mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
    return mask.unsqueeze(0)
@torch.no_grad()
def main(img, variant, positive, negative, pilimg):
    variant = min(int(variant), limits[img])
    
    global unet, clip, vae 
    mask = load_image_mask(f"./mangaka-d/{img}/i{variant}.png")
    
    tkns = clip.tokenize("(greyscale monochrome black-and-white:1.3)" + positive)
    cond, c = clip.encode_from_tokens(tkns, return_pooled=True) 
    
    uncond_tkns = clip.tokenize(BASE_NEG + negative)
    uncond, uc = clip.encode_from_tokens(uncond_tkns, return_pooled=True)
    cn = [[cond, {"pooled_output": c}]]
    un = [[uncond, {"pooled_output": uc}]]

    latent = vencode(vae, pilimg)
    latent = set_mask(latent, mask)
    
    denoised = common_ksampler(unet, 0, 20, 7, 'ddpm', 'karras', cn, un, latent, denoise=1)
    decoded = vae.decode(denoised)
    i = 255. * decoded[0].cpu().numpy()
    img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
    return img

limits = {
    "1": 4,
    "2": 4,
    "3": 5,
    "4": 6,
    "5": 4,
    "6": 6,
    "7": 8,
    "8": 5,
    "9": 5,
    "s1": 4,
    "s2": 6,
    "s3": 5,
    "s4": 5,
    "s5": 4,
    "s6": 4
}
import gradio as gr
def visualize_fn(page, panel):
    base = f"./mangaka-d/{page}/base.png"
    base = Image.open(base)
    if panel == "none":
        return base
    panel = min(int(panel), limits[page])
    mask = f"./mangaka-d/{page}/i{panel}.png"
    base = base.convert("RGBA")
    mask = Image.open(mask)
    #remove all green and blue from the mask
    mask = mask.convert("RGBA")
    data = mask.getdata()
    data = [
    (255, 0, 0, 255) if pixel[:3] == (255, 255, 255) else pixel
    for pixel in mask.getdata()
    ]
    mask.putdata(data)
    #overlay the mask on the base
    base.paste(mask, (0,0), mask)
    return base
def reset_fn(page):
    base = f"./mangaka-d/{page}/base.png"
    base = Image.open(base)
    return base
with gr.Blocks() as demo:
    with gr.Tab("Mangaka"):
        with gr.Row():
            with gr.Column():
                positive = gr.Textbox(label="Positive prompt", lines=2)
                negative = gr.Textbox(label="Negative prompt")
                with gr.Accordion("Page Settings"):
                    with gr.Row():
                        with gr.Column():
                            page = gr.Dropdown(label="Page", choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "s1", "s2", "s3", "s4", "s5", "s6"], value="s1")
                            panel = gr.Dropdown(label="Panel", choices=["1", "2", "3", "4", "5", "6", "7", "8", "none"], value="1")
                            visualize = gr.Button("Visualize")
                        with gr.Column():
                            visualize_output = gr.Image(interactive=False)
                    visualize.click(visualize_fn, inputs=[page, panel], outputs=visualize_output)
            with gr.Column():
                with gr.Row():
                    with gr.Column():
                        generate = gr.Button("Generate", variant="primary")
                    with gr.Column():
                        reset = gr.Button("Reset", variant="stop")
                current_panel = gr.Image(interactive=False)
                reset.click(reset_fn, inputs=[page], outputs=current_panel)
                generate.click(main, inputs=[page, panel, positive, negative, current_panel], outputs=current_panel)

demo.launch()