import torch
import torch.nn as nn
from torch.nn.functional import cross_entropy
from transformers import CLIPVisionModel, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from tqdm import tqdm
import os, peft
class CustomClipPhi2(nn.Module):
def __init__(self,tokenizer, phi2_model_name, clip_model_name, clip_embed=768, phi_embed=2560):
super().__init__()
self.tokenizer = tokenizer
# These two models are not finetuned
# pretrained Microsoft phi2 model
self.phi2_model = AutoModelForCausalLM.from_pretrained(phi2_model_name,torch_dtype=torch.float32, trust_remote_code=True)
# pretrained OpenAI clip model
self.clip_model = CLIPVisionModel.from_pretrained(clip_model_name)
self.EOS_TOKEN_ID = self.tokenizer.eos_token_id # 50256
self.IMAGE_TOKEN_ID = 23903 # token for Comments
self.clip_embed = clip_embed
self.phi_embed = phi_embed
# projection layers
# Trainable projection layer
self.projection_layer = torch.nn.Linear(clip_embed, phi_embed)
# Freeze Weights
for models in [self.phi2_model, self.clip_model]:
for param in models.parameters():
param.requires_grad_(False)
# load checkpoint weights
if os.path.exists('./ckpts/model_phase1.pth'):
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location='cpu'))
print("Loaded checkpoint weights for projection layer")
else:
print("No checkpoint weights for projection layer")
print("Initializing projection layer with random weights")
self.projection_layer.weight.data.normal_(mean=0.0, std=0.02)
self.projection_layer.bias.data.zero_()
def generate(self, images, tokenizer, config):
clip_outputs = self.clip_model(**images)
# remove cls token
images = clip_outputs.last_hidden_state[:, 1:, :]
image_embeddings = self.projection_layer(images).to(torch.float16)
batch_size = images.size()[0]
predicted_caption = torch.full((batch_size, config.get("max_tokens")), self.EOS_TOKEN_ID, dtype=torch.long, device=config.get('device'))
img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1)
for pos in range(config.get("max_tokens") - 1):
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
return predicted_caption
def forward(self, images, target_captions):
batch_size = target_captions.size()[0]
target_length = target_captions.size()[1]
print("---", target_length)
# clip model output for image
clip_outputs = self.clip_model(**images) # See this for loading https://huggingface.co/openai/clip-vit-base-patch36
images = clip_outputs.last_hidden_state[:, 1:, :] # remove CLS token
# projection layer
image_embeddings = self.projection_layer(images).to(torch.float16)
# add comment token from phi2
img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1) # 4,49,2560
del clip_outputs
del image_embeddings
# for loss
loss = 0
for pos in range(target_length - 1):
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), target_captions[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
loss += pos_loss
predicted_word_token = torch.argmax(predicted_word_token_logits, dim=-1)
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
loss = loss / target_length
# Delete variables to free up memory
del combined_embeds
del model_output_logits
torch.cuda.empty_cache()
return loss
def show_results_for_samples_phase1(model, val_dataloader, tokenizer, config, num_samples = 2):
model.eval()
with torch.no_grad():
for i in range(num_samples):
for images, target_captions in val_dataloader:
images = {'pixel_values': images.to(config.get('device'))}
target_captions = target_captions.to(config.get('device'))
target_captions_decoded = tokenizer.batch_decode(target_captions, ignore_index = tokenizer.eos_token_id)
predicted_captions = model.generate(images, tokenizer, config)
predicted_captions_decoded = tokenizer.batch_decode(predicted_captions,ignore_index = tokenizer.eos_token_id)
for idx, pc in enumerate(predicted_captions_decoded):
print(f"{idx} - Target captions: {target_captions_decoded[idx]} \n {'---------------------'*10} \n Predicted_captions:{pc} ")
break
def validate_model_phase1(model, val_dataloader, tokenizer, config):
model.eval()
total_loss = 0
with torch.no_grad():
try:
for images, target_captions in tqdm(val_dataloader):
images = {'pixel_values': images.to(config.get('device'))}
target_captions = target_captions.to(config.get('device'))
loss = model(images, target_captions)
total_loss+=loss.item()
print(f"Validation Loss: {total_loss/len(val_dataloader)}")
except Exception as e:
pass
model.train()
def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer, config):
model.train()
pbar = tqdm(train_loader)
for epoch in range(1, config.get("epochs")):
print(f"Epoch: {epoch}")
torch.cuda.empty_cache()
step = 1
try:
for idx, (images, target_captions) in enumerate(pbar):
try:
if target_captions.shape[1] >= config.get("max_tokens"):
# print(f"Skipping batch {idx} due to long caption")
continue
images = {'pixel_values': images.to(config.get('device'))}
target_captions = target_captions.to(config.get('device'))
optimizer.zero_grad()
loss = model(images, target_captions)
loss.backward()
optimizer.step()
pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
torch.cuda.empty_cache()
step+=1
if (step%1000==0):
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
except Exception as e:
print(e)
continue
# # save model
# if ((epoch % 2) == 0):
# Only save last checkpoint
validate_model_phase1(model, val_dataloader, tokenizer, config)
show_results_for_samples_phase1(model, val_dataloader, tokenizer, config)
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
except Exception as e:
print(e)
continue
######################################## Phase 2 #########################################
class MainQLoraModel(nn.Module):
def __init__(self, tokenizer, config):
super().__init__()
self.tokenizer = tokenizer
self.config = config
self.clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
phi2_model = AutoModelForCausalLM.from_pretrained(
config.get("phi2_model_name"),
quantization_config=bnb_config,
trust_remote_code=True
)
phi2_model.config.use_cache = False
## 4 - LORA config
lora_alpha = 16
lora_dropout = 0.1
lora_r = 64
peft_config = LoraConfig(
lora_alpha = lora_alpha,
lora_dropout = lora_dropout,
r = lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"dense",
"fc1",
"fc2"
]
)
self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device"))
self.EOS_TOKEN_ID = self.tokenizer.eos_token_id
self.clip_embed = config.get("clip_embed")
self.phi_embed = config.get("phi_embed")
# projection layers
# Trainable projection layer
self.projection_layer = torch.nn.Linear(self.clip_embed, self.phi_embed)
# Freeze Weights
for models in [self.clip_model]:
for param in models.parameters():
param.requires_grad_(False)
# load checkpoint weights
if os.path.exists('./ckpts/model_phase2.pth'):
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
self.phi2_model.from_pretrained(self.phi2_model,'./ckpts/Qlora_adaptor')
print("Loaded checkpoint weights for projection layer")
else:
# Load weights from phase 1
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device")))
def generate(self, tokenizer, config, images = None, ques = None, max_tokens = 100):
batch_size = 1
predicted_caption = torch.full((batch_size, max_tokens), self.EOS_TOKEN_ID, dtype=torch.long, device=self.config.get('device'))
start_iq = self.tokenizer.encode("")
end_iq = self.tokenizer.encode("")
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
questions_embed = self.phi2_model.model.model.embed_tokens(ques)
if images is not None:
clip_outputs = self.clip_model(**images)
# remove cls token
images = clip_outputs.last_hidden_state[:, 1:, :]
image_embeddings = self.projection_layer(images).to(torch.float16)
combined_embeds = torch.cat([start_iq_embeds, image_embeddings, questions_embed, end_iq_embeds], dim=1)
else:
combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds], dim=1)
for pos in range(max_tokens - 1):
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
return predicted_caption
def forward(self, images, ques, ans):
batch_size = ques.size()[0]
questions = ques.to(self.config.get("device"))
answers = ans.to(self.config.get("device"))
target_length = ans.size()[1]
start_iq = self.tokenizer.encode("")
end_iq = self.tokenizer.encode("")
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
questions_embed = self.phi2_model.model.model.embed_tokens(questions)
answers_embed = self.phi2_model.model.model.embed_tokens(answers)
are_all_zeros = torch.all(images == 0).item()
if are_all_zeros:
combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
else:
images = {'pixel_values': images.to(self.config.get("device"))}
clip_outputs = self.clip_model(**images)
images_embeds = clip_outputs.last_hidden_state[:,1:,:] # remove cls token
# projection
image_embeds = self.projection_layer(images_embeds).to(torch.float16)
combined_embeds = torch.cat([start_iq_embeds, image_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
# # for loss
loss = 0
for pos in range(target_length - 1):
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), answers[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
loss += pos_loss
loss = loss / target_length
# Delete variables to free up memory
del combined_embeds
del model_output_logits
torch.cuda.empty_cache()
return loss
def validate_model_phase2(model, val_dataloader, tokenizer, config):
model.eval()
total_loss = 0
with torch.no_grad():
# try:
for images, ques, ans in tqdm(val_dataloader):
loss = model(images, ques, ans)
total_loss+=loss.item()
print(f"Validation Loss: {total_loss/len(val_dataloader)}")
# except Exception as e:
# pass
model.train()
def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config):
phi2_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.phi2_model.parameters()), lr=1e-5)
proj_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.projection_layer.parameters()), lr=1e-5)
model.phi2_model.train()
model.projection_layer.train()
pbar = tqdm(train_loader)
for epoch in range(1, config.get("epochs")):
print(f"Epoch: {epoch}")
torch.cuda.empty_cache()
step = 1
try:
for idx, (images, ques, ans) in enumerate(pbar):
try:
phi2_optim.zero_grad()
proj_optim.zero_grad()
loss = model(images, ques, ans)
loss.backward()
phi2_optim.step()
proj_optim.step()
pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
torch.cuda.empty_cache()
step+=1
if (step%1000==0):
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
except Exception as e:
print("in frp",e)
continue
validate_model_phase2(model, val_dataloader, tokenizer, config)
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
except Exception as e:
print(e)
continue