Vasudevakrishna
commited on
Commit
β’
94f80f5
1
Parent(s):
0d94ec2
Upload 7 files
Browse files- README.md +57 -13
- configs.py +33 -0
- dataset.py +204 -0
- get_coco.py +41 -0
- main.py +41 -0
- model.py +378 -0
- requirements.txt +7 -0
README.md
CHANGED
@@ -1,13 +1,57 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-Modal LLM Gradio App
|
2 |
+
|
3 |
+
## Project Overview
|
4 |
+
|
5 |
+
This project is a **multi-modal language model** Gradio app that accepts **text**, **image**, and **audio inputs**, and outputs **text responses**. The app mimics a **ChatGPT-style interface**, allowing users to interact using multiple input modes.
|
6 |
+
|
7 |
+
The app leverages:
|
8 |
+
- **CLIP** for image processing
|
9 |
+
- **Whisper** for audio transcription (ASR)
|
10 |
+
- A **text-based model** (like GPT or Phi) for generating text responses
|
11 |
+
|
12 |
+
## Features
|
13 |
+
|
14 |
+
- **Text Input**: Users can input text directly for response generation.
|
15 |
+
- **Image Input**: Users can upload images, which are processed by the CLIP model.
|
16 |
+
- **Audio Input**: Users can upload or record audio files, which are transcribed by the Whisper model and then processed for response.
|
17 |
+
- **ChatGPT-Like Interface**: Simple and intuitive interface to handle multi-modal inputs and provide text-based output.
|
18 |
+
|
19 |
+
## Installation
|
20 |
+
|
21 |
+
1. Clone the repository:
|
22 |
+
```bash
|
23 |
+
git clone https://huggingface.co/spaces/Vasudevakrishna/MultiModel_LLM_ERAV2
|
24 |
+
cd MultiModel_LLM_ERAV2
|
25 |
+
```
|
26 |
+
|
27 |
+
2. Install dependencies:
|
28 |
+
```bash
|
29 |
+
pip -r requirements.txt
|
30 |
+
```
|
31 |
+
|
32 |
+
3. Run the app:
|
33 |
+
```bash
|
34 |
+
python app.py
|
35 |
+
```
|
36 |
+
|
37 |
+
## How It Works
|
38 |
+
|
39 |
+
1. **Text Processing**: Input text is passed to a language model (like GPT or Phi) to generate a response.
|
40 |
+
2. **Image Processing**: Images are processed using CLIP, which extracts embeddings. These embeddings are then converted into a format understandable by the text model.
|
41 |
+
3. **Audio Processing**: Audio files are transcribed into text using Whisper. This text is passed into the language model for response generation.
|
42 |
+
|
43 |
+
## Usage
|
44 |
+
|
45 |
+
- **Text Input**: Enter text in the provided textbox and click "Submit" to generate a response.
|
46 |
+
- **Image Input**: Upload an image and click "Submit" to generate a response based on the image.
|
47 |
+
- **Audio Input**: Upload or record an audio file, click "Submit" to transcribe and generate a response.
|
48 |
+
|
49 |
+
## Future Improvements
|
50 |
+
|
51 |
+
- Add advanced features like drag-and-drop file upload or live audio recording for a better user experience.
|
52 |
+
- Improve the real-time image embedding process by running CLIP embeddings in real-time with more GPU resources.
|
53 |
+
- Implement end-to-end training of all components for better response quality.
|
54 |
+
|
55 |
+
## License
|
56 |
+
|
57 |
+
This project is licensed under the MIT License.
|
configs.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
def get_config_phase1():
|
3 |
+
return {
|
4 |
+
"data_dir": "./data",
|
5 |
+
"clip_model_name": "openai/clip-vit-base-patch16",
|
6 |
+
"phi2_model_name": "microsoft/phi-2",
|
7 |
+
"train_batch_size": 2,
|
8 |
+
"val_batch_size": 1,
|
9 |
+
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
10 |
+
"epochs": 2,
|
11 |
+
"max_tokens": 20,
|
12 |
+
"clip_embed": 768,
|
13 |
+
"phi_embed": 2560,
|
14 |
+
"num_workers": 4,
|
15 |
+
"ckpts": "./ckpts"
|
16 |
+
}
|
17 |
+
|
18 |
+
def get_config_phase2():
|
19 |
+
return {
|
20 |
+
"data_dir": "./data",
|
21 |
+
"clip_model_name": "openai/clip-vit-base-patch16",
|
22 |
+
"phi2_model_name": "microsoft/phi-2",
|
23 |
+
"train_batch_size": 1,
|
24 |
+
"val_batch_size": 1,
|
25 |
+
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
26 |
+
"epochs": 10,
|
27 |
+
"max_tokens": 100,
|
28 |
+
"clip_embed": 768,
|
29 |
+
"phi_embed": 2560,
|
30 |
+
"num_workers": 0,
|
31 |
+
"ckpts": "./ckpts",
|
32 |
+
"vocab_size": 51200
|
33 |
+
}
|
dataset.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from transformers import AutoProcessor
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
import pickle
|
9 |
+
import requests
|
10 |
+
from datasets import Dataset, load_dataset
|
11 |
+
import pandas as pd
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
|
15 |
+
class ClipDataset(Dataset):
|
16 |
+
'''ClipDataset class for loading the CLIP dataset'''
|
17 |
+
def __init__(self, coco_data, model_name, tokenizer):
|
18 |
+
|
19 |
+
self.tokenizer = tokenizer
|
20 |
+
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
21 |
+
self.caption_dataset = coco_data
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
#Return the length of the dataset
|
25 |
+
return len(self.caption_dataset)
|
26 |
+
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
#Get the image url and caption
|
29 |
+
img_url = self.caption_dataset[idx]["image_url"]
|
30 |
+
caption = self.caption_dataset[idx]["caption"]
|
31 |
+
|
32 |
+
#Get the image and caption embeddings
|
33 |
+
image = Image.open(requests.get(img_url,stream=True).raw)
|
34 |
+
width, height = image.size
|
35 |
+
new_width = 224
|
36 |
+
new_height = new_width * height // width
|
37 |
+
new_height = 224
|
38 |
+
new_width = new_height * width // height
|
39 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
40 |
+
image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values']
|
41 |
+
image_sqeezed = image_processed.squeeze(0)
|
42 |
+
tokenized_caption = self.tokenizer(caption, return_tensors="pt", return_attention_mask=False)
|
43 |
+
tokenized_caption_ids = tokenized_caption['input_ids'].squeeze(0)
|
44 |
+
return(image_sqeezed , tokenized_caption_ids)
|
45 |
+
|
46 |
+
|
47 |
+
def collate_fn_phase1(batch):
|
48 |
+
#Unzip the batch
|
49 |
+
image_embeddings, captions = zip(*batch)
|
50 |
+
#Stack the image embeddings
|
51 |
+
image_embeddings_stacked = torch.stack(image_embeddings, dim=0)
|
52 |
+
#Pad the captions, padded value is the <eos> token
|
53 |
+
captions_padded = torch.nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=50256)
|
54 |
+
#Return the stacked image embeddings and padded captions
|
55 |
+
return (image_embeddings_stacked, captions_padded)
|
56 |
+
|
57 |
+
|
58 |
+
def get_data_loaders_phase1(data_dir, clip_model_name, tokenizer, train_batch_size, val_batch_size, num_workers):
|
59 |
+
# Load the data
|
60 |
+
with open(os.path.join(data_dir, 'coco_train.pkl'), 'rb') as fp:
|
61 |
+
train_pkl = pickle.load(fp)
|
62 |
+
with open(os.path.join(data_dir, "coco_val.pkl"), "rb") as fp:
|
63 |
+
val_pkl = pickle.load(fp)
|
64 |
+
# train data loaders
|
65 |
+
train_dataloader = DataLoader(ClipDataset(train_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=train_batch_size, num_workers = num_workers, shuffle=True, pin_memory=True)
|
66 |
+
|
67 |
+
# val data loaders
|
68 |
+
val_dataloader = DataLoader(ClipDataset(val_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=val_batch_size, num_workers = num_workers, shuffle=False, pin_memory=True)
|
69 |
+
return train_dataloader, val_dataloader
|
70 |
+
|
71 |
+
##################################### Phase 2 #########################################
|
72 |
+
|
73 |
+
|
74 |
+
class ClipDatasetPhase2(Dataset):
|
75 |
+
'''ClipDataset class for loading the CLIP dataset'''
|
76 |
+
def __init__(self, data_frame, model_name, tokenizer):
|
77 |
+
|
78 |
+
self.tokenizer = tokenizer
|
79 |
+
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
80 |
+
self.df = data_frame
|
81 |
+
|
82 |
+
def __len__(self):
|
83 |
+
#Return the length of the dataset
|
84 |
+
return len(self.df)
|
85 |
+
|
86 |
+
def __getitem__(self, idx):
|
87 |
+
#Get the image url and QAs
|
88 |
+
img_url = self.df.ImageUrl[idx[0]]
|
89 |
+
que = self.df.Question[idx[0]]
|
90 |
+
ans = self.df.Answer[idx[0]]
|
91 |
+
|
92 |
+
print("img_url", img_url)
|
93 |
+
print("que", que)
|
94 |
+
print("ans", ans)
|
95 |
+
|
96 |
+
#Get the image and caption embeddings
|
97 |
+
if img_url is None:
|
98 |
+
print("img_url is None")
|
99 |
+
image_sqeezed = None
|
100 |
+
else:
|
101 |
+
image = Image.open(requests.get(img_url,stream=True).raw)
|
102 |
+
width, height = image.size
|
103 |
+
new_width = 224
|
104 |
+
new_height = new_width * height // width
|
105 |
+
new_height = 224
|
106 |
+
new_width = new_height * width // height
|
107 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
108 |
+
image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values']
|
109 |
+
image_sqeezed = image_processed.squeeze(0)
|
110 |
+
que_ids = self.tokenizer(que, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
111 |
+
ans_ids = self.tokenizer(ans, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
112 |
+
return(image_sqeezed , que_ids, ans_ids)
|
113 |
+
|
114 |
+
|
115 |
+
def collate_fn_phase2(batch):
|
116 |
+
#Unzip the batch
|
117 |
+
image_embeddings, ques, ans = zip(*batch)
|
118 |
+
#Stack the image embeddings
|
119 |
+
if image_embeddings[0] is None:
|
120 |
+
image_embeddings_stacked = None
|
121 |
+
else:
|
122 |
+
image_embeddings_stacked = torch.stack(image_embeddings, dim=0)
|
123 |
+
#Pad the QAs, padded value is the <eos> token
|
124 |
+
ques_padded = torch.nn.utils.rnn.pad_sequence(ques, batch_first=True, padding_value=50256)
|
125 |
+
ans_padded = torch.nn.utils.rnn.pad_sequence(ans, batch_first=True, padding_value=50256)
|
126 |
+
#Return the stacked image embeddings and padded QAs
|
127 |
+
return (image_embeddings_stacked, ques_padded, ans_padded)
|
128 |
+
|
129 |
+
|
130 |
+
def prep_data(df):
|
131 |
+
df_assistant = df[(df.role == "assistant") & (df["rank"] == 0.0)].copy()
|
132 |
+
df_prompter = df[(df.role == "prompter")].copy()
|
133 |
+
df_prompter = df_prompter.set_index("message_id")
|
134 |
+
df_assistant["Answer"] = df_assistant["text"].values
|
135 |
+
|
136 |
+
inputs = []
|
137 |
+
for _, row in df_assistant.iterrows():
|
138 |
+
input = df_prompter.loc[row.parent_id]
|
139 |
+
inputs.append(input.text)
|
140 |
+
|
141 |
+
df_assistant["Question"] = inputs
|
142 |
+
df_assistant["ImageUrl"] = None
|
143 |
+
|
144 |
+
df_assistant = df_assistant[df_assistant.lang == "en"]
|
145 |
+
|
146 |
+
df_assistant = df_assistant[
|
147 |
+
["ImageUrl","Question", "Answer", "message_id"]
|
148 |
+
].rename(columns={"message_id": "Ids"})
|
149 |
+
|
150 |
+
return df_assistant
|
151 |
+
|
152 |
+
|
153 |
+
def get_i150_df(config):
|
154 |
+
with open(config.get("i150k_json"), "r") as fp:
|
155 |
+
i150k_json_read = json.load(fp)
|
156 |
+
max_tokens = 100
|
157 |
+
image_urls = []
|
158 |
+
ques_list = []
|
159 |
+
ans_list = []
|
160 |
+
id_list = []
|
161 |
+
for idx, data in enumerate(i150k_json_read):
|
162 |
+
image = data['image']
|
163 |
+
image_url = 'http://images.cocodataset.org/train2017/' + image
|
164 |
+
id_ = data["id"]
|
165 |
+
iterator = iter(data['conversations'])
|
166 |
+
for i in iterator:
|
167 |
+
ques = i
|
168 |
+
ans = next(iterator)
|
169 |
+
if (len(ques["value"])>100 or len(ans["value"])>max_tokens):
|
170 |
+
continue
|
171 |
+
if ques["from"] == "human" and ans["from"] == "gpt":
|
172 |
+
image_urls.append(image_url)
|
173 |
+
ques_list.append(ques["value"].replace("<image>\n","").replace("<image>",""))
|
174 |
+
ans_list.append(ans["value"])
|
175 |
+
id_list.append(id_)
|
176 |
+
df_i150k = pd.DataFrame(list(zip(image_urls, ques_list, ans_list, id_list)),
|
177 |
+
columns =["ImageUrl", "Question", "Answer", "Ids"])
|
178 |
+
msk = np.random.rand(len(df_i150k)) < 0.96
|
179 |
+
|
180 |
+
train_df = df_i150k[msk]
|
181 |
+
test_df = df_i150k[~msk]
|
182 |
+
return train_df, test_df
|
183 |
+
|
184 |
+
|
185 |
+
def get_oas_df(config):
|
186 |
+
train_ds, val_ds = load_dataset(config.get("QA_datasetName"), split=["train", "validation"])
|
187 |
+
train_df = prep_data(train_ds.to_pandas())
|
188 |
+
test_df = prep_data(val_ds.to_pandas())
|
189 |
+
return train_df, test_df
|
190 |
+
|
191 |
+
|
192 |
+
def get_data_loaders_phase2(tokenizer, config):
|
193 |
+
|
194 |
+
train_i150k, test_i150k = get_i150_df(config)
|
195 |
+
train_oas, test_oas = get_oas_df(config)
|
196 |
+
|
197 |
+
train_df = pd.concat([train_i150k, train_oas]).reset_index(drop=True)
|
198 |
+
val_df = pd.concat([test_i150k, test_oas]).reset_index(drop=True)
|
199 |
+
# train data loaders
|
200 |
+
train_dataloader = DataLoader(ClipDatasetPhase2(train_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("train_batch_size"), num_workers = config.get("num_workers"), shuffle=True, pin_memory=True)
|
201 |
+
|
202 |
+
# val data loaders
|
203 |
+
val_dataloader = DataLoader(ClipDatasetPhase2(val_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("val_batch_size"), num_workers = config.get("num_workers"), shuffle=False, pin_memory=True)
|
204 |
+
return train_dataloader, val_dataloader
|
get_coco.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, shutil, json
|
2 |
+
import pickle, argparse
|
3 |
+
|
4 |
+
"""Unzip the data and and save it as a pickle file."""
|
5 |
+
|
6 |
+
def make_pkl(data_dir, dataset_json, train_flag=False):
|
7 |
+
coco_data_list = []
|
8 |
+
for i, data in enumerate(dataset_json['annotations']):
|
9 |
+
image_id = data['image_id']
|
10 |
+
caption = data['caption']
|
11 |
+
for img in dataset_json['images']:
|
12 |
+
if img['id'] == image_id:
|
13 |
+
image_url = img['coco_url']
|
14 |
+
file_name = img['file_name']
|
15 |
+
break
|
16 |
+
coco_data_list.append({'image_id': image_id,'image_url': image_url, 'file_name': file_name, 'caption': caption})
|
17 |
+
if train_flag:
|
18 |
+
with open(os.path.join(data_dir, f'coco_train.pkl'), 'wb') as f:
|
19 |
+
pickle.dump(coco_data_list, f)
|
20 |
+
else:
|
21 |
+
with open(os.path.join(data_dir, f'coco_val.pkl'), 'wb') as f:
|
22 |
+
pickle.dump(coco_data_list, f)
|
23 |
+
|
24 |
+
|
25 |
+
def main(coco_path, data_dir):
|
26 |
+
coco_dir = os.path.dirname(coco_path)
|
27 |
+
# shutil.unpack_archive(coco_path, coco_dir)
|
28 |
+
with open(os.path.join(coco_dir, 'annotations/captions_train2017.json')) as f:
|
29 |
+
coco_train_dataset = json.load(f)
|
30 |
+
with open(os.path.join(coco_dir, 'annotations/captions_val2017.json')) as f:
|
31 |
+
coco_val_dataset = json.load(f)
|
32 |
+
make_pkl(data_dir, coco_train_dataset, train_flag=True)
|
33 |
+
# make_pkl(data_dir, coco_val_dataset)
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
parser = argparse.ArgumentParser()
|
38 |
+
parser.add_argument('--coco_path', type=str, default='coco.zip')
|
39 |
+
parser.add_argument('--data_dir', type=str, default='data')
|
40 |
+
args = parser.parse_args()
|
41 |
+
main(args.coco_path, args.data_dir)
|
main.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dataset import get_data_loaders_phase1, get_data_loaders_phase2
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
from model import CustomClipPhi2, MainQLoraModel, train_model_phase1, train_model_phase2
|
5 |
+
from configs import get_config_phase1, get_config_phase2
|
6 |
+
|
7 |
+
def phase_1():
|
8 |
+
# get config
|
9 |
+
config = get_config_phase1()
|
10 |
+
# tokenizer
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
|
12 |
+
|
13 |
+
# data loaders
|
14 |
+
train_dataloader, val_dataloader = get_data_loaders_phase1(config.get("data_dir"), config.get("clip_model_name"), tokenizer, config.get("train_batch_size"), config.get("val_batch_size"), config.get("num_workers"))
|
15 |
+
|
16 |
+
llmModel = CustomClipPhi2(tokenizer, config.get("phi2_model_name"), config.get("clip_model_name"), clip_embed=768, phi_embed=2560).to(config.get("device"))
|
17 |
+
print(llmModel)
|
18 |
+
# optimizer
|
19 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, llmModel.parameters()), lr=1e-3)
|
20 |
+
# train model
|
21 |
+
train_model_phase1(llmModel, train_dataloader, val_dataloader, optimizer, tokenizer, config)
|
22 |
+
|
23 |
+
|
24 |
+
def phase_2():
|
25 |
+
# get config
|
26 |
+
config = get_config_phase2()
|
27 |
+
# tokenizer
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
|
29 |
+
|
30 |
+
# data loaders
|
31 |
+
train_dataloader, val_dataloader = get_data_loaders_phase2(tokenizer, config)
|
32 |
+
|
33 |
+
llmModel = MainQLoraModel(tokenizer, config).to(config.get("device"))
|
34 |
+
print(llmModel)
|
35 |
+
# train model
|
36 |
+
train_model_phase2(llmModel, train_dataloader, val_dataloader, tokenizer, config)
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
torch.set_float32_matmul_precision('medium')
|
40 |
+
phase_1()
|
41 |
+
# phase_2()
|
model.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.functional import cross_entropy
|
4 |
+
from transformers import CLIPVisionModel, AutoModelForCausalLM, BitsAndBytesConfig
|
5 |
+
from peft import LoraConfig
|
6 |
+
from tqdm import tqdm
|
7 |
+
import os, peft
|
8 |
+
|
9 |
+
|
10 |
+
class CustomClipPhi2(nn.Module):
|
11 |
+
def __init__(self,tokenizer, phi2_model_name, clip_model_name, clip_embed=768, phi_embed=2560):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
self.tokenizer = tokenizer
|
15 |
+
# These two models are not finetuned
|
16 |
+
# pretrained Microsoft phi2 model
|
17 |
+
self.phi2_model = AutoModelForCausalLM.from_pretrained(phi2_model_name,torch_dtype=torch.float32, trust_remote_code=True)
|
18 |
+
# pretrained OpenAI clip model
|
19 |
+
self.clip_model = CLIPVisionModel.from_pretrained(clip_model_name)
|
20 |
+
|
21 |
+
self.EOS_TOKEN_ID = self.tokenizer.eos_token_id # 50256
|
22 |
+
self.IMAGE_TOKEN_ID = 23903 # token for Comments
|
23 |
+
self.clip_embed = clip_embed
|
24 |
+
self.phi_embed = phi_embed
|
25 |
+
|
26 |
+
# projection layers
|
27 |
+
# Trainable projection layer
|
28 |
+
self.projection_layer = torch.nn.Linear(clip_embed, phi_embed)
|
29 |
+
|
30 |
+
# Freeze Weights
|
31 |
+
for models in [self.phi2_model, self.clip_model]:
|
32 |
+
for param in models.parameters():
|
33 |
+
param.requires_grad_(False)
|
34 |
+
|
35 |
+
# load checkpoint weights
|
36 |
+
if os.path.exists('./ckpts/model_phase1.pth'):
|
37 |
+
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location='cpu'))
|
38 |
+
print("Loaded checkpoint weights for projection layer")
|
39 |
+
else:
|
40 |
+
print("No checkpoint weights for projection layer")
|
41 |
+
print("Initializing projection layer with random weights")
|
42 |
+
self.projection_layer.weight.data.normal_(mean=0.0, std=0.02)
|
43 |
+
self.projection_layer.bias.data.zero_()
|
44 |
+
|
45 |
+
|
46 |
+
def generate(self, images, tokenizer, config):
|
47 |
+
clip_outputs = self.clip_model(**images)
|
48 |
+
# remove cls token
|
49 |
+
images = clip_outputs.last_hidden_state[:, 1:, :]
|
50 |
+
image_embeddings = self.projection_layer(images).to(torch.float16)
|
51 |
+
|
52 |
+
batch_size = images.size()[0]
|
53 |
+
predicted_caption = torch.full((batch_size, config.get("max_tokens")), self.EOS_TOKEN_ID, dtype=torch.long, device=config.get('device'))
|
54 |
+
img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
|
55 |
+
img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
|
56 |
+
combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1)
|
57 |
+
|
58 |
+
for pos in range(config.get("max_tokens") - 1):
|
59 |
+
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
60 |
+
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
61 |
+
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
|
62 |
+
predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
|
63 |
+
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
|
64 |
+
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
65 |
+
return predicted_caption
|
66 |
+
|
67 |
+
|
68 |
+
def forward(self, images, target_captions):
|
69 |
+
|
70 |
+
batch_size = target_captions.size()[0]
|
71 |
+
target_length = target_captions.size()[1]
|
72 |
+
print("---", target_length)
|
73 |
+
|
74 |
+
# clip model output for image
|
75 |
+
clip_outputs = self.clip_model(**images) # See this for loading https://huggingface.co/openai/clip-vit-base-patch36
|
76 |
+
images = clip_outputs.last_hidden_state[:, 1:, :] # remove CLS token
|
77 |
+
|
78 |
+
# projection layer
|
79 |
+
image_embeddings = self.projection_layer(images).to(torch.float16)
|
80 |
+
|
81 |
+
# add comment token from phi2
|
82 |
+
img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
|
83 |
+
img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
|
84 |
+
combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1) # 4,49,2560
|
85 |
+
del clip_outputs
|
86 |
+
del image_embeddings
|
87 |
+
|
88 |
+
# for loss
|
89 |
+
loss = 0
|
90 |
+
for pos in range(target_length - 1):
|
91 |
+
|
92 |
+
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
93 |
+
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
94 |
+
pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), target_captions[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
|
95 |
+
loss += pos_loss
|
96 |
+
|
97 |
+
predicted_word_token = torch.argmax(predicted_word_token_logits, dim=-1)
|
98 |
+
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
|
99 |
+
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
100 |
+
loss = loss / target_length
|
101 |
+
|
102 |
+
# Delete variables to free up memory
|
103 |
+
del combined_embeds
|
104 |
+
del model_output_logits
|
105 |
+
torch.cuda.empty_cache()
|
106 |
+
|
107 |
+
return loss
|
108 |
+
|
109 |
+
|
110 |
+
def show_results_for_samples_phase1(model, val_dataloader, tokenizer, config, num_samples = 2):
|
111 |
+
model.eval()
|
112 |
+
with torch.no_grad():
|
113 |
+
for i in range(num_samples):
|
114 |
+
for images, target_captions in val_dataloader:
|
115 |
+
images = {'pixel_values': images.to(config.get('device'))}
|
116 |
+
target_captions = target_captions.to(config.get('device'))
|
117 |
+
target_captions_decoded = tokenizer.batch_decode(target_captions, ignore_index = tokenizer.eos_token_id)
|
118 |
+
predicted_captions = model.generate(images, tokenizer, config)
|
119 |
+
predicted_captions_decoded = tokenizer.batch_decode(predicted_captions,ignore_index = tokenizer.eos_token_id)
|
120 |
+
|
121 |
+
for idx, pc in enumerate(predicted_captions_decoded):
|
122 |
+
print(f"{idx} - Target captions: {target_captions_decoded[idx]} \n {'---------------------'*10} \n Predicted_captions:{pc} ")
|
123 |
+
break
|
124 |
+
|
125 |
+
|
126 |
+
def validate_model_phase1(model, val_dataloader, tokenizer, config):
|
127 |
+
model.eval()
|
128 |
+
total_loss = 0
|
129 |
+
with torch.no_grad():
|
130 |
+
try:
|
131 |
+
for images, target_captions in tqdm(val_dataloader):
|
132 |
+
images = {'pixel_values': images.to(config.get('device'))}
|
133 |
+
target_captions = target_captions.to(config.get('device'))
|
134 |
+
loss = model(images, target_captions)
|
135 |
+
total_loss+=loss.item()
|
136 |
+
print(f"Validation Loss: {total_loss/len(val_dataloader)}")
|
137 |
+
except Exception as e:
|
138 |
+
pass
|
139 |
+
model.train()
|
140 |
+
|
141 |
+
|
142 |
+
def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer, config):
|
143 |
+
model.train()
|
144 |
+
|
145 |
+
pbar = tqdm(train_loader)
|
146 |
+
for epoch in range(1, config.get("epochs")):
|
147 |
+
print(f"Epoch: {epoch}")
|
148 |
+
torch.cuda.empty_cache()
|
149 |
+
step = 1
|
150 |
+
try:
|
151 |
+
for idx, (images, target_captions) in enumerate(pbar):
|
152 |
+
try:
|
153 |
+
if target_captions.shape[1] >= config.get("max_tokens"):
|
154 |
+
# print(f"Skipping batch {idx} due to long caption")
|
155 |
+
continue
|
156 |
+
|
157 |
+
images = {'pixel_values': images.to(config.get('device'))}
|
158 |
+
target_captions = target_captions.to(config.get('device'))
|
159 |
+
|
160 |
+
optimizer.zero_grad()
|
161 |
+
loss = model(images, target_captions)
|
162 |
+
loss.backward()
|
163 |
+
optimizer.step()
|
164 |
+
pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
|
165 |
+
torch.cuda.empty_cache()
|
166 |
+
step+=1
|
167 |
+
if (step%1000==0):
|
168 |
+
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
|
169 |
+
except Exception as e:
|
170 |
+
print(e)
|
171 |
+
continue
|
172 |
+
|
173 |
+
# # save model
|
174 |
+
# if ((epoch % 2) == 0):
|
175 |
+
# Only save last checkpoint
|
176 |
+
validate_model_phase1(model, val_dataloader, tokenizer, config)
|
177 |
+
show_results_for_samples_phase1(model, val_dataloader, tokenizer, config)
|
178 |
+
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
print(e)
|
182 |
+
continue
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
######################################## Phase 2 #########################################
|
188 |
+
|
189 |
+
class MainQLoraModel(nn.Module):
|
190 |
+
def __init__(self, tokenizer, config):
|
191 |
+
super().__init__()
|
192 |
+
self.tokenizer = tokenizer
|
193 |
+
self.config = config
|
194 |
+
self.clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
|
195 |
+
|
196 |
+
bnb_config = BitsAndBytesConfig(
|
197 |
+
load_in_4bit=True,
|
198 |
+
bnb_4bit_quant_type="nf4",
|
199 |
+
bnb_4bit_compute_dtype=torch.float16,
|
200 |
+
)
|
201 |
+
|
202 |
+
phi2_model = AutoModelForCausalLM.from_pretrained(
|
203 |
+
config.get("phi2_model_name"),
|
204 |
+
quantization_config=bnb_config,
|
205 |
+
trust_remote_code=True
|
206 |
+
)
|
207 |
+
phi2_model.config.use_cache = False
|
208 |
+
|
209 |
+
## 4 - LORA config
|
210 |
+
|
211 |
+
lora_alpha = 16
|
212 |
+
lora_dropout = 0.1
|
213 |
+
lora_r = 64
|
214 |
+
|
215 |
+
peft_config = LoraConfig(
|
216 |
+
lora_alpha = lora_alpha,
|
217 |
+
lora_dropout = lora_dropout,
|
218 |
+
r = lora_r,
|
219 |
+
bias="none",
|
220 |
+
task_type="CAUSAL_LM",
|
221 |
+
target_modules=[
|
222 |
+
"q_proj",
|
223 |
+
"k_proj",
|
224 |
+
"v_proj",
|
225 |
+
"dense",
|
226 |
+
"fc1",
|
227 |
+
"fc2"
|
228 |
+
]
|
229 |
+
)
|
230 |
+
self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device"))
|
231 |
+
|
232 |
+
self.EOS_TOKEN_ID = self.tokenizer.eos_token_id
|
233 |
+
self.clip_embed = config.get("clip_embed")
|
234 |
+
self.phi_embed = config.get("phi_embed")
|
235 |
+
|
236 |
+
# projection layers
|
237 |
+
# Trainable projection layer
|
238 |
+
self.projection_layer = torch.nn.Linear(self.clip_embed, self.phi_embed)
|
239 |
+
|
240 |
+
# Freeze Weights
|
241 |
+
for models in [self.clip_model]:
|
242 |
+
for param in models.parameters():
|
243 |
+
param.requires_grad_(False)
|
244 |
+
|
245 |
+
# load checkpoint weights
|
246 |
+
if os.path.exists('./ckpts/model_phase2.pth'):
|
247 |
+
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
|
248 |
+
self.phi2_model.from_pretrained(self.phi2_model,'./ckpts/Qlora_adaptor')
|
249 |
+
print("Loaded checkpoint weights for projection layer")
|
250 |
+
else:
|
251 |
+
# Load weights from phase 1
|
252 |
+
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device")))
|
253 |
+
|
254 |
+
|
255 |
+
def generate(self, tokenizer, config, images = None, ques = None, max_tokens = 100):
|
256 |
+
batch_size = 1
|
257 |
+
|
258 |
+
predicted_caption = torch.full((batch_size, max_tokens), self.EOS_TOKEN_ID, dtype=torch.long, device=self.config.get('device'))
|
259 |
+
start_iq = self.tokenizer.encode("<iQ>")
|
260 |
+
end_iq = self.tokenizer.encode("</iQ>")
|
261 |
+
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
|
262 |
+
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
|
263 |
+
start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
|
264 |
+
end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
|
265 |
+
questions_embed = self.phi2_model.model.model.embed_tokens(ques)
|
266 |
+
if images is not None:
|
267 |
+
clip_outputs = self.clip_model(**images)
|
268 |
+
# remove cls token
|
269 |
+
images = clip_outputs.last_hidden_state[:, 1:, :]
|
270 |
+
image_embeddings = self.projection_layer(images).to(torch.float16)
|
271 |
+
combined_embeds = torch.cat([start_iq_embeds, image_embeddings, questions_embed, end_iq_embeds], dim=1)
|
272 |
+
else:
|
273 |
+
combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds], dim=1)
|
274 |
+
|
275 |
+
for pos in range(max_tokens - 1):
|
276 |
+
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
277 |
+
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
278 |
+
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
|
279 |
+
predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
|
280 |
+
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
|
281 |
+
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
282 |
+
return predicted_caption
|
283 |
+
|
284 |
+
|
285 |
+
def forward(self, images, ques, ans):
|
286 |
+
|
287 |
+
batch_size = ques.size()[0]
|
288 |
+
questions = ques.to(self.config.get("device"))
|
289 |
+
answers = ans.to(self.config.get("device"))
|
290 |
+
target_length = ans.size()[1]
|
291 |
+
start_iq = self.tokenizer.encode("<iQ>")
|
292 |
+
end_iq = self.tokenizer.encode("</iQ>")
|
293 |
+
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
|
294 |
+
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
|
295 |
+
start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
|
296 |
+
end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
|
297 |
+
|
298 |
+
questions_embed = self.phi2_model.model.model.embed_tokens(questions)
|
299 |
+
answers_embed = self.phi2_model.model.model.embed_tokens(answers)
|
300 |
+
|
301 |
+
are_all_zeros = torch.all(images == 0).item()
|
302 |
+
if are_all_zeros:
|
303 |
+
combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
|
304 |
+
else:
|
305 |
+
images = {'pixel_values': images.to(self.config.get("device"))}
|
306 |
+
clip_outputs = self.clip_model(**images)
|
307 |
+
images_embeds = clip_outputs.last_hidden_state[:,1:,:] # remove cls token
|
308 |
+
|
309 |
+
# projection
|
310 |
+
image_embeds = self.projection_layer(images_embeds).to(torch.float16)
|
311 |
+
combined_embeds = torch.cat([start_iq_embeds, image_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
|
312 |
+
|
313 |
+
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
314 |
+
# # for loss
|
315 |
+
loss = 0
|
316 |
+
for pos in range(target_length - 1):
|
317 |
+
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
318 |
+
pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), answers[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
|
319 |
+
loss += pos_loss
|
320 |
+
loss = loss / target_length
|
321 |
+
|
322 |
+
# Delete variables to free up memory
|
323 |
+
del combined_embeds
|
324 |
+
del model_output_logits
|
325 |
+
torch.cuda.empty_cache()
|
326 |
+
return loss
|
327 |
+
|
328 |
+
def validate_model_phase2(model, val_dataloader, tokenizer, config):
|
329 |
+
model.eval()
|
330 |
+
total_loss = 0
|
331 |
+
with torch.no_grad():
|
332 |
+
# try:
|
333 |
+
for images, ques, ans in tqdm(val_dataloader):
|
334 |
+
loss = model(images, ques, ans)
|
335 |
+
total_loss+=loss.item()
|
336 |
+
print(f"Validation Loss: {total_loss/len(val_dataloader)}")
|
337 |
+
# except Exception as e:
|
338 |
+
# pass
|
339 |
+
model.train()
|
340 |
+
|
341 |
+
|
342 |
+
def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config):
|
343 |
+
phi2_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.phi2_model.parameters()), lr=1e-5)
|
344 |
+
proj_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.projection_layer.parameters()), lr=1e-5)
|
345 |
+
model.phi2_model.train()
|
346 |
+
model.projection_layer.train()
|
347 |
+
|
348 |
+
pbar = tqdm(train_loader)
|
349 |
+
for epoch in range(1, config.get("epochs")):
|
350 |
+
print(f"Epoch: {epoch}")
|
351 |
+
torch.cuda.empty_cache()
|
352 |
+
step = 1
|
353 |
+
try:
|
354 |
+
for idx, (images, ques, ans) in enumerate(pbar):
|
355 |
+
try:
|
356 |
+
phi2_optim.zero_grad()
|
357 |
+
proj_optim.zero_grad()
|
358 |
+
loss = model(images, ques, ans)
|
359 |
+
loss.backward()
|
360 |
+
phi2_optim.step()
|
361 |
+
proj_optim.step()
|
362 |
+
pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
|
363 |
+
torch.cuda.empty_cache()
|
364 |
+
step+=1
|
365 |
+
if (step%1000==0):
|
366 |
+
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
|
367 |
+
model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
|
368 |
+
except Exception as e:
|
369 |
+
print("in frp",e)
|
370 |
+
continue
|
371 |
+
|
372 |
+
validate_model_phase2(model, val_dataloader, tokenizer, config)
|
373 |
+
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
|
374 |
+
model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
|
375 |
+
|
376 |
+
except Exception as e:
|
377 |
+
print(e)
|
378 |
+
continue
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
git+https://github.com/huggingface/peft.git
|
4 |
+
accelerate
|
5 |
+
transformers
|
6 |
+
einops
|
7 |
+
git+https://github.com/m-bain/whisperx.git
|