Spaces:
Sleeping
Sleeping
import torch | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
from ldm.util import instantiate_from_config | |
import numpy as np | |
import random | |
from dataset.concat_dataset import ConCatDataset #, collate_fn | |
from torch.utils.data import DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
import os | |
from tqdm import tqdm | |
from distributed import get_rank, synchronize, get_world_size | |
from trainer import read_official_ckpt, batch_to_device, ImageCaptionSaver, wrap_loader #, get_padded_boxes | |
from PIL import Image | |
import math | |
import json | |
#hello | |
def draw_masks_from_boxes(boxes,size): | |
image_masks = [] | |
for box in boxes: | |
image_mask = torch.ones(size[0],size[1]) | |
for bx in box: | |
x0, x1 = bx[0]*size[0], bx[2]*size[0] | |
y0, y1 = bx[1]*size[1], bx[3]*size[1] | |
image_mask[int(y0):int(y1), int(x0):int(x1)] = 0 | |
image_masks.append(image_mask) | |
return torch.stack(image_masks).unsqueeze(1) | |
def set_alpha_scale(model, alpha_scale): | |
from ldm.modules.attention import GatedCrossAttentionDense, GatedSelfAttentionDense | |
for module in model.modules(): | |
if type(module) == GatedCrossAttentionDense or type(module) == GatedSelfAttentionDense: | |
module.scale = alpha_scale | |
# print("scale: ", alpha_scale) | |
# print("attn: ", module.alpha_attn) | |
# print("dense: ", module.alpha_dense) | |
# print(' ') | |
# print(' ') | |
def save_images(samples, image_ids, folder, to256): | |
for sample, image_id in zip(samples, image_ids): | |
sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5 | |
sample = sample.cpu().numpy().transpose(1,2,0) * 255 | |
img_name = str(int(image_id))+'.png' | |
img = Image.fromarray(sample.astype(np.uint8)) | |
if to256: | |
img = img.resize( (256,256), Image.BICUBIC) | |
img.save(os.path.join(folder,img_name)) | |
def ckpt_to_folder_name(basename): | |
name="" | |
for s in basename: | |
if s.isdigit(): | |
name+=s | |
seen = round( int(name)/1000, 1 ) | |
return str(seen).ljust(4,'0')+'k' | |
class Evaluator: | |
def __init__(self, config): | |
self.config = config | |
self.device = torch.device("cuda") | |
# = = = = = create model and diffusion = = = = = # | |
if self.config.ckpt != "real": | |
self.model = instantiate_from_config(config.model).to(self.device) | |
self.autoencoder = instantiate_from_config(config.autoencoder).to(self.device) | |
self.text_encoder = instantiate_from_config(config.text_encoder).to(self.device) | |
self.diffusion = instantiate_from_config(config.diffusion).to(self.device) | |
# donot need to load official_ckpt for self.model here, since we will load from our ckpt | |
state_dict = read_official_ckpt( os.path.join(config.DATA_ROOT, config.official_ckpt_name) ) | |
self.autoencoder.load_state_dict( state_dict["autoencoder"] ) | |
self.text_encoder.load_state_dict( state_dict["text_encoder"] ) | |
self.diffusion.load_state_dict( state_dict["diffusion"] ) | |
# = = = = = load from our ckpt = = = = = # | |
if self.config.ckpt == "real": | |
print("Saving all real images...") | |
self.just_save_real = True | |
else: | |
checkpoint = torch.load(self.config.ckpt, map_location="cpu") | |
which_state = 'ema' if 'ema' in checkpoint else "model" | |
which_state = which_state if config.which_state is None else config.which_state | |
self.model.load_state_dict(checkpoint[which_state]) | |
print("ckpt is loaded") | |
self.just_save_real = False | |
set_alpha_scale(self.model, self.config.alpha_scale) | |
self.autoencoder.eval() | |
self.model.eval() | |
self.text_encoder.eval() | |
# = = = = = create data = = = = = # | |
self.dataset_eval = ConCatDataset(config.val_dataset_names, config.DATA_ROOT, config.which_embedder, train=False) | |
print("total eval images: ", len(self.dataset_eval)) | |
sampler = DistributedSampler(self.dataset_eval,shuffle=False) if config.distributed else None | |
loader_eval = DataLoader( self.dataset_eval,batch_size=config.batch_size, | |
num_workers=config.workers, | |
pin_memory=True, | |
sampler=sampler, | |
drop_last=False) # shuffle default is False | |
self.loader_eval = loader_eval | |
# = = = = = create output folder = = = = = # | |
folder_name = ckpt_to_folder_name(os.path.basename(config.ckpt)) | |
self.outdir = os.path.join(config.OUTPUT_ROOT, folder_name) | |
self.outdir_real = os.path.join(self.outdir,'real') | |
self.outdir_fake = os.path.join(self.outdir,'fake') | |
if config.to256: | |
self.outdir_real256 = os.path.join(self.outdir,'real256') | |
self.outdir_fake256 = os.path.join(self.outdir,'fake256') | |
synchronize() # if rank0 is faster, it may mkdir before the other rank call os.listdir() | |
if get_rank() == 0: | |
os.makedirs(self.outdir, exist_ok=True) | |
os.makedirs(self.outdir_real, exist_ok=True) | |
os.makedirs(self.outdir_fake, exist_ok=True) | |
if config.to256: | |
os.makedirs(self.outdir_real256, exist_ok=True) | |
os.makedirs(self.outdir_fake256, exist_ok=True) | |
print(self.outdir) # double check | |
self.evaluation_finished = False | |
if os.path.exists( os.path.join(self.outdir,'score.txt') ): | |
self.evaluation_finished = True | |
def alread_saved_this_batch(self, batch): | |
existing_real_files = os.listdir( self.outdir_real ) | |
existing_fake_files = os.listdir( self.outdir_fake ) | |
status = [] | |
for image_id in batch["id"]: | |
img_name = str(int(image_id))+'.png' | |
status.append(img_name in existing_real_files) | |
status.append(img_name in existing_fake_files) | |
return all(status) | |
def start_evaluating(self): | |
iterator = tqdm( self.loader_eval, desc='Evaluating progress') | |
for batch in iterator: | |
#if not self.alread_saved_this_batch(batch): | |
if True: | |
batch_to_device(batch, self.device) | |
batch_size = batch["image"].shape[0] | |
samples_real = batch["image"] | |
if self.just_save_real: | |
samples_fake = None | |
else: | |
uc = self.text_encoder.encode( batch_size*[""] ) | |
context = self.text_encoder.encode( batch["caption"] ) | |
image_mask = x0 = None | |
if self.config.inpaint: | |
image_mask = draw_masks_from_boxes( batch['boxes'], self.model.image_size ).cuda() | |
x0 = self.autoencoder.encode( batch["image"] ) | |
shape = (batch_size, self.model.in_channels, self.model.image_size, self.model.image_size) | |
if self.config.no_plms: | |
sampler = DDIMSampler(self.diffusion, self.model) | |
steps = 250 | |
else: | |
sampler = PLMSSampler(self.diffusion, self.model) | |
steps = 50 | |
input = dict( x=None, timesteps=None, context=context, boxes=batch['boxes'], masks=batch['masks'], positive_embeddings=batch["positive_embeddings"] ) | |
samples_fake = sampler.sample(S=steps, shape=shape, input=input, uc=uc, guidance_scale=self.config.guidance_scale, mask=image_mask, x0=x0) | |
samples_fake = self.autoencoder.decode(samples_fake) | |
save_images(samples_real, batch['id'], self.outdir_real, to256=False ) | |
if self.config.to256: | |
save_images(samples_real, batch['id'], self.outdir_real256, to256=True ) | |
if samples_fake is not None: | |
save_images(samples_fake, batch['id'], self.outdir_fake, to256=False ) | |
if self.config.to256: | |
save_images(samples_fake, batch['id'], self.outdir_fake256, to256=True ) | |
def fire_fid(self): | |
paths = [self.outdir_real, self.outdir_fake] | |
if self.config.to256: | |
paths = [self.outdir_real256, self.outdir_fake256] | |