import torch import torch.nn as nn from torch.nn.functional import cross_entropy from transformers import CLIPVisionModel, AutoModelForCausalLM, BitsAndBytesConfig from peft import LoraConfig from tqdm import tqdm import os, peft class CustomClipPhi2(nn.Module): def __init__(self,tokenizer, phi2_model_name, clip_model_name, clip_embed=768, phi_embed=2560): super().__init__() self.tokenizer = tokenizer # These two models are not finetuned # pretrained Microsoft phi2 model self.phi2_model = AutoModelForCausalLM.from_pretrained(phi2_model_name,torch_dtype=torch.float32, trust_remote_code=True) # pretrained OpenAI clip model self.clip_model = CLIPVisionModel.from_pretrained(clip_model_name) self.EOS_TOKEN_ID = self.tokenizer.eos_token_id # 50256 self.IMAGE_TOKEN_ID = 23903 # token for Comments self.clip_embed = clip_embed self.phi_embed = phi_embed # projection layers # Trainable projection layer self.projection_layer = torch.nn.Linear(clip_embed, phi_embed) # Freeze Weights for models in [self.phi2_model, self.clip_model]: for param in models.parameters(): param.requires_grad_(False) # load checkpoint weights if os.path.exists('./ckpts/model_phase1.pth'): self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location='cpu')) print("Loaded checkpoint weights for projection layer") else: print("No checkpoint weights for projection layer") print("Initializing projection layer with random weights") self.projection_layer.weight.data.normal_(mean=0.0, std=0.02) self.projection_layer.bias.data.zero_() def generate(self, images, tokenizer, config): clip_outputs = self.clip_model(**images) # remove cls token images = clip_outputs.last_hidden_state[:, 1:, :] image_embeddings = self.projection_layer(images).to(torch.float16) batch_size = images.size()[0] predicted_caption = torch.full((batch_size, config.get("max_tokens")), self.EOS_TOKEN_ID, dtype=torch.long, device=config.get('device')) img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1) img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device)) combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1) for pos in range(config.get("max_tokens") - 1): model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits'] predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1) predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu') next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token) combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1) return predicted_caption def forward(self, images, target_captions): batch_size = target_captions.size()[0] target_length = target_captions.size()[1] print("---", target_length) # clip model output for image clip_outputs = self.clip_model(**images) # See this for loading https://huggingface.co/openai/clip-vit-base-patch36 images = clip_outputs.last_hidden_state[:, 1:, :] # remove CLS token # projection layer image_embeddings = self.projection_layer(images).to(torch.float16) # add comment token from phi2 img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1) img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device)) combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1) # 4,49,2560 del clip_outputs del image_embeddings # for loss loss = 0 for pos in range(target_length - 1): model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits'] predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1) pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), target_captions[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1) loss += pos_loss predicted_word_token = torch.argmax(predicted_word_token_logits, dim=-1) next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token) combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1) loss = loss / target_length # Delete variables to free up memory del combined_embeds del model_output_logits torch.cuda.empty_cache() return loss def show_results_for_samples_phase1(model, val_dataloader, tokenizer, config, num_samples = 2): model.eval() with torch.no_grad(): for i in range(num_samples): for images, target_captions in val_dataloader: images = {'pixel_values': images.to(config.get('device'))} target_captions = target_captions.to(config.get('device')) target_captions_decoded = tokenizer.batch_decode(target_captions, ignore_index = tokenizer.eos_token_id) predicted_captions = model.generate(images, tokenizer, config) predicted_captions_decoded = tokenizer.batch_decode(predicted_captions,ignore_index = tokenizer.eos_token_id) for idx, pc in enumerate(predicted_captions_decoded): print(f"{idx} - Target captions: {target_captions_decoded[idx]} \n {'---------------------'*10} \n Predicted_captions:{pc} ") break def validate_model_phase1(model, val_dataloader, tokenizer, config): model.eval() total_loss = 0 with torch.no_grad(): try: for images, target_captions in tqdm(val_dataloader): images = {'pixel_values': images.to(config.get('device'))} target_captions = target_captions.to(config.get('device')) loss = model(images, target_captions) total_loss+=loss.item() print(f"Validation Loss: {total_loss/len(val_dataloader)}") except Exception as e: pass model.train() def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer, config): model.train() pbar = tqdm(train_loader) for epoch in range(1, config.get("epochs")): print(f"Epoch: {epoch}") torch.cuda.empty_cache() step = 1 try: for idx, (images, target_captions) in enumerate(pbar): try: if target_captions.shape[1] >= config.get("max_tokens"): # print(f"Skipping batch {idx} due to long caption") continue images = {'pixel_values': images.to(config.get('device'))} target_captions = target_captions.to(config.get('device')) optimizer.zero_grad() loss = model(images, target_captions) loss.backward() optimizer.step() pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}") torch.cuda.empty_cache() step+=1 if (step%1000==0): torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth') except Exception as e: print(e) continue # # save model # if ((epoch % 2) == 0): # Only save last checkpoint validate_model_phase1(model, val_dataloader, tokenizer, config) show_results_for_samples_phase1(model, val_dataloader, tokenizer, config) torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth') except Exception as e: print(e) continue ######################################## Phase 2 ######################################### class MainQLoraModel(nn.Module): def __init__(self, tokenizer, config): super().__init__() self.tokenizer = tokenizer self.config = config self.clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name")) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) phi2_model = AutoModelForCausalLM.from_pretrained( config.get("phi2_model_name"), quantization_config=bnb_config, trust_remote_code=True ) phi2_model.config.use_cache = False ## 4 - LORA config lora_alpha = 16 lora_dropout = 0.1 lora_r = 64 peft_config = LoraConfig( lora_alpha = lora_alpha, lora_dropout = lora_dropout, r = lora_r, bias="none", task_type="CAUSAL_LM", target_modules=[ "q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2" ] ) self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device")) self.EOS_TOKEN_ID = self.tokenizer.eos_token_id self.clip_embed = config.get("clip_embed") self.phi_embed = config.get("phi_embed") # projection layers # Trainable projection layer self.projection_layer = torch.nn.Linear(self.clip_embed, self.phi_embed) # Freeze Weights for models in [self.clip_model]: for param in models.parameters(): param.requires_grad_(False) # load checkpoint weights if os.path.exists('./ckpts/model_phase2.pth'): self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device"))) self.phi2_model.from_pretrained(self.phi2_model,'./ckpts/Qlora_adaptor') print("Loaded checkpoint weights for projection layer") else: # Load weights from phase 1 self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device"))) def generate(self, tokenizer, config, images = None, ques = None, max_tokens = 100): batch_size = 1 predicted_caption = torch.full((batch_size, max_tokens), self.EOS_TOKEN_ID, dtype=torch.long, device=self.config.get('device')) start_iq = self.tokenizer.encode("") end_iq = self.tokenizer.encode("") start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1) end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1) start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device"))) end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device"))) questions_embed = self.phi2_model.model.model.embed_tokens(ques) if images is not None: clip_outputs = self.clip_model(**images) # remove cls token images = clip_outputs.last_hidden_state[:, 1:, :] image_embeddings = self.projection_layer(images).to(torch.float16) combined_embeds = torch.cat([start_iq_embeds, image_embeddings, questions_embed, end_iq_embeds], dim=1) else: combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds], dim=1) for pos in range(max_tokens - 1): model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits'] predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1) predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu') next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token) combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1) return predicted_caption def forward(self, images, ques, ans): batch_size = ques.size()[0] questions = ques.to(self.config.get("device")) answers = ans.to(self.config.get("device")) target_length = ans.size()[1] start_iq = self.tokenizer.encode("") end_iq = self.tokenizer.encode("") start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1) end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1) start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device"))) end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device"))) questions_embed = self.phi2_model.model.model.embed_tokens(questions) answers_embed = self.phi2_model.model.model.embed_tokens(answers) are_all_zeros = torch.all(images == 0).item() if are_all_zeros: combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1) else: images = {'pixel_values': images.to(self.config.get("device"))} clip_outputs = self.clip_model(**images) images_embeds = clip_outputs.last_hidden_state[:,1:,:] # remove cls token # projection image_embeds = self.projection_layer(images_embeds).to(torch.float16) combined_embeds = torch.cat([start_iq_embeds, image_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1) model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits'] # # for loss loss = 0 for pos in range(target_length - 1): predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1) pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), answers[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1) loss += pos_loss loss = loss / target_length # Delete variables to free up memory del combined_embeds del model_output_logits torch.cuda.empty_cache() return loss def validate_model_phase2(model, val_dataloader, tokenizer, config): model.eval() total_loss = 0 with torch.no_grad(): # try: for images, ques, ans in tqdm(val_dataloader): loss = model(images, ques, ans) total_loss+=loss.item() print(f"Validation Loss: {total_loss/len(val_dataloader)}") # except Exception as e: # pass model.train() def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config): phi2_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.phi2_model.parameters()), lr=1e-5) proj_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.projection_layer.parameters()), lr=1e-5) model.phi2_model.train() model.projection_layer.train() pbar = tqdm(train_loader) for epoch in range(1, config.get("epochs")): print(f"Epoch: {epoch}") torch.cuda.empty_cache() step = 1 try: for idx, (images, ques, ans) in enumerate(pbar): try: phi2_optim.zero_grad() proj_optim.zero_grad() loss = model(images, ques, ans) loss.backward() phi2_optim.step() proj_optim.step() pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}") torch.cuda.empty_cache() step+=1 if (step%1000==0): torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth') model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True) except Exception as e: print("in frp",e) continue validate_model_phase2(model, val_dataloader, tokenizer, config) torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth') model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True) except Exception as e: print(e) continue