File size: 5,014 Bytes
44fc622 e235434 44fc622 e235434 44fc622 e235434 44fc622 e235434 44fc622 e235434 44fc622 e235434 44fc622 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# Copyright (C) 2024 Ronan Le Meillat
# License: Apache License 2.0
# Description: Train the model on the dataset
import os
import sys
import torch
from huggingface_hub import login as hf_login
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration, TrainingArguments, Trainer
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()
HF_TOKEN = ""
arguments = sys.argv[1:]
if os.environ.get('HF_TOKEN') is not None:
HF_TOKEN = os.environ.get('HF_TOKEN')
print(f"Hugging Face token found in environment variable")
# If HF_TOKEN is empty checks if the first argument seems to be the token (ie starts with "hf_" )
if not HF_TOKEN and arguments and arguments[0].startswith("hf_"):
HF_TOKEN = arguments[0]
print(f"Hugging Face token found in script arguments")
hf_login(
token=HF_TOKEN,
add_to_git_credential=True
)
dataset_id = "eltorio/ROCO-radiology"
prompt= "You are an expert radiologist certified with over 15 years of experience in diagnostic imaging, describe this image"
source_model_id = "HuggingFaceM4/Idefics3-8B-Llama3"
destination_model_id = "eltorio/IDEFICS3_ROCOv2"
output_dir = "IDEFICS3_ROCOv2"
cache_dir = "/workspace/data"
full_dataset = load_dataset(dataset_id,keep_in_memory=False)
train_dataset = full_dataset["train"]
eval_dataset = full_dataset["validation"]
DEVICE = "cuda:0"
USE_LORA = False
USE_QLORA = True
processor = AutoProcessor.from_pretrained(
source_model_id,
do_image_splitting=False
)
if USE_QLORA or USE_LORA:
lora_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0.1,
target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$',
use_dora=False if USE_QLORA else True,
init_lora_weights="gaussian"
)
if USE_QLORA:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
model = Idefics3ForConditionalGeneration.from_pretrained(
source_model_id,
torch_dtype=torch.float16,
quantization_config=bnb_config if USE_QLORA else None,
)
model.add_adapter(lora_config)
model.enable_adapters()
else:
model = Idefics3ForConditionalGeneration.from_pretrained(
source_model_id,
torch_dtype=torch.float16,
_attn_implementation="flash_attention_2", # This works for A100 or H100
).to(DEVICE)
class MyDataCollator:
def __init__(self, processor):
self.processor = processor
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index("<image>")
]
def __call__(self, samples):
texts = []
images = []
for sample in samples:
image = sample["image"]
answer = sample["caption"]
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": prompt}
]
},
{
"role": "user",
"content": [
{"type": "image"},
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer}
]
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=False)
texts.append(text.strip())
images.append([image.convert('RGB')])
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
batch["labels"] = labels
return batch
data_collator = MyDataCollator(processor)
training_args = TrainingArguments(
output_dir = output_dir,
overwrite_output_dir = False,
auto_find_batch_size = True,
learning_rate = 2e-4,
fp16 = True,
per_device_train_batch_size = 2,
per_device_eval_batch_size = 2,
gradient_accumulation_steps = 8,
dataloader_pin_memory = False,
save_total_limit = 3,
eval_strategy = "steps",
save_strategy = "steps",
eval_steps = 100,
save_steps = 10, # checkpoint each 10 steps
resume_from_checkpoint = True,
logging_steps = 5,
remove_unused_columns = False,
push_to_hub = True,
label_names = ["labels"],
load_best_model_at_end = False,
report_to = "none",
optim = "paged_adamw_8bit",
)
trainer = Trainer(
model = model,
args = training_args,
data_collator = data_collator,
train_dataset = train_dataset,
eval_dataset = train_dataset,
)
trainer.train()
|