from __future__ import annotations import logging import os import random import sys import tempfile import gradio as gr import imageio import numpy as np import PIL.Image import torch import tqdm.auto from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline, DiffusionPipeline, PNDMPipeline, PNDMScheduler) HF_TOKEN = os.environ['HF_TOKEN'] formatter = logging.Formatter( '[%(asctime)s] %(name)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') stream_handler = logging.StreamHandler(stream=sys.stdout) stream_handler.setLevel(logging.INFO) stream_handler.setFormatter(formatter) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger.propagate = False logger.addHandler(stream_handler) class Model: MODEL_NAMES = [ 'ddpm-128-exp000', ] def __init__(self, device: str | torch.device): self.device = torch.device(device) self._download_all_models() self.model_name = self.MODEL_NAMES[0] self.scheduler_type = 'DDIM' self.pipeline = self._load_pipeline(self.model_name, self.scheduler_type) self.rng = random.Random() self.real_esrgan = gr.Interface.load('spaces/hysts/Real-ESRGAN-anime') @staticmethod def _load_pipeline(model_name: str, scheduler_type: str) -> DiffusionPipeline: repo_id = f'hysts/diffusers-anime-faces-{model_name}' if scheduler_type == 'DDPM': pipeline = DDPMPipeline.from_pretrained(repo_id, use_auth_token=HF_TOKEN) elif scheduler_type == 'DDIM': pipeline = DDIMPipeline.from_pretrained(repo_id, use_auth_token=HF_TOKEN) pipeline.scheduler = DDIMScheduler.from_config( repo_id, subfolder='scheduler', use_auth_token=HF_TOKEN) elif scheduler_type == 'PNDM': pipeline = PNDMPipeline.from_pretrained(repo_id, use_auth_token=HF_TOKEN) pipeline.scheduler = PNDMScheduler.from_config( repo_id, subfolder='scheduler', use_auth_token=HF_TOKEN) else: raise ValueError return pipeline def set_pipeline(self, model_name: str, scheduler_type: str) -> None: logger.info('--- set_pipeline ---') logger.info(f'{model_name=}, {scheduler_type=}') if model_name == self.model_name and scheduler_type == self.scheduler_type: logger.info('Skipping') logger.info('--- done ---') return self.model_name = model_name self.scheduler_type = scheduler_type self.pipeline = self._load_pipeline(model_name, scheduler_type) logger.info('--- done ---') def _download_all_models(self) -> None: for name in self.MODEL_NAMES: self._load_pipeline(name, 'DDPM') def generate(self, seed: int, num_steps: int, num_images: int = 1) -> list[PIL.Image.Image]: logger.info('--- generate ---') logger.info(f'{seed=}, {num_steps=}') torch.manual_seed(seed) if self.scheduler_type == 'DDPM': res = self.pipeline(batch_size=num_images, torch_device=self.device)['sample'] elif self.scheduler_type in ['DDIM', 'PNDM']: res = self.pipeline(batch_size=num_images, torch_device=self.device, num_inference_steps=num_steps)['sample'] else: raise ValueError logger.info('--- done ---') return res @staticmethod def postprocess(sample: torch.Tensor) -> np.ndarray: res = (sample / 2 + 0.5).clamp(0, 1) res = (res * 255).to(torch.uint8) res = res.cpu().permute(0, 2, 3, 1).numpy() return res @torch.inference_mode() def generate_with_video(self, seed: int, num_steps: int) -> tuple[PIL.Image.Image, str]: logger.info('--- generate_with_video ---') if self.scheduler_type == 'DDPM': num_steps = 1000 fps = 100 else: fps = 10 logger.info(f'{seed=}, {num_steps=}') model = self.pipeline.unet.to(self.device) scheduler = self.pipeline.scheduler scheduler.set_timesteps(num_inference_steps=num_steps) input_shape = (1, model.config.in_channels, model.config.sample_size, model.config.sample_size) torch.manual_seed(seed) out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) writer = imageio.get_writer(out_file.name, fps=fps) sample = torch.randn(input_shape).to(self.device) for t in tqdm.auto.tqdm(scheduler.timesteps): out = model(sample, t)['sample'] sample = scheduler.step(out, t, sample)['prev_sample'] res = self.postprocess(sample)[0] writer.append_data(res) writer.close() logger.info('--- done ---') return PIL.Image.fromarray(res), out_file.name def superresolve(self, image: PIL.Image.Image) -> PIL.Image.Image: logger.info('--- superresolve ---') with tempfile.NamedTemporaryFile(suffix='.png') as f: image.save(f.name) out_file = self.real_esrgan(f.name) logger.info('--- done ---') return PIL.Image.open(out_file) def run(self, model_name: str, scheduler_type: str, num_steps: int, randomize_seed: bool, seed: int) -> tuple[PIL.Image.Image, PIL.Image.Image, int, str]: self.set_pipeline(model_name, scheduler_type) if scheduler_type == 'PNDM': num_steps = max(4, min(num_steps, 100)) if randomize_seed: seed = self.rng.randint(0, 100000) res, filename = self.generate_with_video(seed, num_steps) superresolved = self.superresolve(res) return superresolved, res, seed, filename @staticmethod def to_grid(images: list[PIL.Image.Image], ncols: int = 2) -> PIL.Image.Image: images = [np.asarray(image) for image in images] nrows = (len(images) + ncols - 1) // ncols h, w = images[0].shape[:2] if (d := nrows * ncols - len(images)) > 0: images += [np.full((h, w, 3), 255, dtype=np.uint8)] * d grid = np.asarray(images).reshape(nrows, ncols, h, w, 3).transpose( 0, 2, 1, 3, 4).reshape(nrows * h, ncols * w, 3) return PIL.Image.fromarray(grid) def run_simple(self) -> tuple[PIL.Image.Image, PIL.Image.Image]: self.set_pipeline(self.MODEL_NAMES[0], 'DDIM') seed = self.rng.randint(0, 1000000) images = self.generate(seed, num_steps=10, num_images=4) superresolved = [self.superresolve(image) for image in images] return self.to_grid(superresolved, 2), self.to_grid(images, 2)