Spaces:
Sleeping
Sleeping
import torch | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
from ldm.util import instantiate_from_config | |
import numpy as np | |
import random | |
import time | |
from dataset.concat_dataset import ConCatDataset #, collate_fn | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
import os | |
import shutil | |
import torchvision | |
import math | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from tqdm import tqdm | |
from distributed import get_rank, synchronize, get_world_size | |
from transformers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup | |
from copy import deepcopy | |
try: | |
from apex import amp | |
except: | |
pass | |
# = = = = = = = = = = = = = = = = = = useful functions = = = = = = = = = = = = = = = = = # | |
class ImageCaptionSaver: | |
def __init__(self, base_path, nrow=8, normalize=True, scale_each=True, range=(-1,1) ): | |
self.base_path = base_path | |
self.nrow = nrow | |
self.normalize = normalize | |
self.scale_each = scale_each | |
self.range = range | |
def __call__(self, images, real, captions, seen): | |
save_path = os.path.join(self.base_path, str(seen).zfill(8)+'.png') | |
torchvision.utils.save_image( images, save_path, nrow=self.nrow, normalize=self.normalize, scale_each=self.scale_each, range=self.range ) | |
save_path = os.path.join(self.base_path, str(seen).zfill(8)+'_real.png') | |
torchvision.utils.save_image( real, save_path, nrow=self.nrow) | |
assert images.shape[0] == len(captions) | |
save_path = os.path.join(self.base_path, 'captions.txt') | |
with open(save_path, "a") as f: | |
f.write( str(seen).zfill(8) + ':\n' ) | |
for cap in captions: | |
f.write( cap + '\n' ) | |
f.write( '\n' ) | |
def read_official_ckpt(ckpt_path): | |
"Read offical pretrained ckpt and convert into my style" | |
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] | |
out = {} | |
out["model"] = {} | |
out["text_encoder"] = {} | |
out["autoencoder"] = {} | |
out["unexpected"] = {} | |
out["diffusion"] = {} | |
for k,v in state_dict.items(): | |
if k.startswith('model.diffusion_model'): | |
out["model"][k.replace("model.diffusion_model.", "")] = v | |
elif k.startswith('cond_stage_model'): | |
out["text_encoder"][k.replace("cond_stage_model.", "")] = v | |
elif k.startswith('first_stage_model'): | |
out["autoencoder"][k.replace("first_stage_model.", "")] = v | |
elif k in ["model_ema.decay", "model_ema.num_updates"]: | |
out["unexpected"][k] = v | |
else: | |
out["diffusion"][k] = v | |
return out | |
def batch_to_device(batch, device): | |
for k in batch: | |
if isinstance(batch[k], torch.Tensor): | |
batch[k] = batch[k].to(device) | |
return batch | |
def sub_batch(batch, num=1): | |
# choose first num in given batch | |
num = num if num > 1 else 1 | |
for k in batch: | |
batch[k] = batch[k][0:num] | |
return batch | |
def wrap_loader(loader): | |
while True: | |
for batch in loader: # TODO: it seems each time you have the same order for all epoch?? | |
yield batch | |
def disable_grads(model): | |
for p in model.parameters(): | |
p.requires_grad = False | |
def count_params(params): | |
total_trainable_params_count = 0 | |
for p in params: | |
total_trainable_params_count += p.numel() | |
print("total_trainable_params_count is: ", total_trainable_params_count) | |
def update_ema(target_params, source_params, rate=0.99): | |
for targ, src in zip(target_params, source_params): | |
targ.detach().mul_(rate).add_(src, alpha=1 - rate) | |
def create_expt_folder_with_auto_resuming(OUTPUT_ROOT, name): | |
#curr_folder_name = os.getcwd().split("/")[-1] | |
name = os.path.join( OUTPUT_ROOT, name ) | |
writer = None | |
checkpoint = None | |
if os.path.exists(name): | |
all_tags = os.listdir(name) | |
all_existing_tags = [ tag for tag in all_tags if tag.startswith('tag') ] | |
all_existing_tags.sort() | |
all_existing_tags = all_existing_tags[::-1] | |
for previous_tag in all_existing_tags: | |
potential_ckpt = os.path.join( name, previous_tag, 'checkpoint_latest.pth' ) | |
if os.path.exists(potential_ckpt): | |
checkpoint = potential_ckpt | |
if get_rank() == 0: | |
print('ckpt found '+ potential_ckpt) | |
break | |
curr_tag = 'tag'+str(len(all_existing_tags)).zfill(2) | |
name = os.path.join( name, curr_tag ) # output/name/tagxx | |
else: | |
name = os.path.join( name, 'tag00' ) # output/name/tag00 | |
if get_rank() == 0: | |
os.makedirs(name) | |
os.makedirs( os.path.join(name,'Log') ) | |
writer = SummaryWriter( os.path.join(name,'Log') ) | |
return name, writer, checkpoint | |
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # | |
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # | |
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # | |
class Trainer: | |
def __init__(self, config): | |
self.config = config | |
self.device = torch.device("cuda") | |
self.l_simple_weight = 1 | |
self.name, self.writer, checkpoint = create_expt_folder_with_auto_resuming(config.OUTPUT_ROOT, config.name) | |
if get_rank() == 0: | |
shutil.copyfile(config.yaml_file, os.path.join(self.name, "train_config_file.yaml") ) | |
torch.save( vars(config), os.path.join(self.name, "config_dict.pth") ) | |
# = = = = = = = = = = create model and diffusion = = = = = = = = = = # | |
self.model = instantiate_from_config(config.model).to(self.device) | |
self.autoencoder = instantiate_from_config(config.autoencoder).to(self.device) | |
self.text_encoder = instantiate_from_config(config.text_encoder).to(self.device) | |
self.diffusion = instantiate_from_config(config.diffusion).to(self.device) | |
state_dict = read_official_ckpt( os.path.join(config.DATA_ROOT, config.official_ckpt_name) ) | |
missing_keys, unexpected_keys = self.model.load_state_dict( state_dict["model"], strict=False ) | |
assert unexpected_keys == [] | |
original_params_names = list( state_dict["model"].keys() ) | |
self.autoencoder.load_state_dict( state_dict["autoencoder"] ) | |
self.text_encoder.load_state_dict( state_dict["text_encoder"] ) | |
self.diffusion.load_state_dict( state_dict["diffusion"] ) | |
self.autoencoder.eval() | |
self.text_encoder.eval() | |
disable_grads(self.autoencoder) | |
disable_grads(self.text_encoder) | |
# = = load from ckpt: (usually second stage whole model finetune) = = # | |
if self.config.ckpt is not None: | |
first_stage_ckpt = torch.load(self.config.ckpt, map_location="cpu") | |
self.model.load_state_dict(first_stage_ckpt["model"]) | |
# = = = = = = = = = = create opt = = = = = = = = = = # | |
print(" ") | |
print("IMPORTANT: following code decides which params trainable!") | |
print(" ") | |
if self.config.whole: | |
print("Entire model is trainable") | |
params = list(self.model.parameters()) | |
else: | |
print("Only new added components will be updated") | |
params = [] | |
trainable_names = [] | |
for name, p in self.model.named_parameters(): | |
if ("transformer_blocks" in name) and ("fuser" in name): | |
params.append(p) | |
trainable_names.append(name) | |
elif "position_net" in name: | |
params.append(p) | |
trainable_names.append(name) | |
else: | |
# all new added trainable params have to be haddled above | |
# otherwise it will trigger the following error | |
assert name in original_params_names, name | |
all_params_name = list( self.model.state_dict().keys() ) | |
assert set(all_params_name) == set(trainable_names + original_params_names) | |
self.opt = torch.optim.AdamW(params, lr=config.base_learning_rate, weight_decay=config.weight_decay) | |
count_params(params) | |
self.master_params = list(self.model.parameters()) # note: you cannot assign above params as master_params since that is only trainable one | |
if config.enable_ema: | |
self.ema = deepcopy(self.model) | |
self.ema_params = list(self.ema.parameters()) | |
self.ema.eval() | |
# = = = = = = = = = = create scheduler = = = = = = = = = = # | |
if config.scheduler_type == "cosine": | |
self.scheduler = get_cosine_schedule_with_warmup(self.opt, num_warmup_steps=config.warmup_steps, num_training_steps=config.total_iters) | |
elif config.scheduler_type == "constant": | |
self.scheduler = get_constant_schedule_with_warmup(self.opt, num_warmup_steps=config.warmup_steps) | |
else: | |
assert False | |
# = = = = = = = = = = create data = = = = = = = = = = # | |
train_dataset_repeats = config.train_dataset_repeats if 'train_dataset_repeats' in config else None | |
dataset_train = ConCatDataset(config.train_dataset_names, config.DATA_ROOT, config.which_embedder, train=True, repeats=train_dataset_repeats) | |
sampler = DistributedSampler(dataset_train) if config.distributed else None | |
loader_train = DataLoader( dataset_train, batch_size=config.batch_size, | |
shuffle=(sampler is None), | |
num_workers=config.workers, | |
pin_memory=True, | |
sampler=sampler) | |
self.dataset_train = dataset_train | |
self.loader_train = wrap_loader(loader_train) | |
if get_rank() == 0: | |
total_image = dataset_train.total_images() | |
print("Total training images: ", total_image) | |
# = = = = = = = = = = load from autoresuming ckpt = = = = = = = = = = # | |
self.starting_iter = 0 | |
if checkpoint is not None: | |
checkpoint = torch.load(checkpoint, map_location="cpu") | |
self.model.load_state_dict(checkpoint["model"]) | |
if config.enable_ema: | |
self.ema.load_state_dict(checkpoint["ema"]) | |
self.opt.load_state_dict(checkpoint["opt"]) | |
self.scheduler.load_state_dict(checkpoint["scheduler"]) | |
self.starting_iter = checkpoint["iters"] | |
if self.starting_iter >= config.total_iters: | |
synchronize() | |
print("Training finished. Start exiting") | |
exit() | |
# = = = = = misc = = = = = # | |
if get_rank() == 0: | |
print("Actual total need see images is: ", config.total_iters*config.total_batch_size) | |
print("Equivalent training epoch is: ", (config.total_iters*config.total_batch_size) / len(dataset_train) ) | |
self.image_caption_saver = ImageCaptionSaver(self.name) | |
# self.counter = Counter(config.total_batch_size, config.save_every_images) | |
if config.use_o2: | |
self.model, self.opt = amp.initialize(self.model, self.opt, opt_level="O2") | |
self.model.use_o2 = True | |
# = = = = = wrap into ddp = = = = = # | |
if config.distributed: | |
self.model = DDP( self.model, device_ids=[config.local_rank], output_device=config.local_rank, broadcast_buffers=False ) | |
def get_input(self, batch): | |
z = self.autoencoder.encode( batch["image"] ) | |
context = self.text_encoder.encode( batch["caption"] ) | |
_t = torch.rand(z.shape[0]).to(z.device) | |
t = (torch.pow(_t, self.config.resample_step_gamma) * 1000).long() | |
t = torch.where(t!=1000, t, 999) # if 1000, then replace it with 999 | |
return z, t, context | |
def run_one_step(self, batch): | |
x_start, t, context = self.get_input(batch) | |
noise = torch.randn_like(x_start) | |
x_noisy = self.diffusion.q_sample(x_start=x_start, t=t, noise=noise) | |
input = dict(x = x_noisy, | |
timesteps = t, | |
context = context, | |
boxes = batch['boxes'], | |
masks = batch['masks'], | |
text_masks = batch['text_masks'], | |
image_masks = batch['image_masks'], | |
text_embeddings = batch["text_embeddings"], | |
image_embeddings = batch["image_embeddings"] ) | |
model_output = self.model(input) | |
loss = torch.nn.functional.mse_loss(model_output, noise) * self.l_simple_weight | |
self.loss_dict = {"loss": loss.item()} | |
return loss | |
def start_training(self): | |
if not self.config.use_o2: | |
# use pytorch mixed training which is similar to o1 but faster | |
scaler = torch.cuda.amp.GradScaler() | |
iterator = tqdm(range(self.starting_iter, self.config.total_iters), desc='Training progress', disable=get_rank() != 0 ) | |
self.model.train() | |
for iter_idx in iterator: # note: iter_idx is not from 0 if resume training | |
self.iter_idx = iter_idx | |
self.opt.zero_grad() | |
batch = next(self.loader_train) | |
batch_to_device(batch, self.device) | |
if self.config.use_o2: | |
loss = self.run_one_step(batch) | |
with amp.scale_loss(loss, self.opt) as scaled_loss: | |
scaled_loss.backward() | |
self.opt.step() | |
else: | |
enabled = True if self.config.use_mixed else False | |
with torch.cuda.amp.autocast(enabled=enabled): # with torch.autocast(enabled=True): | |
loss = self.run_one_step(batch) | |
scaler.scale(loss).backward() | |
scaler.step(self.opt) | |
scaler.update() | |
self.scheduler.step() | |
if self.config.enable_ema: | |
update_ema(self.ema_params, self.master_params, self.config.ema_rate) | |
if (get_rank() == 0): | |
if (iter_idx % 10 == 0): | |
self.log_loss() | |
if (iter_idx == 0) or ( iter_idx % self.config.save_every_iters == 0 ) or (iter_idx == self.config.total_iters-1): | |
self.save_ckpt_and_result() | |
synchronize() | |
synchronize() | |
print("Training finished. Start exiting") | |
exit() | |
def log_loss(self): | |
for k, v in self.loss_dict.items(): | |
self.writer.add_scalar( k, v, self.iter_idx+1 ) # we add 1 as the actual name | |
def save_ckpt_and_result(self): | |
model_wo_wrapper = self.model.module if self.config.distributed else self.model | |
iter_name = self.iter_idx + 1 # we add 1 as the actual name | |
if not self.config.disable_inference_in_training: | |
# Do a quick inference on one training batch | |
batch_here = self.config.batch_size | |
batch = sub_batch( next(self.loader_train), batch_here) | |
batch_to_device(batch, self.device) | |
real_images_with_box_drawing = [] # we save this durining trianing for better visualization | |
for i in range(batch_here): | |
temp_data = {"image": batch["image"][i], "boxes":batch["boxes"][i]} | |
im = self.dataset_train.datasets[0].vis_getitem_data(out=temp_data, return_tensor=True, print_caption=False) | |
real_images_with_box_drawing.append(im) | |
real_images_with_box_drawing = torch.stack(real_images_with_box_drawing) | |
uc = self.text_encoder.encode( batch_here*[""] ) | |
context = self.text_encoder.encode( batch["caption"] ) | |
ddim_sampler = PLMSSampler(self.diffusion, model_wo_wrapper) | |
shape = (batch_here, model_wo_wrapper.in_channels, model_wo_wrapper.image_size, model_wo_wrapper.image_size) | |
input = dict( x = None, | |
timesteps = None, | |
context = context, | |
boxes = batch['boxes'], | |
masks = batch['masks'], | |
text_masks = batch['text_masks'], | |
image_masks = batch['image_masks'], | |
text_embeddings = batch["text_embeddings"], | |
image_embeddings = batch["image_embeddings"] ) | |
samples = ddim_sampler.sample(S=50, shape=shape, input=input, uc=uc, guidance_scale=5) | |
# old | |
# autoencoder_wo_wrapper = self.autoencoder # Note itself is without wrapper since we do not train that. | |
# autoencoder_wo_wrapper = autoencoder_wo_wrapper.cpu() # To save GPU | |
# samples = autoencoder_wo_wrapper.decode(samples.cpu()) | |
# autoencoder_wo_wrapper = autoencoder_wo_wrapper.to(self.device) | |
# new | |
autoencoder_wo_wrapper = self.autoencoder # Note itself is without wrapper since we do not train that. | |
samples = autoencoder_wo_wrapper.decode(samples).cpu() | |
self.image_caption_saver(samples, real_images_with_box_drawing, batch["caption"], iter_name) | |
ckpt = dict(model = model_wo_wrapper.state_dict(), | |
opt = self.opt.state_dict(), | |
scheduler= self.scheduler.state_dict(), | |
iters = self.iter_idx+1 ) | |
if self.config.enable_ema: | |
ckpt["ema"] = self.ema.state_dict() | |
torch.save( ckpt, os.path.join(self.name, "checkpoint_"+str(iter_name).zfill(8)+".pth") ) | |
torch.save( ckpt, os.path.join(self.name, "checkpoint_latest.pth") ) | |