File size: 4,778 Bytes
3542be4
 
 
2b32e3d
3542be4
 
 
 
 
c689a76
3542be4
 
a50b44f
 
2b32e3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a50b44f
 
2b32e3d
a50b44f
2b32e3d
 
a50b44f
2b32e3d
a50b44f
2b32e3d
a50b44f
 
 
2b32e3d
a50b44f
 
 
 
 
 
 
 
 
2b32e3d
a50b44f
 
 
 
 
 
3542be4
a50b44f
 
 
 
2b32e3d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import spaces
import gradio as gr
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler
from compel import Compel, ReturnedEmbeddingsType
from PIL import Image
import os
import time

from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation
from utils.prompt_analysis import PromptAnalysis

class Img2Img:
    def __init__(self):
        self.setup_paths()
        self.setup_models()
        self.compel = self.setup_compel()
        self.demo = self.layout()

    def setup_paths(self):
        self.path = os.getcwd()
        self.cn_dir = f"{self.path}/controlnet"
        self.tagger_dir = f"{self.path}/tagger"
        self.lora_dir = f"{self.path}/lora"
        os.makedirs(self.cn_dir, exist_ok=True)
        os.makedirs(self.tagger_dir, exist_ok=True)
        os.makedirs(self.lora_dir, exist_ok=True)

    def setup_models(self):
        load_cn_model(self.cn_dir)
        load_cn_config(self.cn_dir)
        load_tagger_model(self.tagger_dir)
        load_lora_model(self.lora_dir)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.float16
        self.model = "cagliostrolab/animagine-xl-3.1"
        self.scheduler = DDIMScheduler.from_pretrained(self.model, subfolder="scheduler")
        self.controlnet = ControlNetModel.from_pretrained(self.cn_dir, torch_dtype=self.dtype, use_safetensors=True)
        self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
            self.model,
            controlnet=self.controlnet,
            torch_dtype=self.dtype,
            use_safetensors=True,
            scheduler=self.scheduler,
        )
        self.pipe.load_lora_weights(self.lora_dir, weight_name="sdxl_BWLine.safetensors")
        self.pipe = self.pipe.to(self.device)

    def setup_compel(self):
        return Compel(
            tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
            text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
            returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
            requires_pooled=[False, True],
        )

    def layout(self):
        css = """
        #intro{
            max-width: 32rem;
            text-align: center;
            margin: 0 auto;
        }
        """
        with gr.Blocks(css=css) as demo:
            with gr.Row():
                with gr.Column():
                    self.input_image_path = gr.Image(label="入力画像", type='filepath')
                    self.prompt_analysis = PromptAnalysis(self.tagger_dir)
                    self.prompt, self.negative_prompt = self.prompt_analysis.layout(self.input_image_path)
                    self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度")
                    generate_button = gr.Button("生成")
                with gr.Column():
                    self.output_image = gr.Image(type="pil", label="生成画像")

            generate_button.click(
                fn=self.predict,
                inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
                outputs=self.output_image
            )
        return demo

    @spaces.GPU
    def predict(self, input_image_path, prompt, negative_prompt, controlnet_scale):
        input_image_pil = Image.open(input_image_path)
        base_size = input_image_pil.size
        resize_image = resize_image_aspect_ratio(input_image_pil)
        resize_image_size = resize_image.size
        width, height = resize_image_size
        white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
        conditioning, pooled = self.compel([prompt, negative_prompt])
        generator = torch.manual_seed(0)
        last_time = time.time()

        output_image = self.pipe(
            image=white_base_pil,
            control_image=resize_image,
            strength=1.0,
            prompt_embeds=conditioning[0:1],
            pooled_prompt_embeds=pooled[0:1],
            negative_prompt_embeds=conditioning[1:2],
            negative_pooled_prompt_embeds=pooled[1:2],
            width=width,
            height=height,
            controlnet_conditioning_scale=float(controlnet_scale),
            controlnet_start=0.0,
            controlnet_end=1.0,
            generator=generator,
            num_inference_steps=30,
            guidance_scale=8.5,
            eta=1.0,
        )
        print(f"Time taken: {time.time() - last_time}")
        output_image = output_image.resize(base_size, Image.LANCZOS)
        return output_image

img2img = Img2Img()
img2img.demo.launch()