File size: 4,662 Bytes
3542be4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import spaces
import gradio as gr
from gradio_imageslider import ImageSlider
import torch

torch.jit.script = lambda f: f
from diffusers import (
    ControlNetModel,
    StableDiffusionXLControlNetImg2ImgPipeline,
    DDIMScheduler,
)
from controlnet_aux import AnylineDetector
from compel import Compel, ReturnedEmbeddingsType
from PIL import Image
import os
import time
import numpy as np

from utils.utils import load_cn_model, load_cn_config, load_tagger_model, resize_image_aspect_ratio, base_generation
from utils.prompt_analysis import PromptAnalysis

path = os.getcwd()
cn_dir = f"{path}/controlnet"
tagger_dir = f"{path}/tagger"

load_cn_model(cn_dir)
load_cn_config(cn_dir)
load_tagger_model(tagger_dir)

IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
IS_SPACE = os.environ.get("SPACE_ID", None) is not None

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16

LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"

print(f"device: {device}")
print(f"dtype: {dtype}")
print(f"low memory: {LOW_MEMORY}")


model = "cagliostrolab/animagine-xl-3.1"
scheduler = DDIMScheduler.from_pretrained(model, subfolder="scheduler")
controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
    model,
    controlnet=controlnet,
    torch_dtype=dtype,
    variant="fp16",
    use_safetensors=True,
    scheduler=scheduler,
)

compel = Compel(
    tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
    text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=[False, True],
)
pipe = pipe.to(device)



@spaces.GPU
def predict(
    input_image,
    prompt,
    negative_prompt,
    controlnet_conditioning_scale,
):
    base_size =input_image.size
    resize_image= resize_image_aspect_ratio(input_image)
    resize_image_size = resize_image.size
    width = resize_image_size[0]
    height = resize_image_size[1]
    white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
    conditioning, pooled = compel([prompt, negative_prompt])
    generator = torch.manual_seed(0)
    last_time = time.time()

    output_image = pipe(
        image=white_base_pil,
        control_image=resize_image,
        strength=1.0,
        prompt_embeds=conditioning[0:1],
        pooled_prompt_embeds=pooled[0:1],
        negative_prompt_embeds=conditioning[1:2],
        negative_pooled_prompt_embeds=pooled[1:2],
        width=width,
        height=height,
        controlnet_conditioning_scale=float(controlnet_conditioning_scale),
        controlnet_start=0.0,
        controlnet_end=1.0,
        generator=generator,
        num_inference_steps=30,
        guidance_scale=8.5,
        eta=1.0,
    )
    print(f"Time taken: {time.time() - last_time}")
    output_image = output_image.resize(base_size, Image.LANCZOS)
    return output_image


css = """
#intro{
    # max-width: 32rem;
    # text-align: center;
    # margin: 0 auto;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Row() as block:
        with gr.Column():
            # 画像アップロード用の行
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(label="入力画像", type="pil")
            
            # プロンプト入力用の行
            with gr.Row():
                prompt_analysis = PromptAnalysis(tagger_dir)
                [prompt, nega] = PromptAnalysis.layout(input_image)           
            # 画像の詳細設定用のスライダー行
            with gr.Row():
                controlnet_conditioning_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, interactive=True, label="ラインアートの忠実度")
           
            # 画像生成ボタンの行
            with gr.Row():
                generate_button = gr.Button("生成", interactive=False)

        with gr.Column():
            output_image = gr.Image(type="pil", label="Output Image")

        # インプットとアウトプットの設定
        inputs = [
            input_image,
            prompt,
            nega,
            controlnet_conditioning_scale,
        ]
        outputs = [output_image]
        
        # ボタンのクリックイベントを設定
        generate_button.click(
            fn=predict,
            inputs=[input_image, prompt, nega, controlnet_conditioning_scale],
            outputs=[output_image]
        )

# デモの設定と起動
demo.queue(api_open=True)
demo.launch(show_api=True)