|
from __future__ import annotations |
|
|
|
import gc |
|
import json |
|
import tempfile |
|
from typing import Generator |
|
|
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
from diffusers import DiffusionPipeline, StableDiffusionUpscalePipeline |
|
from diffusers.pipelines.deepfloyd_if import (fast27_timesteps, |
|
smart27_timesteps, |
|
smart50_timesteps, |
|
smart100_timesteps, |
|
smart185_timesteps) |
|
|
|
from settings import (DISABLE_AUTOMATIC_CPU_OFFLOAD, DISABLE_SD_X4_UPSCALER, |
|
HF_TOKEN, MAX_NUM_IMAGES, MAX_NUM_STEPS, MAX_SEED, |
|
RUN_GARBAGE_COLLECTION) |
|
|
|
|
|
class Model: |
|
def __init__(self): |
|
self.device = torch.device( |
|
'cuda:0' if torch.cuda.is_available() else 'cpu') |
|
self.pipe = None |
|
self.super_res_1_pipe = None |
|
self.super_res_2_pipe = None |
|
self.watermark_image = None |
|
|
|
if torch.cuda.is_available(): |
|
self.load_weights() |
|
self.watermark_image = PIL.Image.fromarray( |
|
self.pipe.watermarker.watermark_image.to( |
|
torch.uint8).cpu().numpy(), |
|
mode='RGBA') |
|
|
|
def load_weights(self) -> None: |
|
self.pipe = DiffusionPipeline.from_pretrained( |
|
'DeepFloyd/IF-I-XL-v1.0', |
|
torch_dtype=torch.float16, |
|
variant='fp16', |
|
use_safetensors=True, |
|
use_auth_token=HF_TOKEN) |
|
self.super_res_1_pipe = DiffusionPipeline.from_pretrained( |
|
'DeepFloyd/IF-II-L-v1.0', |
|
text_encoder=None, |
|
torch_dtype=torch.float16, |
|
variant='fp16', |
|
use_safetensors=True, |
|
use_auth_token=HF_TOKEN) |
|
|
|
if not DISABLE_SD_X4_UPSCALER: |
|
self.super_res_2_pipe = StableDiffusionUpscalePipeline.from_pretrained( |
|
'stabilityai/stable-diffusion-x4-upscaler', |
|
torch_dtype=torch.float16) |
|
|
|
if DISABLE_AUTOMATIC_CPU_OFFLOAD: |
|
self.pipe.to(self.device) |
|
self.super_res_1_pipe.to(self.device) |
|
|
|
self.pipe.unet.to(memory_format=torch.channels_last) |
|
self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True) |
|
|
|
if not DISABLE_SD_X4_UPSCALER: |
|
self.super_res_2_pipe.to(self.device) |
|
else: |
|
self.pipe.enable_model_cpu_offload() |
|
self.super_res_1_pipe.enable_model_cpu_offload() |
|
if not DISABLE_SD_X4_UPSCALER: |
|
self.super_res_2_pipe.enable_model_cpu_offload() |
|
|
|
def apply_watermark_to_sd_x4_upscaler_results( |
|
self, images: list[PIL.Image.Image]) -> None: |
|
w, h = images[0].size |
|
|
|
stability_x4_upscaler_sample_size = 128 |
|
|
|
coef = min(h / stability_x4_upscaler_sample_size, |
|
w / stability_x4_upscaler_sample_size) |
|
img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) |
|
|
|
S1, S2 = 1024**2, img_w * img_h |
|
K = (S2 / S1)**0.5 |
|
watermark_size = int(K * 62) |
|
watermark_x = img_w - int(14 * K) |
|
watermark_y = img_h - int(14 * K) |
|
|
|
watermark_image = self.watermark_image.copy().resize( |
|
(watermark_size, watermark_size), |
|
PIL.Image.Resampling.BICUBIC, |
|
reducing_gap=None) |
|
|
|
for image in images: |
|
image.paste(watermark_image, |
|
box=( |
|
watermark_x - watermark_size, |
|
watermark_y - watermark_size, |
|
watermark_x, |
|
watermark_y, |
|
), |
|
mask=watermark_image.split()[-1]) |
|
|
|
@staticmethod |
|
def to_pil_images(images: torch.Tensor) -> list[PIL.Image.Image]: |
|
images = (images / 2 + 0.5).clamp(0, 1) |
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
|
images = np.round(images * 255).astype(np.uint8) |
|
return [PIL.Image.fromarray(image) for image in images] |
|
|
|
@staticmethod |
|
def check_seed(seed: int) -> None: |
|
if not 0 <= seed <= MAX_SEED: |
|
raise ValueError |
|
|
|
@staticmethod |
|
def check_num_images(num_images: int) -> None: |
|
if not 1 <= num_images <= MAX_NUM_IMAGES: |
|
raise ValueError |
|
|
|
@staticmethod |
|
def check_num_inference_steps(num_steps: int) -> None: |
|
if not 1 <= num_steps <= MAX_NUM_STEPS: |
|
raise ValueError |
|
|
|
@staticmethod |
|
def get_custom_timesteps(name: str) -> list[int] | None: |
|
if name == 'none': |
|
timesteps = None |
|
elif name == 'fast27': |
|
timesteps = fast27_timesteps |
|
elif name == 'smart27': |
|
timesteps = smart27_timesteps |
|
elif name == 'smart50': |
|
timesteps = smart50_timesteps |
|
elif name == 'smart100': |
|
timesteps = smart100_timesteps |
|
elif name == 'smart185': |
|
timesteps = smart185_timesteps |
|
else: |
|
raise ValueError |
|
return timesteps |
|
|
|
@staticmethod |
|
def run_garbage_collection(): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def run_stage1( |
|
self, |
|
prompt: str, |
|
negative_prompt: str = '', |
|
seed: int = 0, |
|
num_images: int = 1, |
|
guidance_scale_1: float = 7.0, |
|
custom_timesteps_1: str = 'smart100', |
|
num_inference_steps_1: int = 100, |
|
) -> tuple[list[PIL.Image.Image], str, str]: |
|
self.check_seed(seed) |
|
self.check_num_images(num_images) |
|
self.check_num_inference_steps(num_inference_steps_1) |
|
|
|
if RUN_GARBAGE_COLLECTION: |
|
self.run_garbage_collection() |
|
|
|
generator = torch.Generator(device=self.device).manual_seed(seed) |
|
|
|
prompt_embeds, negative_embeds = self.pipe.encode_prompt( |
|
prompt=prompt, negative_prompt=negative_prompt) |
|
|
|
timesteps = self.get_custom_timesteps(custom_timesteps_1) |
|
|
|
images = self.pipe(prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_embeds, |
|
num_images_per_prompt=num_images, |
|
guidance_scale=guidance_scale_1, |
|
timesteps=timesteps, |
|
num_inference_steps=num_inference_steps_1, |
|
generator=generator, |
|
output_type='pt').images |
|
pil_images = self.to_pil_images(images) |
|
self.pipe.watermarker.apply_watermark( |
|
pil_images, self.pipe.unet.config.sample_size) |
|
|
|
stage1_params = { |
|
'prompt': prompt, |
|
'negative_prompt': negative_prompt, |
|
'seed': seed, |
|
'num_images': num_images, |
|
'guidance_scale_1': guidance_scale_1, |
|
'custom_timesteps_1': custom_timesteps_1, |
|
'num_inference_steps_1': num_inference_steps_1, |
|
} |
|
with tempfile.NamedTemporaryFile(mode='w', delete=False) as param_file: |
|
param_file.write(json.dumps(stage1_params)) |
|
stage1_result = { |
|
'prompt_embeds': prompt_embeds, |
|
'negative_embeds': negative_embeds, |
|
'images': images, |
|
'pil_images': pil_images, |
|
} |
|
with tempfile.NamedTemporaryFile(delete=False) as result_file: |
|
torch.save(stage1_result, result_file.name) |
|
return pil_images, param_file.name, result_file.name |
|
|
|
def run_stage2( |
|
self, |
|
stage1_result_path: str, |
|
stage2_index: int, |
|
seed_2: int = 0, |
|
guidance_scale_2: float = 4.0, |
|
custom_timesteps_2: str = 'smart50', |
|
num_inference_steps_2: int = 50, |
|
disable_watermark: bool = False, |
|
) -> PIL.Image.Image: |
|
self.check_seed(seed_2) |
|
self.check_num_inference_steps(num_inference_steps_2) |
|
|
|
if RUN_GARBAGE_COLLECTION: |
|
self.run_garbage_collection() |
|
|
|
generator = torch.Generator(device=self.device).manual_seed(seed_2) |
|
|
|
stage1_result = torch.load(stage1_result_path) |
|
prompt_embeds = stage1_result['prompt_embeds'] |
|
negative_embeds = stage1_result['negative_embeds'] |
|
images = stage1_result['images'] |
|
images = images[[stage2_index]] |
|
|
|
timesteps = self.get_custom_timesteps(custom_timesteps_2) |
|
|
|
out = self.super_res_1_pipe(image=images, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_embeds, |
|
num_images_per_prompt=1, |
|
guidance_scale=guidance_scale_2, |
|
timesteps=timesteps, |
|
num_inference_steps=num_inference_steps_2, |
|
generator=generator, |
|
output_type='pt', |
|
noise_level=250).images |
|
pil_images = self.to_pil_images(out) |
|
|
|
if disable_watermark: |
|
return pil_images[0] |
|
|
|
self.super_res_1_pipe.watermarker.apply_watermark( |
|
pil_images, self.super_res_1_pipe.unet.config.sample_size) |
|
return pil_images[0] |
|
|
|
def run_stage3( |
|
self, |
|
image: PIL.Image.Image, |
|
prompt: str = '', |
|
negative_prompt: str = '', |
|
seed_3: int = 0, |
|
guidance_scale_3: float = 9.0, |
|
num_inference_steps_3: int = 75, |
|
) -> PIL.Image.Image: |
|
self.check_seed(seed_3) |
|
self.check_num_inference_steps(num_inference_steps_3) |
|
|
|
if RUN_GARBAGE_COLLECTION: |
|
self.run_garbage_collection() |
|
|
|
generator = torch.Generator(device=self.device).manual_seed(seed_3) |
|
out = self.super_res_2_pipe(image=image, |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
num_images_per_prompt=1, |
|
guidance_scale=guidance_scale_3, |
|
num_inference_steps=num_inference_steps_3, |
|
generator=generator, |
|
noise_level=100).images |
|
self.apply_watermark_to_sd_x4_upscaler_results(out) |
|
return out[0] |
|
|
|
def run_stage2_3( |
|
self, |
|
stage1_result_path: str, |
|
stage2_index: int, |
|
seed_2: int = 0, |
|
guidance_scale_2: float = 4.0, |
|
custom_timesteps_2: str = 'smart50', |
|
num_inference_steps_2: int = 50, |
|
prompt: str = '', |
|
negative_prompt: str = '', |
|
seed_3: int = 0, |
|
guidance_scale_3: float = 9.0, |
|
num_inference_steps_3: int = 75, |
|
) -> Generator[PIL.Image.Image]: |
|
self.check_seed(seed_3) |
|
self.check_num_inference_steps(num_inference_steps_3) |
|
|
|
out_image = self.run_stage2( |
|
stage1_result_path=stage1_result_path, |
|
stage2_index=stage2_index, |
|
seed_2=seed_2, |
|
guidance_scale_2=guidance_scale_2, |
|
custom_timesteps_2=custom_timesteps_2, |
|
num_inference_steps_2=num_inference_steps_2, |
|
disable_watermark=True) |
|
temp_image = out_image.copy() |
|
self.super_res_1_pipe.watermarker.apply_watermark( |
|
[temp_image], self.super_res_1_pipe.unet.config.sample_size) |
|
yield temp_image |
|
yield self.run_stage3(image=out_image, |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
seed_3=seed_3, |
|
guidance_scale_3=guidance_scale_3, |
|
num_inference_steps_3=num_inference_steps_3) |
|
|