|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.functional import cross_entropy
|
|
from transformers import CLIPVisionModel, AutoModelForCausalLM, BitsAndBytesConfig
|
|
from peft import LoraConfig
|
|
from tqdm import tqdm
|
|
import os, peft
|
|
|
|
|
|
class CustomClipPhi2(nn.Module):
|
|
def __init__(self,tokenizer, phi2_model_name, clip_model_name, clip_embed=768, phi_embed=2560):
|
|
super().__init__()
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
self.phi2_model = AutoModelForCausalLM.from_pretrained(phi2_model_name,torch_dtype=torch.float32, trust_remote_code=True)
|
|
|
|
self.clip_model = CLIPVisionModel.from_pretrained(clip_model_name)
|
|
|
|
self.EOS_TOKEN_ID = self.tokenizer.eos_token_id
|
|
self.IMAGE_TOKEN_ID = 23903
|
|
self.clip_embed = clip_embed
|
|
self.phi_embed = phi_embed
|
|
|
|
|
|
|
|
self.projection_layer = torch.nn.Linear(clip_embed, phi_embed)
|
|
|
|
|
|
for models in [self.phi2_model, self.clip_model]:
|
|
for param in models.parameters():
|
|
param.requires_grad_(False)
|
|
|
|
|
|
if os.path.exists('./ckpts/model_phase1.pth'):
|
|
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location='cpu'))
|
|
print("Loaded checkpoint weights for projection layer")
|
|
else:
|
|
print("No checkpoint weights for projection layer")
|
|
print("Initializing projection layer with random weights")
|
|
self.projection_layer.weight.data.normal_(mean=0.0, std=0.02)
|
|
self.projection_layer.bias.data.zero_()
|
|
|
|
|
|
def generate(self, images, tokenizer, config):
|
|
clip_outputs = self.clip_model(**images)
|
|
|
|
images = clip_outputs.last_hidden_state[:, 1:, :]
|
|
image_embeddings = self.projection_layer(images).to(torch.float16)
|
|
|
|
batch_size = images.size()[0]
|
|
predicted_caption = torch.full((batch_size, config.get("max_tokens")), self.EOS_TOKEN_ID, dtype=torch.long, device=config.get('device'))
|
|
img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
|
|
img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
|
|
combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1)
|
|
|
|
for pos in range(config.get("max_tokens") - 1):
|
|
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
|
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
|
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
|
|
predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
|
|
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
|
|
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
|
return predicted_caption
|
|
|
|
|
|
def forward(self, images, target_captions):
|
|
|
|
batch_size = target_captions.size()[0]
|
|
target_length = target_captions.size()[1]
|
|
print("---", target_length)
|
|
|
|
|
|
clip_outputs = self.clip_model(**images)
|
|
images = clip_outputs.last_hidden_state[:, 1:, :]
|
|
|
|
|
|
image_embeddings = self.projection_layer(images).to(torch.float16)
|
|
|
|
|
|
img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
|
|
img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
|
|
combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1)
|
|
del clip_outputs
|
|
del image_embeddings
|
|
|
|
|
|
loss = 0
|
|
for pos in range(target_length - 1):
|
|
|
|
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
|
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
|
pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), target_captions[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
|
|
loss += pos_loss
|
|
|
|
predicted_word_token = torch.argmax(predicted_word_token_logits, dim=-1)
|
|
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
|
|
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
|
loss = loss / target_length
|
|
|
|
|
|
del combined_embeds
|
|
del model_output_logits
|
|
torch.cuda.empty_cache()
|
|
|
|
return loss
|
|
|
|
|
|
def show_results_for_samples_phase1(model, val_dataloader, tokenizer, config, num_samples = 2):
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for i in range(num_samples):
|
|
for images, target_captions in val_dataloader:
|
|
images = {'pixel_values': images.to(config.get('device'))}
|
|
target_captions = target_captions.to(config.get('device'))
|
|
target_captions_decoded = tokenizer.batch_decode(target_captions, ignore_index = tokenizer.eos_token_id)
|
|
predicted_captions = model.generate(images, tokenizer, config)
|
|
predicted_captions_decoded = tokenizer.batch_decode(predicted_captions,ignore_index = tokenizer.eos_token_id)
|
|
|
|
for idx, pc in enumerate(predicted_captions_decoded):
|
|
print(f"{idx} - Target captions: {target_captions_decoded[idx]} \n {'---------------------'*10} \n Predicted_captions:{pc} ")
|
|
break
|
|
|
|
|
|
def validate_model_phase1(model, val_dataloader, tokenizer, config):
|
|
model.eval()
|
|
total_loss = 0
|
|
with torch.no_grad():
|
|
try:
|
|
for images, target_captions in tqdm(val_dataloader):
|
|
images = {'pixel_values': images.to(config.get('device'))}
|
|
target_captions = target_captions.to(config.get('device'))
|
|
loss = model(images, target_captions)
|
|
total_loss+=loss.item()
|
|
print(f"Validation Loss: {total_loss/len(val_dataloader)}")
|
|
except Exception as e:
|
|
pass
|
|
model.train()
|
|
|
|
|
|
def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer, config):
|
|
model.train()
|
|
|
|
pbar = tqdm(train_loader)
|
|
for epoch in range(1, config.get("epochs")):
|
|
print(f"Epoch: {epoch}")
|
|
torch.cuda.empty_cache()
|
|
step = 1
|
|
try:
|
|
for idx, (images, target_captions) in enumerate(pbar):
|
|
try:
|
|
if target_captions.shape[1] >= config.get("max_tokens"):
|
|
|
|
continue
|
|
|
|
images = {'pixel_values': images.to(config.get('device'))}
|
|
target_captions = target_captions.to(config.get('device'))
|
|
|
|
optimizer.zero_grad()
|
|
loss = model(images, target_captions)
|
|
loss.backward()
|
|
optimizer.step()
|
|
pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
|
|
torch.cuda.empty_cache()
|
|
step+=1
|
|
if (step%1000==0):
|
|
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
|
|
except Exception as e:
|
|
print(e)
|
|
continue
|
|
|
|
|
|
|
|
|
|
validate_model_phase1(model, val_dataloader, tokenizer, config)
|
|
show_results_for_samples_phase1(model, val_dataloader, tokenizer, config)
|
|
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
|
|
|
|
except Exception as e:
|
|
print(e)
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainQLoraModel(nn.Module):
|
|
def __init__(self, tokenizer, config):
|
|
super().__init__()
|
|
self.tokenizer = tokenizer
|
|
self.config = config
|
|
self.clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=torch.float16,
|
|
)
|
|
|
|
phi2_model = AutoModelForCausalLM.from_pretrained(
|
|
config.get("phi2_model_name"),
|
|
quantization_config=bnb_config,
|
|
trust_remote_code=True
|
|
)
|
|
phi2_model.config.use_cache = False
|
|
|
|
|
|
|
|
lora_alpha = 16
|
|
lora_dropout = 0.1
|
|
lora_r = 64
|
|
|
|
peft_config = LoraConfig(
|
|
lora_alpha = lora_alpha,
|
|
lora_dropout = lora_dropout,
|
|
r = lora_r,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=[
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
"dense",
|
|
"fc1",
|
|
"fc2"
|
|
]
|
|
)
|
|
self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device"))
|
|
|
|
self.EOS_TOKEN_ID = self.tokenizer.eos_token_id
|
|
self.clip_embed = config.get("clip_embed")
|
|
self.phi_embed = config.get("phi_embed")
|
|
|
|
|
|
|
|
self.projection_layer = torch.nn.Linear(self.clip_embed, self.phi_embed)
|
|
|
|
|
|
for models in [self.clip_model]:
|
|
for param in models.parameters():
|
|
param.requires_grad_(False)
|
|
|
|
|
|
if os.path.exists('./ckpts/model_phase2.pth'):
|
|
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
|
|
self.phi2_model.from_pretrained(self.phi2_model,'./ckpts/Qlora_adaptor')
|
|
print("Loaded checkpoint weights for projection layer")
|
|
else:
|
|
|
|
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device")))
|
|
|
|
|
|
def generate(self, tokenizer, config, images = None, ques = None, max_tokens = 100):
|
|
batch_size = 1
|
|
|
|
predicted_caption = torch.full((batch_size, max_tokens), self.EOS_TOKEN_ID, dtype=torch.long, device=self.config.get('device'))
|
|
start_iq = self.tokenizer.encode("<iQ>")
|
|
end_iq = self.tokenizer.encode("</iQ>")
|
|
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
|
|
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
|
|
start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
|
|
end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
|
|
questions_embed = self.phi2_model.model.model.embed_tokens(ques)
|
|
if images is not None:
|
|
clip_outputs = self.clip_model(**images)
|
|
|
|
images = clip_outputs.last_hidden_state[:, 1:, :]
|
|
image_embeddings = self.projection_layer(images).to(torch.float16)
|
|
combined_embeds = torch.cat([start_iq_embeds, image_embeddings, questions_embed, end_iq_embeds], dim=1)
|
|
else:
|
|
combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds], dim=1)
|
|
|
|
for pos in range(max_tokens - 1):
|
|
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
|
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
|
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
|
|
predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
|
|
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
|
|
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
|
return predicted_caption
|
|
|
|
|
|
def forward(self, images, ques, ans):
|
|
|
|
batch_size = ques.size()[0]
|
|
questions = ques.to(self.config.get("device"))
|
|
answers = ans.to(self.config.get("device"))
|
|
target_length = ans.size()[1]
|
|
start_iq = self.tokenizer.encode("<iQ>")
|
|
end_iq = self.tokenizer.encode("</iQ>")
|
|
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
|
|
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
|
|
start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
|
|
end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
|
|
|
|
questions_embed = self.phi2_model.model.model.embed_tokens(questions)
|
|
answers_embed = self.phi2_model.model.model.embed_tokens(answers)
|
|
|
|
are_all_zeros = torch.all(images == 0).item()
|
|
if are_all_zeros:
|
|
combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
|
|
else:
|
|
images = {'pixel_values': images.to(self.config.get("device"))}
|
|
clip_outputs = self.clip_model(**images)
|
|
images_embeds = clip_outputs.last_hidden_state[:,1:,:]
|
|
|
|
|
|
image_embeds = self.projection_layer(images_embeds).to(torch.float16)
|
|
combined_embeds = torch.cat([start_iq_embeds, image_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
|
|
|
|
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
|
|
|
loss = 0
|
|
for pos in range(target_length - 1):
|
|
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
|
pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), answers[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
|
|
loss += pos_loss
|
|
loss = loss / target_length
|
|
|
|
|
|
del combined_embeds
|
|
del model_output_logits
|
|
torch.cuda.empty_cache()
|
|
return loss
|
|
|
|
def validate_model_phase2(model, val_dataloader, tokenizer, config):
|
|
model.eval()
|
|
total_loss = 0
|
|
with torch.no_grad():
|
|
|
|
for images, ques, ans in tqdm(val_dataloader):
|
|
loss = model(images, ques, ans)
|
|
total_loss+=loss.item()
|
|
print(f"Validation Loss: {total_loss/len(val_dataloader)}")
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config):
|
|
phi2_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.phi2_model.parameters()), lr=1e-5)
|
|
proj_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.projection_layer.parameters()), lr=1e-5)
|
|
model.phi2_model.train()
|
|
model.projection_layer.train()
|
|
|
|
pbar = tqdm(train_loader)
|
|
for epoch in range(1, config.get("epochs")):
|
|
print(f"Epoch: {epoch}")
|
|
torch.cuda.empty_cache()
|
|
step = 1
|
|
try:
|
|
for idx, (images, ques, ans) in enumerate(pbar):
|
|
try:
|
|
phi2_optim.zero_grad()
|
|
proj_optim.zero_grad()
|
|
loss = model(images, ques, ans)
|
|
loss.backward()
|
|
phi2_optim.step()
|
|
proj_optim.step()
|
|
pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
|
|
torch.cuda.empty_cache()
|
|
step+=1
|
|
if (step%1000==0):
|
|
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
|
|
model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
|
|
except Exception as e:
|
|
print("in frp",e)
|
|
continue
|
|
|
|
validate_model_phase2(model, val_dataloader, tokenizer, config)
|
|
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
|
|
model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
|
|
|
|
except Exception as e:
|
|
print(e)
|
|
continue |