Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,334 Bytes
3542be4 2b32e3d 3542be4 c689a76 3542be4 a50b44f 2b32e3d 5826348 2b32e3d 5826348 2b32e3d a50b44f 2b32e3d a50b44f 2b32e3d a50b44f 2b32e3d a50b44f 2b32e3d a50b44f 2b32e3d a50b44f 2b32e3d a50b44f 3542be4 a50b44f 2b32e3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import spaces
import gradio as gr
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler
from PIL import Image
import os
import time
from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation
from utils.prompt_analysis import PromptAnalysis
class Img2Img:
def __init__(self):
self.setup_paths()
self.setup_models()
self.demo = self.layout()
def setup_paths(self):
self.path = os.getcwd()
self.cn_dir = f"{self.path}/controlnet"
self.tagger_dir = f"{self.path}/tagger"
self.lora_dir = f"{self.path}/lora"
os.makedirs(self.cn_dir, exist_ok=True)
os.makedirs(self.tagger_dir, exist_ok=True)
os.makedirs(self.lora_dir, exist_ok=True)
def setup_models(self):
load_cn_model(self.cn_dir)
load_cn_config(self.cn_dir)
load_tagger_model(self.tagger_dir)
load_lora_model(self.lora_dir)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.float16
self.model = "cagliostrolab/animagine-xl-3.1"
self.scheduler = DDIMScheduler.from_pretrained(self.model, subfolder="scheduler")
self.controlnet = ControlNetModel.from_pretrained(self.cn_dir, torch_dtype=self.dtype, use_safetensors=True)
self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
self.model,
controlnet=self.controlnet,
torch_dtype=self.dtype,
use_safetensors=True,
scheduler=self.scheduler,
)
self.pipe.load_lora_weights(self.lora_dir, weight_name="sdxl_BWLine.safetensors")
self.pipe = self.pipe.to(self.device)
def layout(self):
css = """
#intro{
max-width: 32rem;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
self.input_image_path = gr.Image(label="input_image", type='filepath')
self.prompt_analysis = PromptAnalysis(self.tagger_dir)
self.prompt, self.negative_prompt = self.prompt_analysis.layout(self.input_image_path)
self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度")
generate_button = gr.Button("生成")
with gr.Column():
self.output_image = gr.Image(type="pil", label="出力画像")
generate_button.click(
fn=self.predict,
inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
outputs=self.output_image
)
return demo
@spaces.GPU
def predict(self, input_image_path, prompt, negative_prompt, controlnet_scale):
input_image_pil = Image.open(input_image_path)
base_size = input_image_pil.size
resize_image = resize_image_aspect_ratio(input_image_pil)
resize_image_size = resize_image.size
width, height = resize_image_size
white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
conditioning, pooled = self.compel([prompt, negative_prompt])
generator = torch.manual_seed(0)
last_time = time.time()
output_image = self.pipe(
image=white_base_pil,
control_image=resize_image,
strength=1.0,
prompt_embeds=conditioning[0:1],
pooled_prompt_embeds=pooled[0:1],
negative_prompt_embeds=conditioning[1:2],
negative_pooled_prompt_embeds=pooled[1:2],
width=width,
height=height,
controlnet_conditioning_scale=float(controlnet_scale),
controlnet_start=0.0,
controlnet_end=1.0,
generator=generator,
num_inference_steps=30,
guidance_scale=8.5,
eta=1.0,
)
print(f"Time taken: {time.time() - last_time}")
output_image = output_image.resize(base_size, Image.LANCZOS)
return output_image
img2img = Img2Img()
img2img.demo.launch()
|