Spaces:
Runtime error
Runtime error
import torch | |
from torch.utils.data import Dataset | |
from datasets import load_dataset | |
from processor import MultiModalProcessor | |
from load_model import load_hf_model | |
from transformers import Trainer, TrainingArguments | |
from dataclasses import dataclass, field | |
from typing import List | |
class LoraConfig: | |
r: int = 8 | |
lora_alpha: int = 16 | |
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) | |
lora_dropout: float = 0.05 | |
bias: str = "none" | |
task_type: str = "CAUSAL_LM" | |
def __post_init__(self): | |
self.inference_mode = False | |
self.r = {} | |
self.lora_alpha = {} | |
self.scaling = {} | |
self.lora_dropout = {} | |
for key in self.target_modules: | |
self.r[key] = self.r | |
self.lora_alpha[key] = self.lora_alpha | |
self.scaling[key] = self.lora_alpha[key] / self.r[key] | |
self.lora_dropout[key] = self.lora_dropout | |
class LoraLinear(torch.nn.Module): | |
def __init__(self, in_features, out_features, config: LoraConfig): | |
super().__init__() | |
self.linear = torch.nn.Linear(in_features, out_features, bias=False) | |
self.lora_A = torch.nn.Parameter(torch.zeros((config.r, in_features))) | |
self.lora_B = torch.nn.Parameter(torch.zeros((out_features, config.r))) | |
self.scaling = config.scaling | |
self.dropout = torch.nn.Dropout(p=config.lora_dropout) | |
def forward(self, x): | |
result = self.linear(x) | |
lora_output = (self.dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling | |
return result + lora_output | |
def apply_lora_to_model(model, config: LoraConfig): | |
for name, module in model.named_modules(): | |
if any(target in name for target in config.target_modules): | |
if isinstance(module, torch.nn.Linear): | |
lora_module = LoraLinear(module.in_features, module.out_features, config) | |
lora_module.linear.weight.data = module.weight.data | |
if module.bias is not None: | |
lora_module.linear.bias = module.bias | |
setattr(model, name, lora_module) | |
return model | |
# Load the dataset | |
ds = load_dataset('HuggingFaceM4/VQAv2', split="train[:10%]") | |
cols_remove = ["question_type", "answers", "answer_type", "image_id", "question_id"] | |
ds = ds.remove_columns(cols_remove) | |
# Create a small test split | |
split_ds = ds.train_test_split(test_size=0.05) | |
train_ds = split_ds["train"] | |
test_ds = split_ds["test"] | |
print(train_ds) | |
print(test_ds) | |
# Load the model and processor | |
model_id = "./paligemma-3b-pt-224" | |
model, tokenizer = load_hf_model(model_id, "cuda") | |
processor = MultiModalProcessor(tokenizer, model.config.vision_config.num_image_tokens, model.config.vision_config.image_size) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
# Apply LoRA to the model | |
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"], lora_dropout=0.05) | |
model = apply_lora_to_model(model, lora_config) | |
# Define a custom dataset | |
class PaliGemmaDataset(Dataset): | |
def __init__(self, dataset, processor): | |
self.dataset = dataset | |
self.processor = processor | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
item = self.dataset[idx] | |
prompt = "answer " + item["question"] | |
image = item["image"].convert("RGB") | |
answer = item["multiple_choice_answer"] | |
# Process inputs | |
inputs = self.processor(text=[prompt], images=[image]) | |
# Process labels | |
label_inputs = self.processor(text=[answer], images=[image]) | |
labels = label_inputs['input_ids'][0] | |
# Set the labels to -100 for the input part (we don't want to compute loss on it) | |
inputs['labels'] = torch.full_like(inputs['input_ids'][0], -100) | |
inputs['labels'][-len(labels):] = torch.tensor(labels) | |
return inputs | |
# Create datasets | |
train_dataset = PaliGemmaDataset(train_ds, processor) | |
eval_dataset = PaliGemmaDataset(test_ds, processor) | |
# Define a custom data collator | |
def custom_data_collator(features): | |
batch = { | |
'pixel_values': torch.stack([f['pixel_values'][0] for f in features]), | |
'input_ids': torch.stack([f['input_ids'][0] for f in features]), | |
'attention_mask': torch.stack([f['attention_mask'][0] for f in features]), | |
'labels': torch.stack([f['labels'] for f in features]) | |
} | |
return batch | |
# Define training arguments | |
training_args = TrainingArguments( | |
output_dir="./results", | |
num_train_epochs=3, | |
per_device_train_batch_size=4, | |
per_device_eval_batch_size=4, | |
warmup_steps=500, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_steps=10, | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
load_best_model_at_end=True, | |
) | |
# Initialize the Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
data_collator=custom_data_collator, | |
) | |
# Fine-tune the model | |
trainer.train() | |
# Save the fine-tuned model | |
trainer.save_model("lora_paligemma_vqa") | |
# Function to save LoRA weights separately | |
def save_lora_weights(model, path): | |
lora_state_dict = {} | |
for name, module in model.named_modules(): | |
if isinstance(module, LoraLinear): | |
lora_state_dict[f"{name}.lora_A"] = module.lora_A.data | |
lora_state_dict[f"{name}.lora_B"] = module.lora_B.data | |
torch.save(lora_state_dict, path) | |
# Save LoRA weights | |
save_lora_weights(model, "lora_weights.pt") | |
print("Fine-tuning completed and model saved.") |