Vasudevakrishna's picture
Upload 7 files
94f80f5 verified
raw
history blame contribute delete
No virus
1.74 kB
import torch
from dataset import get_data_loaders_phase1, get_data_loaders_phase2
from transformers import AutoTokenizer
from model import CustomClipPhi2, MainQLoraModel, train_model_phase1, train_model_phase2
from configs import get_config_phase1, get_config_phase2
def phase_1():
# get config
config = get_config_phase1()
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
# data loaders
train_dataloader, val_dataloader = get_data_loaders_phase1(config.get("data_dir"), config.get("clip_model_name"), tokenizer, config.get("train_batch_size"), config.get("val_batch_size"), config.get("num_workers"))
llmModel = CustomClipPhi2(tokenizer, config.get("phi2_model_name"), config.get("clip_model_name"), clip_embed=768, phi_embed=2560).to(config.get("device"))
print(llmModel)
# optimizer
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, llmModel.parameters()), lr=1e-3)
# train model
train_model_phase1(llmModel, train_dataloader, val_dataloader, optimizer, tokenizer, config)
def phase_2():
# get config
config = get_config_phase2()
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
# data loaders
train_dataloader, val_dataloader = get_data_loaders_phase2(tokenizer, config)
llmModel = MainQLoraModel(tokenizer, config).to(config.get("device"))
print(llmModel)
# train model
train_model_phase2(llmModel, train_dataloader, val_dataloader, tokenizer, config)
if __name__ == "__main__":
torch.set_float32_matmul_precision('medium')
phase_1()
# phase_2()