Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,778 Bytes
3542be4 2b32e3d 3542be4 c689a76 3542be4 a50b44f 2b32e3d a50b44f 2b32e3d a50b44f 2b32e3d a50b44f 2b32e3d a50b44f 2b32e3d a50b44f 2b32e3d a50b44f 2b32e3d a50b44f 3542be4 a50b44f 2b32e3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import spaces
import gradio as gr
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler
from compel import Compel, ReturnedEmbeddingsType
from PIL import Image
import os
import time
from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation
from utils.prompt_analysis import PromptAnalysis
class Img2Img:
def __init__(self):
self.setup_paths()
self.setup_models()
self.compel = self.setup_compel()
self.demo = self.layout()
def setup_paths(self):
self.path = os.getcwd()
self.cn_dir = f"{self.path}/controlnet"
self.tagger_dir = f"{self.path}/tagger"
self.lora_dir = f"{self.path}/lora"
os.makedirs(self.cn_dir, exist_ok=True)
os.makedirs(self.tagger_dir, exist_ok=True)
os.makedirs(self.lora_dir, exist_ok=True)
def setup_models(self):
load_cn_model(self.cn_dir)
load_cn_config(self.cn_dir)
load_tagger_model(self.tagger_dir)
load_lora_model(self.lora_dir)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.float16
self.model = "cagliostrolab/animagine-xl-3.1"
self.scheduler = DDIMScheduler.from_pretrained(self.model, subfolder="scheduler")
self.controlnet = ControlNetModel.from_pretrained(self.cn_dir, torch_dtype=self.dtype, use_safetensors=True)
self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
self.model,
controlnet=self.controlnet,
torch_dtype=self.dtype,
use_safetensors=True,
scheduler=self.scheduler,
)
self.pipe.load_lora_weights(self.lora_dir, weight_name="sdxl_BWLine.safetensors")
self.pipe = self.pipe.to(self.device)
def setup_compel(self):
return Compel(
tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True],
)
def layout(self):
css = """
#intro{
max-width: 32rem;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
self.input_image_path = gr.Image(label="入力画像", type='filepath')
self.prompt_analysis = PromptAnalysis(self.tagger_dir)
self.prompt, self.negative_prompt = self.prompt_analysis.layout(self.input_image_path)
self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度")
generate_button = gr.Button("生成")
with gr.Column():
self.output_image = gr.Image(type="pil", label="生成画像")
generate_button.click(
fn=self.predict,
inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
outputs=self.output_image
)
return demo
@spaces.GPU
def predict(self, input_image_path, prompt, negative_prompt, controlnet_scale):
input_image_pil = Image.open(input_image_path)
base_size = input_image_pil.size
resize_image = resize_image_aspect_ratio(input_image_pil)
resize_image_size = resize_image.size
width, height = resize_image_size
white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
conditioning, pooled = self.compel([prompt, negative_prompt])
generator = torch.manual_seed(0)
last_time = time.time()
output_image = self.pipe(
image=white_base_pil,
control_image=resize_image,
strength=1.0,
prompt_embeds=conditioning[0:1],
pooled_prompt_embeds=pooled[0:1],
negative_prompt_embeds=conditioning[1:2],
negative_pooled_prompt_embeds=pooled[1:2],
width=width,
height=height,
controlnet_conditioning_scale=float(controlnet_scale),
controlnet_start=0.0,
controlnet_end=1.0,
generator=generator,
num_inference_steps=30,
guidance_scale=8.5,
eta=1.0,
)
print(f"Time taken: {time.time() - last_time}")
output_image = output_image.resize(base_size, Image.LANCZOS)
return output_image
img2img = Img2Img()
img2img.demo.launch()
|