File size: 2,975 Bytes
cab8a49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import torch
import numpy as np
from tqdm import tqdm
from modules import Paella
from torch import nn, optim
from warmup_scheduler import GradualWarmupScheduler
from utils import get_dataloader, load_conditional_models

steps = 100_000
warmup_updates = 10000
batch_size = 16
checkpoint_frequency = 2000
lr = 1e-4

train_device = "cuda"
dataset_path = ""
byt5_model_name = "google/byt5-xl"
vqmodel_path = ""
run_name = "Paella-ByT5-XL-v1"
output_path = "output"
checkpoint_path = f"{run_name}.pt"


def train():
    os.makedirs(output_path, exist_ok=True)
    device = torch.device(train_device)

    dataloader = get_dataloader(dataset_path, batch_size=batch_size)
    checkpoint = torch.load(checkpoint_path, map_location=device) if os.path.exists(checkpoint_path) else None

    model = Paella(byt5_embd=2560).to(device)
    vqgan, (byt5_tokenizer, byt5) = load_conditional_models(byt5_model_name, vqmodel_path, device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_updates)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1, reduction='none')

    start_iter = 1
    if checkpoint is not None:
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.last_epoch = checkpoint['scheduler_last_step']
        start_iter = checkpoint['scheduler_last_step'] + 1
        del checkpoint

    pbar = tqdm(range(start_iter, steps+1))
    model.train()
    for i, (images, captions) in enumerate(dataloader):
        images = images.to(device)

        with torch.no_grad():
            if np.random.rand() < 0.05:
                byt5_captions = [''] * len(captions)
            else:
                byt5_captions = captions
            byt5_tokens = byt5_tokenizer(byt5_captions, padding="longest", return_tensors="pt", max_length=768, truncation=True).input_ids.to(device)
            byt_embeddings = byt5(input_ids=byt5_tokens).last_hidden_state

            t = (1-torch.rand(images.size(0), device=device))
            latents = vqgan.encode(images)[2]
            noised_latents, _ = model.add_noise(latents, t)

        pred = model(noised_latents, t, byt_embeddings)
        loss = criterion(pred, latents)

        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scheduler.step()
        optimizer.zero_grad()

        acc = (pred.argmax(1) == latents).float().mean()

        pbar.set_postfix({'bs': images.size(0), 'loss': loss.item(), 'acc': acc.item(), 'grad_norm': grad_norm.item(), 'lr': optimizer.param_groups[0]['lr'], 'total_steps': scheduler.last_epoch})

        if i % checkpoint_frequency == 0:
            torch.save({'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_last_step': scheduler.last_epoch, 'iter' : i}, checkpoint_path)


if __name__ == '__main__':
    train()