File size: 1,741 Bytes
94f80f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from dataset import get_data_loaders_phase1, get_data_loaders_phase2
from transformers import AutoTokenizer
from model import CustomClipPhi2, MainQLoraModel, train_model_phase1, train_model_phase2
from configs import get_config_phase1, get_config_phase2

def phase_1():   
    # get config
    config = get_config_phase1() 
    # tokenizer
    tokenizer  = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)

    # data loaders
    train_dataloader, val_dataloader = get_data_loaders_phase1(config.get("data_dir"), config.get("clip_model_name"), tokenizer, config.get("train_batch_size"), config.get("val_batch_size"), config.get("num_workers"))

    llmModel = CustomClipPhi2(tokenizer, config.get("phi2_model_name"), config.get("clip_model_name"), clip_embed=768, phi_embed=2560).to(config.get("device"))
    print(llmModel)
    # optimizer
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, llmModel.parameters()), lr=1e-3)
    # train model
    train_model_phase1(llmModel, train_dataloader, val_dataloader, optimizer, tokenizer, config)


def phase_2():   
    # get config
    config = get_config_phase2() 
    # tokenizer
    tokenizer  = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)

    # data loaders
    train_dataloader, val_dataloader = get_data_loaders_phase2(tokenizer, config)

    llmModel = MainQLoraModel(tokenizer, config).to(config.get("device"))
    print(llmModel)
    # train model
    train_model_phase2(llmModel, train_dataloader, val_dataloader, tokenizer, config)

if __name__ == "__main__":
    torch.set_float32_matmul_precision('medium')
    phase_1()
    # phase_2()