Vasudevakrishna commited on
Commit
94f80f5
β€’
1 Parent(s): 0d94ec2

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +57 -13
  2. configs.py +33 -0
  3. dataset.py +204 -0
  4. get_coco.py +41 -0
  5. main.py +41 -0
  6. model.py +378 -0
  7. requirements.txt +7 -0
README.md CHANGED
@@ -1,13 +1,57 @@
1
- ---
2
- title: MultiModel LLM ERAV2
3
- emoji: πŸš€
4
- colorFrom: red
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-Modal LLM Gradio App
2
+
3
+ ## Project Overview
4
+
5
+ This project is a **multi-modal language model** Gradio app that accepts **text**, **image**, and **audio inputs**, and outputs **text responses**. The app mimics a **ChatGPT-style interface**, allowing users to interact using multiple input modes.
6
+
7
+ The app leverages:
8
+ - **CLIP** for image processing
9
+ - **Whisper** for audio transcription (ASR)
10
+ - A **text-based model** (like GPT or Phi) for generating text responses
11
+
12
+ ## Features
13
+
14
+ - **Text Input**: Users can input text directly for response generation.
15
+ - **Image Input**: Users can upload images, which are processed by the CLIP model.
16
+ - **Audio Input**: Users can upload or record audio files, which are transcribed by the Whisper model and then processed for response.
17
+ - **ChatGPT-Like Interface**: Simple and intuitive interface to handle multi-modal inputs and provide text-based output.
18
+
19
+ ## Installation
20
+
21
+ 1. Clone the repository:
22
+ ```bash
23
+ git clone https://huggingface.co/spaces/Vasudevakrishna/MultiModel_LLM_ERAV2
24
+ cd MultiModel_LLM_ERAV2
25
+ ```
26
+
27
+ 2. Install dependencies:
28
+ ```bash
29
+ pip -r requirements.txt
30
+ ```
31
+
32
+ 3. Run the app:
33
+ ```bash
34
+ python app.py
35
+ ```
36
+
37
+ ## How It Works
38
+
39
+ 1. **Text Processing**: Input text is passed to a language model (like GPT or Phi) to generate a response.
40
+ 2. **Image Processing**: Images are processed using CLIP, which extracts embeddings. These embeddings are then converted into a format understandable by the text model.
41
+ 3. **Audio Processing**: Audio files are transcribed into text using Whisper. This text is passed into the language model for response generation.
42
+
43
+ ## Usage
44
+
45
+ - **Text Input**: Enter text in the provided textbox and click "Submit" to generate a response.
46
+ - **Image Input**: Upload an image and click "Submit" to generate a response based on the image.
47
+ - **Audio Input**: Upload or record an audio file, click "Submit" to transcribe and generate a response.
48
+
49
+ ## Future Improvements
50
+
51
+ - Add advanced features like drag-and-drop file upload or live audio recording for a better user experience.
52
+ - Improve the real-time image embedding process by running CLIP embeddings in real-time with more GPU resources.
53
+ - Implement end-to-end training of all components for better response quality.
54
+
55
+ ## License
56
+
57
+ This project is licensed under the MIT License.
configs.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ def get_config_phase1():
3
+ return {
4
+ "data_dir": "./data",
5
+ "clip_model_name": "openai/clip-vit-base-patch16",
6
+ "phi2_model_name": "microsoft/phi-2",
7
+ "train_batch_size": 2,
8
+ "val_batch_size": 1,
9
+ "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
10
+ "epochs": 2,
11
+ "max_tokens": 20,
12
+ "clip_embed": 768,
13
+ "phi_embed": 2560,
14
+ "num_workers": 4,
15
+ "ckpts": "./ckpts"
16
+ }
17
+
18
+ def get_config_phase2():
19
+ return {
20
+ "data_dir": "./data",
21
+ "clip_model_name": "openai/clip-vit-base-patch16",
22
+ "phi2_model_name": "microsoft/phi-2",
23
+ "train_batch_size": 1,
24
+ "val_batch_size": 1,
25
+ "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
26
+ "epochs": 10,
27
+ "max_tokens": 100,
28
+ "clip_embed": 768,
29
+ "phi_embed": 2560,
30
+ "num_workers": 0,
31
+ "ckpts": "./ckpts",
32
+ "vocab_size": 51200
33
+ }
dataset.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from transformers import AutoProcessor
7
+ from torch.utils.data import DataLoader
8
+ import pickle
9
+ import requests
10
+ from datasets import Dataset, load_dataset
11
+ import pandas as pd
12
+ import numpy as np
13
+
14
+
15
+ class ClipDataset(Dataset):
16
+ '''ClipDataset class for loading the CLIP dataset'''
17
+ def __init__(self, coco_data, model_name, tokenizer):
18
+
19
+ self.tokenizer = tokenizer
20
+ self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
21
+ self.caption_dataset = coco_data
22
+
23
+ def __len__(self):
24
+ #Return the length of the dataset
25
+ return len(self.caption_dataset)
26
+
27
+ def __getitem__(self, idx):
28
+ #Get the image url and caption
29
+ img_url = self.caption_dataset[idx]["image_url"]
30
+ caption = self.caption_dataset[idx]["caption"]
31
+
32
+ #Get the image and caption embeddings
33
+ image = Image.open(requests.get(img_url,stream=True).raw)
34
+ width, height = image.size
35
+ new_width = 224
36
+ new_height = new_width * height // width
37
+ new_height = 224
38
+ new_width = new_height * width // height
39
+ image = image.resize((new_width, new_height), Image.LANCZOS)
40
+ image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values']
41
+ image_sqeezed = image_processed.squeeze(0)
42
+ tokenized_caption = self.tokenizer(caption, return_tensors="pt", return_attention_mask=False)
43
+ tokenized_caption_ids = tokenized_caption['input_ids'].squeeze(0)
44
+ return(image_sqeezed , tokenized_caption_ids)
45
+
46
+
47
+ def collate_fn_phase1(batch):
48
+ #Unzip the batch
49
+ image_embeddings, captions = zip(*batch)
50
+ #Stack the image embeddings
51
+ image_embeddings_stacked = torch.stack(image_embeddings, dim=0)
52
+ #Pad the captions, padded value is the <eos> token
53
+ captions_padded = torch.nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=50256)
54
+ #Return the stacked image embeddings and padded captions
55
+ return (image_embeddings_stacked, captions_padded)
56
+
57
+
58
+ def get_data_loaders_phase1(data_dir, clip_model_name, tokenizer, train_batch_size, val_batch_size, num_workers):
59
+ # Load the data
60
+ with open(os.path.join(data_dir, 'coco_train.pkl'), 'rb') as fp:
61
+ train_pkl = pickle.load(fp)
62
+ with open(os.path.join(data_dir, "coco_val.pkl"), "rb") as fp:
63
+ val_pkl = pickle.load(fp)
64
+ # train data loaders
65
+ train_dataloader = DataLoader(ClipDataset(train_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=train_batch_size, num_workers = num_workers, shuffle=True, pin_memory=True)
66
+
67
+ # val data loaders
68
+ val_dataloader = DataLoader(ClipDataset(val_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=val_batch_size, num_workers = num_workers, shuffle=False, pin_memory=True)
69
+ return train_dataloader, val_dataloader
70
+
71
+ ##################################### Phase 2 #########################################
72
+
73
+
74
+ class ClipDatasetPhase2(Dataset):
75
+ '''ClipDataset class for loading the CLIP dataset'''
76
+ def __init__(self, data_frame, model_name, tokenizer):
77
+
78
+ self.tokenizer = tokenizer
79
+ self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
80
+ self.df = data_frame
81
+
82
+ def __len__(self):
83
+ #Return the length of the dataset
84
+ return len(self.df)
85
+
86
+ def __getitem__(self, idx):
87
+ #Get the image url and QAs
88
+ img_url = self.df.ImageUrl[idx[0]]
89
+ que = self.df.Question[idx[0]]
90
+ ans = self.df.Answer[idx[0]]
91
+
92
+ print("img_url", img_url)
93
+ print("que", que)
94
+ print("ans", ans)
95
+
96
+ #Get the image and caption embeddings
97
+ if img_url is None:
98
+ print("img_url is None")
99
+ image_sqeezed = None
100
+ else:
101
+ image = Image.open(requests.get(img_url,stream=True).raw)
102
+ width, height = image.size
103
+ new_width = 224
104
+ new_height = new_width * height // width
105
+ new_height = 224
106
+ new_width = new_height * width // height
107
+ image = image.resize((new_width, new_height), Image.LANCZOS)
108
+ image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values']
109
+ image_sqeezed = image_processed.squeeze(0)
110
+ que_ids = self.tokenizer(que, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
111
+ ans_ids = self.tokenizer(ans, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
112
+ return(image_sqeezed , que_ids, ans_ids)
113
+
114
+
115
+ def collate_fn_phase2(batch):
116
+ #Unzip the batch
117
+ image_embeddings, ques, ans = zip(*batch)
118
+ #Stack the image embeddings
119
+ if image_embeddings[0] is None:
120
+ image_embeddings_stacked = None
121
+ else:
122
+ image_embeddings_stacked = torch.stack(image_embeddings, dim=0)
123
+ #Pad the QAs, padded value is the <eos> token
124
+ ques_padded = torch.nn.utils.rnn.pad_sequence(ques, batch_first=True, padding_value=50256)
125
+ ans_padded = torch.nn.utils.rnn.pad_sequence(ans, batch_first=True, padding_value=50256)
126
+ #Return the stacked image embeddings and padded QAs
127
+ return (image_embeddings_stacked, ques_padded, ans_padded)
128
+
129
+
130
+ def prep_data(df):
131
+ df_assistant = df[(df.role == "assistant") & (df["rank"] == 0.0)].copy()
132
+ df_prompter = df[(df.role == "prompter")].copy()
133
+ df_prompter = df_prompter.set_index("message_id")
134
+ df_assistant["Answer"] = df_assistant["text"].values
135
+
136
+ inputs = []
137
+ for _, row in df_assistant.iterrows():
138
+ input = df_prompter.loc[row.parent_id]
139
+ inputs.append(input.text)
140
+
141
+ df_assistant["Question"] = inputs
142
+ df_assistant["ImageUrl"] = None
143
+
144
+ df_assistant = df_assistant[df_assistant.lang == "en"]
145
+
146
+ df_assistant = df_assistant[
147
+ ["ImageUrl","Question", "Answer", "message_id"]
148
+ ].rename(columns={"message_id": "Ids"})
149
+
150
+ return df_assistant
151
+
152
+
153
+ def get_i150_df(config):
154
+ with open(config.get("i150k_json"), "r") as fp:
155
+ i150k_json_read = json.load(fp)
156
+ max_tokens = 100
157
+ image_urls = []
158
+ ques_list = []
159
+ ans_list = []
160
+ id_list = []
161
+ for idx, data in enumerate(i150k_json_read):
162
+ image = data['image']
163
+ image_url = 'http://images.cocodataset.org/train2017/' + image
164
+ id_ = data["id"]
165
+ iterator = iter(data['conversations'])
166
+ for i in iterator:
167
+ ques = i
168
+ ans = next(iterator)
169
+ if (len(ques["value"])>100 or len(ans["value"])>max_tokens):
170
+ continue
171
+ if ques["from"] == "human" and ans["from"] == "gpt":
172
+ image_urls.append(image_url)
173
+ ques_list.append(ques["value"].replace("<image>\n","").replace("<image>",""))
174
+ ans_list.append(ans["value"])
175
+ id_list.append(id_)
176
+ df_i150k = pd.DataFrame(list(zip(image_urls, ques_list, ans_list, id_list)),
177
+ columns =["ImageUrl", "Question", "Answer", "Ids"])
178
+ msk = np.random.rand(len(df_i150k)) < 0.96
179
+
180
+ train_df = df_i150k[msk]
181
+ test_df = df_i150k[~msk]
182
+ return train_df, test_df
183
+
184
+
185
+ def get_oas_df(config):
186
+ train_ds, val_ds = load_dataset(config.get("QA_datasetName"), split=["train", "validation"])
187
+ train_df = prep_data(train_ds.to_pandas())
188
+ test_df = prep_data(val_ds.to_pandas())
189
+ return train_df, test_df
190
+
191
+
192
+ def get_data_loaders_phase2(tokenizer, config):
193
+
194
+ train_i150k, test_i150k = get_i150_df(config)
195
+ train_oas, test_oas = get_oas_df(config)
196
+
197
+ train_df = pd.concat([train_i150k, train_oas]).reset_index(drop=True)
198
+ val_df = pd.concat([test_i150k, test_oas]).reset_index(drop=True)
199
+ # train data loaders
200
+ train_dataloader = DataLoader(ClipDatasetPhase2(train_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("train_batch_size"), num_workers = config.get("num_workers"), shuffle=True, pin_memory=True)
201
+
202
+ # val data loaders
203
+ val_dataloader = DataLoader(ClipDatasetPhase2(val_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("val_batch_size"), num_workers = config.get("num_workers"), shuffle=False, pin_memory=True)
204
+ return train_dataloader, val_dataloader
get_coco.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, shutil, json
2
+ import pickle, argparse
3
+
4
+ """Unzip the data and and save it as a pickle file."""
5
+
6
+ def make_pkl(data_dir, dataset_json, train_flag=False):
7
+ coco_data_list = []
8
+ for i, data in enumerate(dataset_json['annotations']):
9
+ image_id = data['image_id']
10
+ caption = data['caption']
11
+ for img in dataset_json['images']:
12
+ if img['id'] == image_id:
13
+ image_url = img['coco_url']
14
+ file_name = img['file_name']
15
+ break
16
+ coco_data_list.append({'image_id': image_id,'image_url': image_url, 'file_name': file_name, 'caption': caption})
17
+ if train_flag:
18
+ with open(os.path.join(data_dir, f'coco_train.pkl'), 'wb') as f:
19
+ pickle.dump(coco_data_list, f)
20
+ else:
21
+ with open(os.path.join(data_dir, f'coco_val.pkl'), 'wb') as f:
22
+ pickle.dump(coco_data_list, f)
23
+
24
+
25
+ def main(coco_path, data_dir):
26
+ coco_dir = os.path.dirname(coco_path)
27
+ # shutil.unpack_archive(coco_path, coco_dir)
28
+ with open(os.path.join(coco_dir, 'annotations/captions_train2017.json')) as f:
29
+ coco_train_dataset = json.load(f)
30
+ with open(os.path.join(coco_dir, 'annotations/captions_val2017.json')) as f:
31
+ coco_val_dataset = json.load(f)
32
+ make_pkl(data_dir, coco_train_dataset, train_flag=True)
33
+ # make_pkl(data_dir, coco_val_dataset)
34
+
35
+
36
+ if __name__ == '__main__':
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument('--coco_path', type=str, default='coco.zip')
39
+ parser.add_argument('--data_dir', type=str, default='data')
40
+ args = parser.parse_args()
41
+ main(args.coco_path, args.data_dir)
main.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataset import get_data_loaders_phase1, get_data_loaders_phase2
3
+ from transformers import AutoTokenizer
4
+ from model import CustomClipPhi2, MainQLoraModel, train_model_phase1, train_model_phase2
5
+ from configs import get_config_phase1, get_config_phase2
6
+
7
+ def phase_1():
8
+ # get config
9
+ config = get_config_phase1()
10
+ # tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
12
+
13
+ # data loaders
14
+ train_dataloader, val_dataloader = get_data_loaders_phase1(config.get("data_dir"), config.get("clip_model_name"), tokenizer, config.get("train_batch_size"), config.get("val_batch_size"), config.get("num_workers"))
15
+
16
+ llmModel = CustomClipPhi2(tokenizer, config.get("phi2_model_name"), config.get("clip_model_name"), clip_embed=768, phi_embed=2560).to(config.get("device"))
17
+ print(llmModel)
18
+ # optimizer
19
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, llmModel.parameters()), lr=1e-3)
20
+ # train model
21
+ train_model_phase1(llmModel, train_dataloader, val_dataloader, optimizer, tokenizer, config)
22
+
23
+
24
+ def phase_2():
25
+ # get config
26
+ config = get_config_phase2()
27
+ # tokenizer
28
+ tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
29
+
30
+ # data loaders
31
+ train_dataloader, val_dataloader = get_data_loaders_phase2(tokenizer, config)
32
+
33
+ llmModel = MainQLoraModel(tokenizer, config).to(config.get("device"))
34
+ print(llmModel)
35
+ # train model
36
+ train_model_phase2(llmModel, train_dataloader, val_dataloader, tokenizer, config)
37
+
38
+ if __name__ == "__main__":
39
+ torch.set_float32_matmul_precision('medium')
40
+ phase_1()
41
+ # phase_2()
model.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.functional import cross_entropy
4
+ from transformers import CLIPVisionModel, AutoModelForCausalLM, BitsAndBytesConfig
5
+ from peft import LoraConfig
6
+ from tqdm import tqdm
7
+ import os, peft
8
+
9
+
10
+ class CustomClipPhi2(nn.Module):
11
+ def __init__(self,tokenizer, phi2_model_name, clip_model_name, clip_embed=768, phi_embed=2560):
12
+ super().__init__()
13
+
14
+ self.tokenizer = tokenizer
15
+ # These two models are not finetuned
16
+ # pretrained Microsoft phi2 model
17
+ self.phi2_model = AutoModelForCausalLM.from_pretrained(phi2_model_name,torch_dtype=torch.float32, trust_remote_code=True)
18
+ # pretrained OpenAI clip model
19
+ self.clip_model = CLIPVisionModel.from_pretrained(clip_model_name)
20
+
21
+ self.EOS_TOKEN_ID = self.tokenizer.eos_token_id # 50256
22
+ self.IMAGE_TOKEN_ID = 23903 # token for Comments
23
+ self.clip_embed = clip_embed
24
+ self.phi_embed = phi_embed
25
+
26
+ # projection layers
27
+ # Trainable projection layer
28
+ self.projection_layer = torch.nn.Linear(clip_embed, phi_embed)
29
+
30
+ # Freeze Weights
31
+ for models in [self.phi2_model, self.clip_model]:
32
+ for param in models.parameters():
33
+ param.requires_grad_(False)
34
+
35
+ # load checkpoint weights
36
+ if os.path.exists('./ckpts/model_phase1.pth'):
37
+ self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location='cpu'))
38
+ print("Loaded checkpoint weights for projection layer")
39
+ else:
40
+ print("No checkpoint weights for projection layer")
41
+ print("Initializing projection layer with random weights")
42
+ self.projection_layer.weight.data.normal_(mean=0.0, std=0.02)
43
+ self.projection_layer.bias.data.zero_()
44
+
45
+
46
+ def generate(self, images, tokenizer, config):
47
+ clip_outputs = self.clip_model(**images)
48
+ # remove cls token
49
+ images = clip_outputs.last_hidden_state[:, 1:, :]
50
+ image_embeddings = self.projection_layer(images).to(torch.float16)
51
+
52
+ batch_size = images.size()[0]
53
+ predicted_caption = torch.full((batch_size, config.get("max_tokens")), self.EOS_TOKEN_ID, dtype=torch.long, device=config.get('device'))
54
+ img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
55
+ img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
56
+ combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1)
57
+
58
+ for pos in range(config.get("max_tokens") - 1):
59
+ model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
60
+ predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
61
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
62
+ predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
63
+ next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
64
+ combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
65
+ return predicted_caption
66
+
67
+
68
+ def forward(self, images, target_captions):
69
+
70
+ batch_size = target_captions.size()[0]
71
+ target_length = target_captions.size()[1]
72
+ print("---", target_length)
73
+
74
+ # clip model output for image
75
+ clip_outputs = self.clip_model(**images) # See this for loading https://huggingface.co/openai/clip-vit-base-patch36
76
+ images = clip_outputs.last_hidden_state[:, 1:, :] # remove CLS token
77
+
78
+ # projection layer
79
+ image_embeddings = self.projection_layer(images).to(torch.float16)
80
+
81
+ # add comment token from phi2
82
+ img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
83
+ img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
84
+ combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1) # 4,49,2560
85
+ del clip_outputs
86
+ del image_embeddings
87
+
88
+ # for loss
89
+ loss = 0
90
+ for pos in range(target_length - 1):
91
+
92
+ model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
93
+ predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
94
+ pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), target_captions[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
95
+ loss += pos_loss
96
+
97
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim=-1)
98
+ next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
99
+ combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
100
+ loss = loss / target_length
101
+
102
+ # Delete variables to free up memory
103
+ del combined_embeds
104
+ del model_output_logits
105
+ torch.cuda.empty_cache()
106
+
107
+ return loss
108
+
109
+
110
+ def show_results_for_samples_phase1(model, val_dataloader, tokenizer, config, num_samples = 2):
111
+ model.eval()
112
+ with torch.no_grad():
113
+ for i in range(num_samples):
114
+ for images, target_captions in val_dataloader:
115
+ images = {'pixel_values': images.to(config.get('device'))}
116
+ target_captions = target_captions.to(config.get('device'))
117
+ target_captions_decoded = tokenizer.batch_decode(target_captions, ignore_index = tokenizer.eos_token_id)
118
+ predicted_captions = model.generate(images, tokenizer, config)
119
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_captions,ignore_index = tokenizer.eos_token_id)
120
+
121
+ for idx, pc in enumerate(predicted_captions_decoded):
122
+ print(f"{idx} - Target captions: {target_captions_decoded[idx]} \n {'---------------------'*10} \n Predicted_captions:{pc} ")
123
+ break
124
+
125
+
126
+ def validate_model_phase1(model, val_dataloader, tokenizer, config):
127
+ model.eval()
128
+ total_loss = 0
129
+ with torch.no_grad():
130
+ try:
131
+ for images, target_captions in tqdm(val_dataloader):
132
+ images = {'pixel_values': images.to(config.get('device'))}
133
+ target_captions = target_captions.to(config.get('device'))
134
+ loss = model(images, target_captions)
135
+ total_loss+=loss.item()
136
+ print(f"Validation Loss: {total_loss/len(val_dataloader)}")
137
+ except Exception as e:
138
+ pass
139
+ model.train()
140
+
141
+
142
+ def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer, config):
143
+ model.train()
144
+
145
+ pbar = tqdm(train_loader)
146
+ for epoch in range(1, config.get("epochs")):
147
+ print(f"Epoch: {epoch}")
148
+ torch.cuda.empty_cache()
149
+ step = 1
150
+ try:
151
+ for idx, (images, target_captions) in enumerate(pbar):
152
+ try:
153
+ if target_captions.shape[1] >= config.get("max_tokens"):
154
+ # print(f"Skipping batch {idx} due to long caption")
155
+ continue
156
+
157
+ images = {'pixel_values': images.to(config.get('device'))}
158
+ target_captions = target_captions.to(config.get('device'))
159
+
160
+ optimizer.zero_grad()
161
+ loss = model(images, target_captions)
162
+ loss.backward()
163
+ optimizer.step()
164
+ pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
165
+ torch.cuda.empty_cache()
166
+ step+=1
167
+ if (step%1000==0):
168
+ torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
169
+ except Exception as e:
170
+ print(e)
171
+ continue
172
+
173
+ # # save model
174
+ # if ((epoch % 2) == 0):
175
+ # Only save last checkpoint
176
+ validate_model_phase1(model, val_dataloader, tokenizer, config)
177
+ show_results_for_samples_phase1(model, val_dataloader, tokenizer, config)
178
+ torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
179
+
180
+ except Exception as e:
181
+ print(e)
182
+ continue
183
+
184
+
185
+
186
+
187
+ ######################################## Phase 2 #########################################
188
+
189
+ class MainQLoraModel(nn.Module):
190
+ def __init__(self, tokenizer, config):
191
+ super().__init__()
192
+ self.tokenizer = tokenizer
193
+ self.config = config
194
+ self.clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
195
+
196
+ bnb_config = BitsAndBytesConfig(
197
+ load_in_4bit=True,
198
+ bnb_4bit_quant_type="nf4",
199
+ bnb_4bit_compute_dtype=torch.float16,
200
+ )
201
+
202
+ phi2_model = AutoModelForCausalLM.from_pretrained(
203
+ config.get("phi2_model_name"),
204
+ quantization_config=bnb_config,
205
+ trust_remote_code=True
206
+ )
207
+ phi2_model.config.use_cache = False
208
+
209
+ ## 4 - LORA config
210
+
211
+ lora_alpha = 16
212
+ lora_dropout = 0.1
213
+ lora_r = 64
214
+
215
+ peft_config = LoraConfig(
216
+ lora_alpha = lora_alpha,
217
+ lora_dropout = lora_dropout,
218
+ r = lora_r,
219
+ bias="none",
220
+ task_type="CAUSAL_LM",
221
+ target_modules=[
222
+ "q_proj",
223
+ "k_proj",
224
+ "v_proj",
225
+ "dense",
226
+ "fc1",
227
+ "fc2"
228
+ ]
229
+ )
230
+ self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device"))
231
+
232
+ self.EOS_TOKEN_ID = self.tokenizer.eos_token_id
233
+ self.clip_embed = config.get("clip_embed")
234
+ self.phi_embed = config.get("phi_embed")
235
+
236
+ # projection layers
237
+ # Trainable projection layer
238
+ self.projection_layer = torch.nn.Linear(self.clip_embed, self.phi_embed)
239
+
240
+ # Freeze Weights
241
+ for models in [self.clip_model]:
242
+ for param in models.parameters():
243
+ param.requires_grad_(False)
244
+
245
+ # load checkpoint weights
246
+ if os.path.exists('./ckpts/model_phase2.pth'):
247
+ self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
248
+ self.phi2_model.from_pretrained(self.phi2_model,'./ckpts/Qlora_adaptor')
249
+ print("Loaded checkpoint weights for projection layer")
250
+ else:
251
+ # Load weights from phase 1
252
+ self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device")))
253
+
254
+
255
+ def generate(self, tokenizer, config, images = None, ques = None, max_tokens = 100):
256
+ batch_size = 1
257
+
258
+ predicted_caption = torch.full((batch_size, max_tokens), self.EOS_TOKEN_ID, dtype=torch.long, device=self.config.get('device'))
259
+ start_iq = self.tokenizer.encode("<iQ>")
260
+ end_iq = self.tokenizer.encode("</iQ>")
261
+ start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
262
+ end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
263
+ start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
264
+ end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
265
+ questions_embed = self.phi2_model.model.model.embed_tokens(ques)
266
+ if images is not None:
267
+ clip_outputs = self.clip_model(**images)
268
+ # remove cls token
269
+ images = clip_outputs.last_hidden_state[:, 1:, :]
270
+ image_embeddings = self.projection_layer(images).to(torch.float16)
271
+ combined_embeds = torch.cat([start_iq_embeds, image_embeddings, questions_embed, end_iq_embeds], dim=1)
272
+ else:
273
+ combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds], dim=1)
274
+
275
+ for pos in range(max_tokens - 1):
276
+ model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
277
+ predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
278
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
279
+ predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
280
+ next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
281
+ combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
282
+ return predicted_caption
283
+
284
+
285
+ def forward(self, images, ques, ans):
286
+
287
+ batch_size = ques.size()[0]
288
+ questions = ques.to(self.config.get("device"))
289
+ answers = ans.to(self.config.get("device"))
290
+ target_length = ans.size()[1]
291
+ start_iq = self.tokenizer.encode("<iQ>")
292
+ end_iq = self.tokenizer.encode("</iQ>")
293
+ start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
294
+ end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
295
+ start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
296
+ end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
297
+
298
+ questions_embed = self.phi2_model.model.model.embed_tokens(questions)
299
+ answers_embed = self.phi2_model.model.model.embed_tokens(answers)
300
+
301
+ are_all_zeros = torch.all(images == 0).item()
302
+ if are_all_zeros:
303
+ combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
304
+ else:
305
+ images = {'pixel_values': images.to(self.config.get("device"))}
306
+ clip_outputs = self.clip_model(**images)
307
+ images_embeds = clip_outputs.last_hidden_state[:,1:,:] # remove cls token
308
+
309
+ # projection
310
+ image_embeds = self.projection_layer(images_embeds).to(torch.float16)
311
+ combined_embeds = torch.cat([start_iq_embeds, image_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
312
+
313
+ model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
314
+ # # for loss
315
+ loss = 0
316
+ for pos in range(target_length - 1):
317
+ predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
318
+ pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), answers[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
319
+ loss += pos_loss
320
+ loss = loss / target_length
321
+
322
+ # Delete variables to free up memory
323
+ del combined_embeds
324
+ del model_output_logits
325
+ torch.cuda.empty_cache()
326
+ return loss
327
+
328
+ def validate_model_phase2(model, val_dataloader, tokenizer, config):
329
+ model.eval()
330
+ total_loss = 0
331
+ with torch.no_grad():
332
+ # try:
333
+ for images, ques, ans in tqdm(val_dataloader):
334
+ loss = model(images, ques, ans)
335
+ total_loss+=loss.item()
336
+ print(f"Validation Loss: {total_loss/len(val_dataloader)}")
337
+ # except Exception as e:
338
+ # pass
339
+ model.train()
340
+
341
+
342
+ def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config):
343
+ phi2_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.phi2_model.parameters()), lr=1e-5)
344
+ proj_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.projection_layer.parameters()), lr=1e-5)
345
+ model.phi2_model.train()
346
+ model.projection_layer.train()
347
+
348
+ pbar = tqdm(train_loader)
349
+ for epoch in range(1, config.get("epochs")):
350
+ print(f"Epoch: {epoch}")
351
+ torch.cuda.empty_cache()
352
+ step = 1
353
+ try:
354
+ for idx, (images, ques, ans) in enumerate(pbar):
355
+ try:
356
+ phi2_optim.zero_grad()
357
+ proj_optim.zero_grad()
358
+ loss = model(images, ques, ans)
359
+ loss.backward()
360
+ phi2_optim.step()
361
+ proj_optim.step()
362
+ pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
363
+ torch.cuda.empty_cache()
364
+ step+=1
365
+ if (step%1000==0):
366
+ torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
367
+ model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
368
+ except Exception as e:
369
+ print("in frp",e)
370
+ continue
371
+
372
+ validate_model_phase2(model, val_dataloader, tokenizer, config)
373
+ torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
374
+ model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
375
+
376
+ except Exception as e:
377
+ print(e)
378
+ continue
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ git+https://github.com/huggingface/peft.git
4
+ accelerate
5
+ transformers
6
+ einops
7
+ git+https://github.com/m-bain/whisperx.git