from __future__ import annotations import logging import os import random import sys import numpy as np import PIL.Image import torch from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline, DiffusionPipeline, PNDMPipeline, PNDMScheduler) HF_TOKEN = os.environ['HF_TOKEN'] formatter = logging.Formatter( '[%(asctime)s] %(name)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') stream_handler = logging.StreamHandler(stream=sys.stdout) stream_handler.setLevel(logging.INFO) stream_handler.setFormatter(formatter) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger.propagate = False logger.addHandler(stream_handler) class Model: MODEL_NAMES = [ 'ddpm-128-exp000', ] def __init__(self, device: str | torch.device): self.device = torch.device(device) self._download_all_models() self.model_name = self.MODEL_NAMES[0] self.scheduler_type = 'DDIM' self.pipeline = self._load_pipeline(self.model_name, self.scheduler_type) self.rng = random.Random() @staticmethod def _load_pipeline(model_name: str, scheduler_type: str) -> DiffusionPipeline: repo_id = f'hysts/diffusers-anime-faces-{model_name}' if scheduler_type == 'DDPM': pipeline = DDPMPipeline.from_pretrained(repo_id, use_auth_token=HF_TOKEN) elif scheduler_type == 'DDIM': pipeline = DDIMPipeline.from_pretrained(repo_id, use_auth_token=HF_TOKEN) config, _ = DDIMScheduler.extract_init_dict( dict(pipeline.scheduler.config)) pipeline.scheduler = DDIMScheduler(**config) elif scheduler_type == 'PNDM': pipeline = PNDMPipeline.from_pretrained(repo_id, use_auth_token=HF_TOKEN) config, _ = PNDMScheduler.extract_init_dict( dict(pipeline.scheduler.config)) pipeline.scheduler = PNDMScheduler(**config) else: raise ValueError return pipeline def set_pipeline(self, model_name: str, scheduler_type: str) -> None: logger.info('--- set_pipeline ---') logger.info(f'{model_name=}, {scheduler_type=}') if model_name == self.model_name and scheduler_type == self.scheduler_type: logger.info('Skipping') logger.info('--- done ---') return self.model_name = model_name self.scheduler_type = scheduler_type self.pipeline = self._load_pipeline(model_name, scheduler_type) logger.info('--- done ---') def _download_all_models(self) -> None: for name in self.MODEL_NAMES: self._load_pipeline(name, 'DDPM') def generate(self, seed: int, num_steps: int, num_images: int = 1) -> list[PIL.Image.Image]: logger.info('--- generate ---') logger.info(f'{seed=}, {num_steps=}') torch.manual_seed(seed) if self.scheduler_type == 'DDPM': res = self.pipeline(batch_size=num_images, torch_device=self.device)['sample'] elif self.scheduler_type in ['DDIM', 'PNDM']: res = self.pipeline(batch_size=num_images, torch_device=self.device, num_inference_steps=num_steps)['sample'] else: raise ValueError logger.info('--- done ---') return res def run( self, model_name: str, scheduler_type: str, num_steps: int, seed: int, ) -> PIL.Image.Image: self.set_pipeline(model_name, scheduler_type) if scheduler_type == 'PNDM': num_steps = max(4, min(num_steps, 100)) return self.generate(seed, num_steps)[0] @staticmethod def to_grid(images: list[PIL.Image.Image], ncols: int = 2) -> PIL.Image.Image: images = [np.asarray(image) for image in images] nrows = (len(images) + ncols - 1) // ncols h, w = images[0].shape[:2] if (d := nrows * ncols - len(images)) > 0: images += [np.full((h, w, 3), 255, dtype=np.uint8)] * d grid = np.asarray(images).reshape(nrows, ncols, h, w, 3).transpose( 0, 2, 1, 3, 4).reshape(nrows * h, ncols * w, 3) return PIL.Image.fromarray(grid) def run_simple(self) -> PIL.Image.Image: self.set_pipeline(self.MODEL_NAMES[0], 'DDIM') seed = self.rng.randint(0, 1000000) images = self.generate(seed, num_steps=10, num_images=4) return self.to_grid(images, 2)