hysts's picture
hysts HF staff
Clean up
80b7378
raw
history blame
4.84 kB
from __future__ import annotations
import logging
import os
import random
import sys
import numpy as np
import PIL.Image
import torch
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
DiffusionPipeline, PNDMPipeline, PNDMScheduler)
HF_TOKEN = os.environ['HF_TOKEN']
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.propagate = False
logger.addHandler(stream_handler)
class Model:
MODEL_NAMES = [
'ddpm-128-exp000',
]
def __init__(self, device: str | torch.device):
self.device = torch.device(device)
self._download_all_models()
self.model_name = self.MODEL_NAMES[0]
self.scheduler_type = 'DDIM'
self.pipeline = self._load_pipeline(self.model_name,
self.scheduler_type)
self.rng = random.Random()
@staticmethod
def _load_pipeline(model_name: str,
scheduler_type: str) -> DiffusionPipeline:
repo_id = f'hysts/diffusers-anime-faces-{model_name}'
if scheduler_type == 'DDPM':
pipeline = DDPMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
elif scheduler_type == 'DDIM':
pipeline = DDIMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
config, _ = DDIMScheduler.extract_init_dict(
dict(pipeline.scheduler.config))
pipeline.scheduler = DDIMScheduler(**config)
elif scheduler_type == 'PNDM':
pipeline = PNDMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
config, _ = PNDMScheduler.extract_init_dict(
dict(pipeline.scheduler.config))
pipeline.scheduler = PNDMScheduler(**config)
else:
raise ValueError
return pipeline
def set_pipeline(self, model_name: str, scheduler_type: str) -> None:
logger.info('--- set_pipeline ---')
logger.info(f'{model_name=}, {scheduler_type=}')
if model_name == self.model_name and scheduler_type == self.scheduler_type:
logger.info('Skipping')
logger.info('--- done ---')
return
self.model_name = model_name
self.scheduler_type = scheduler_type
self.pipeline = self._load_pipeline(model_name, scheduler_type)
logger.info('--- done ---')
def _download_all_models(self) -> None:
for name in self.MODEL_NAMES:
self._load_pipeline(name, 'DDPM')
def generate(self,
seed: int,
num_steps: int,
num_images: int = 1) -> list[PIL.Image.Image]:
logger.info('--- generate ---')
logger.info(f'{seed=}, {num_steps=}')
torch.manual_seed(seed)
if self.scheduler_type == 'DDPM':
res = self.pipeline(batch_size=num_images,
torch_device=self.device)['sample']
elif self.scheduler_type in ['DDIM', 'PNDM']:
res = self.pipeline(batch_size=num_images,
torch_device=self.device,
num_inference_steps=num_steps)['sample']
else:
raise ValueError
logger.info('--- done ---')
return res
def run(
self,
model_name: str,
scheduler_type: str,
num_steps: int,
seed: int,
) -> PIL.Image.Image:
self.set_pipeline(model_name, scheduler_type)
if scheduler_type == 'PNDM':
num_steps = max(4, min(num_steps, 100))
return self.generate(seed, num_steps)[0]
@staticmethod
def to_grid(images: list[PIL.Image.Image],
ncols: int = 2) -> PIL.Image.Image:
images = [np.asarray(image) for image in images]
nrows = (len(images) + ncols - 1) // ncols
h, w = images[0].shape[:2]
if (d := nrows * ncols - len(images)) > 0:
images += [np.full((h, w, 3), 255, dtype=np.uint8)] * d
grid = np.asarray(images).reshape(nrows, ncols, h, w, 3).transpose(
0, 2, 1, 3, 4).reshape(nrows * h, ncols * w, 3)
return PIL.Image.fromarray(grid)
def run_simple(self) -> PIL.Image.Image:
self.set_pipeline(self.MODEL_NAMES[0], 'DDIM')
seed = self.rng.randint(0, 1000000)
images = self.generate(seed, num_steps=10, num_images=4)
return self.to_grid(images, 2)