import spaces import gradio as gr import torch from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler from PIL import Image import os import time from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation from utils.prompt_analysis import PromptAnalysis class Img2Img: def __init__(self): self.setup_paths() self.setup_models() self.demo = self.layout() def setup_paths(self): self.path = os.getcwd() self.cn_dir = f"{self.path}/controlnet" self.tagger_dir = f"{self.path}/tagger" self.lora_dir = f"{self.path}/lora" os.makedirs(self.cn_dir, exist_ok=True) os.makedirs(self.tagger_dir, exist_ok=True) os.makedirs(self.lora_dir, exist_ok=True) def setup_models(self): load_cn_model(self.cn_dir) load_cn_config(self.cn_dir) load_tagger_model(self.tagger_dir) load_lora_model(self.lora_dir) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.dtype = torch.float16 self.model = "cagliostrolab/animagine-xl-3.1" self.scheduler = DDIMScheduler.from_pretrained(self.model, subfolder="scheduler") self.controlnet = ControlNetModel.from_pretrained(self.cn_dir, torch_dtype=self.dtype, use_safetensors=True) self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( self.model, controlnet=self.controlnet, torch_dtype=self.dtype, use_safetensors=True, scheduler=self.scheduler, ) self.pipe.load_lora_weights(self.lora_dir, weight_name="sdxl_BWLine.safetensors") self.pipe = self.pipe.to(self.device) def layout(self): css = """ #intro{ max-width: 32rem; text-align: center; margin: 0 auto; } """ with gr.Blocks(css=css) as demo: with gr.Row(): with gr.Column(): self.input_image_path = gr.Image(label="input_image", type='filepath') self.prompt_analysis = PromptAnalysis(self.tagger_dir) self.prompt, self.negative_prompt = self.prompt_analysis.layout(self.input_image_path) self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度") generate_button = gr.Button("生成") with gr.Column(): self.output_image = gr.Image(type="pil", label="出力画像") generate_button.click( fn=self.predict, inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale], outputs=self.output_image ) return demo @spaces.GPU def predict(self, input_image_path, prompt, negative_prompt, controlnet_scale): input_image_pil = Image.open(input_image_path) base_size = input_image_pil.size resize_image = resize_image_aspect_ratio(input_image_pil) resize_image_size = resize_image.size width, height = resize_image_size white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB") conditioning, pooled = self.compel([prompt, negative_prompt]) generator = torch.manual_seed(0) last_time = time.time() output_image = self.pipe( image=white_base_pil, control_image=resize_image, strength=1.0, prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2], width=width, height=height, controlnet_conditioning_scale=float(controlnet_scale), controlnet_start=0.0, controlnet_end=1.0, generator=generator, num_inference_steps=30, guidance_scale=8.5, eta=1.0, ) print(f"Time taken: {time.time() - last_time}") output_image = output_image.resize(base_size, Image.LANCZOS) return output_image img2img = Img2Img() img2img.demo.launch()