Spaces:
Sleeping
Sleeping
File size: 8,760 Bytes
281df87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import torch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
import numpy as np
import random
from dataset.concat_dataset import ConCatDataset #, collate_fn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import os
from tqdm import tqdm
from distributed import get_rank, synchronize, get_world_size
from trainer import read_official_ckpt, batch_to_device, ImageCaptionSaver, wrap_loader #, get_padded_boxes
from PIL import Image
import math
import json
#hello
def draw_masks_from_boxes(boxes,size):
image_masks = []
for box in boxes:
image_mask = torch.ones(size[0],size[1])
for bx in box:
x0, x1 = bx[0]*size[0], bx[2]*size[0]
y0, y1 = bx[1]*size[1], bx[3]*size[1]
image_mask[int(y0):int(y1), int(x0):int(x1)] = 0
image_masks.append(image_mask)
return torch.stack(image_masks).unsqueeze(1)
def set_alpha_scale(model, alpha_scale):
from ldm.modules.attention import GatedCrossAttentionDense, GatedSelfAttentionDense
for module in model.modules():
if type(module) == GatedCrossAttentionDense or type(module) == GatedSelfAttentionDense:
module.scale = alpha_scale
# print("scale: ", alpha_scale)
# print("attn: ", module.alpha_attn)
# print("dense: ", module.alpha_dense)
# print(' ')
# print(' ')
def save_images(samples, image_ids, folder, to256):
for sample, image_id in zip(samples, image_ids):
sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5
sample = sample.cpu().numpy().transpose(1,2,0) * 255
img_name = str(int(image_id))+'.png'
img = Image.fromarray(sample.astype(np.uint8))
if to256:
img = img.resize( (256,256), Image.BICUBIC)
img.save(os.path.join(folder,img_name))
def ckpt_to_folder_name(basename):
name=""
for s in basename:
if s.isdigit():
name+=s
seen = round( int(name)/1000, 1 )
return str(seen).ljust(4,'0')+'k'
class Evaluator:
def __init__(self, config):
self.config = config
self.device = torch.device("cuda")
# = = = = = create model and diffusion = = = = = #
if self.config.ckpt != "real":
self.model = instantiate_from_config(config.model).to(self.device)
self.autoencoder = instantiate_from_config(config.autoencoder).to(self.device)
self.text_encoder = instantiate_from_config(config.text_encoder).to(self.device)
self.diffusion = instantiate_from_config(config.diffusion).to(self.device)
# donot need to load official_ckpt for self.model here, since we will load from our ckpt
state_dict = read_official_ckpt( os.path.join(config.DATA_ROOT, config.official_ckpt_name) )
self.autoencoder.load_state_dict( state_dict["autoencoder"] )
self.text_encoder.load_state_dict( state_dict["text_encoder"] )
self.diffusion.load_state_dict( state_dict["diffusion"] )
# = = = = = load from our ckpt = = = = = #
if self.config.ckpt == "real":
print("Saving all real images...")
self.just_save_real = True
else:
checkpoint = torch.load(self.config.ckpt, map_location="cpu")
which_state = 'ema' if 'ema' in checkpoint else "model"
which_state = which_state if config.which_state is None else config.which_state
self.model.load_state_dict(checkpoint[which_state])
print("ckpt is loaded")
self.just_save_real = False
set_alpha_scale(self.model, self.config.alpha_scale)
self.autoencoder.eval()
self.model.eval()
self.text_encoder.eval()
# = = = = = create data = = = = = #
self.dataset_eval = ConCatDataset(config.val_dataset_names, config.DATA_ROOT, config.which_embedder, train=False)
print("total eval images: ", len(self.dataset_eval))
sampler = DistributedSampler(self.dataset_eval,shuffle=False) if config.distributed else None
loader_eval = DataLoader( self.dataset_eval,batch_size=config.batch_size,
num_workers=config.workers,
pin_memory=True,
sampler=sampler,
drop_last=False) # shuffle default is False
self.loader_eval = loader_eval
# = = = = = create output folder = = = = = #
folder_name = ckpt_to_folder_name(os.path.basename(config.ckpt))
self.outdir = os.path.join(config.OUTPUT_ROOT, folder_name)
self.outdir_real = os.path.join(self.outdir,'real')
self.outdir_fake = os.path.join(self.outdir,'fake')
if config.to256:
self.outdir_real256 = os.path.join(self.outdir,'real256')
self.outdir_fake256 = os.path.join(self.outdir,'fake256')
synchronize() # if rank0 is faster, it may mkdir before the other rank call os.listdir()
if get_rank() == 0:
os.makedirs(self.outdir, exist_ok=True)
os.makedirs(self.outdir_real, exist_ok=True)
os.makedirs(self.outdir_fake, exist_ok=True)
if config.to256:
os.makedirs(self.outdir_real256, exist_ok=True)
os.makedirs(self.outdir_fake256, exist_ok=True)
print(self.outdir) # double check
self.evaluation_finished = False
if os.path.exists( os.path.join(self.outdir,'score.txt') ):
self.evaluation_finished = True
def alread_saved_this_batch(self, batch):
existing_real_files = os.listdir( self.outdir_real )
existing_fake_files = os.listdir( self.outdir_fake )
status = []
for image_id in batch["id"]:
img_name = str(int(image_id))+'.png'
status.append(img_name in existing_real_files)
status.append(img_name in existing_fake_files)
return all(status)
@torch.no_grad()
def start_evaluating(self):
iterator = tqdm( self.loader_eval, desc='Evaluating progress')
for batch in iterator:
#if not self.alread_saved_this_batch(batch):
if True:
batch_to_device(batch, self.device)
batch_size = batch["image"].shape[0]
samples_real = batch["image"]
if self.just_save_real:
samples_fake = None
else:
uc = self.text_encoder.encode( batch_size*[""] )
context = self.text_encoder.encode( batch["caption"] )
image_mask = x0 = None
if self.config.inpaint:
image_mask = draw_masks_from_boxes( batch['boxes'], self.model.image_size ).cuda()
x0 = self.autoencoder.encode( batch["image"] )
shape = (batch_size, self.model.in_channels, self.model.image_size, self.model.image_size)
if self.config.no_plms:
sampler = DDIMSampler(self.diffusion, self.model)
steps = 250
else:
sampler = PLMSSampler(self.diffusion, self.model)
steps = 50
input = dict( x=None, timesteps=None, context=context, boxes=batch['boxes'], masks=batch['masks'], positive_embeddings=batch["positive_embeddings"] )
samples_fake = sampler.sample(S=steps, shape=shape, input=input, uc=uc, guidance_scale=self.config.guidance_scale, mask=image_mask, x0=x0)
samples_fake = self.autoencoder.decode(samples_fake)
save_images(samples_real, batch['id'], self.outdir_real, to256=False )
if self.config.to256:
save_images(samples_real, batch['id'], self.outdir_real256, to256=True )
if samples_fake is not None:
save_images(samples_fake, batch['id'], self.outdir_fake, to256=False )
if self.config.to256:
save_images(samples_fake, batch['id'], self.outdir_fake256, to256=True )
def fire_fid(self):
paths = [self.outdir_real, self.outdir_fake]
if self.config.to256:
paths = [self.outdir_real256, self.outdir_fake256]
|