IDEFICS3_ROCOv2 / learn.py
none
wip
e235434
raw
history blame
5.01 kB
# Copyright (C) 2024 Ronan Le Meillat
# License: Apache License 2.0
# Description: Train the model on the dataset
import os
import sys
import torch
from huggingface_hub import login as hf_login
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration, TrainingArguments, Trainer
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()
HF_TOKEN = ""
arguments = sys.argv[1:]
if os.environ.get('HF_TOKEN') is not None:
HF_TOKEN = os.environ.get('HF_TOKEN')
print(f"Hugging Face token found in environment variable")
# If HF_TOKEN is empty checks if the first argument seems to be the token (ie starts with "hf_" )
if not HF_TOKEN and arguments and arguments[0].startswith("hf_"):
HF_TOKEN = arguments[0]
print(f"Hugging Face token found in script arguments")
hf_login(
token=HF_TOKEN,
add_to_git_credential=True
)
dataset_id = "eltorio/ROCO-radiology"
prompt= "You are an expert radiologist certified with over 15 years of experience in diagnostic imaging, describe this image"
source_model_id = "HuggingFaceM4/Idefics3-8B-Llama3"
destination_model_id = "eltorio/IDEFICS3_ROCOv2"
output_dir = "IDEFICS3_ROCOv2"
cache_dir = "/workspace/data"
full_dataset = load_dataset(dataset_id,keep_in_memory=False)
train_dataset = full_dataset["train"]
eval_dataset = full_dataset["validation"]
DEVICE = "cuda:0"
USE_LORA = False
USE_QLORA = True
processor = AutoProcessor.from_pretrained(
source_model_id,
do_image_splitting=False
)
if USE_QLORA or USE_LORA:
lora_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0.1,
target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$',
use_dora=False if USE_QLORA else True,
init_lora_weights="gaussian"
)
if USE_QLORA:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
model = Idefics3ForConditionalGeneration.from_pretrained(
source_model_id,
torch_dtype=torch.float16,
quantization_config=bnb_config if USE_QLORA else None,
)
model.add_adapter(lora_config)
model.enable_adapters()
else:
model = Idefics3ForConditionalGeneration.from_pretrained(
source_model_id,
torch_dtype=torch.float16,
_attn_implementation="flash_attention_2", # This works for A100 or H100
).to(DEVICE)
class MyDataCollator:
def __init__(self, processor):
self.processor = processor
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index("<image>")
]
def __call__(self, samples):
texts = []
images = []
for sample in samples:
image = sample["image"]
answer = sample["caption"]
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": prompt}
]
},
{
"role": "user",
"content": [
{"type": "image"},
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer}
]
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=False)
texts.append(text.strip())
images.append([image.convert('RGB')])
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
batch["labels"] = labels
return batch
data_collator = MyDataCollator(processor)
training_args = TrainingArguments(
output_dir = output_dir,
overwrite_output_dir = False,
auto_find_batch_size = True,
learning_rate = 2e-4,
fp16 = True,
per_device_train_batch_size = 2,
per_device_eval_batch_size = 2,
gradient_accumulation_steps = 8,
dataloader_pin_memory = False,
save_total_limit = 3,
eval_strategy = "steps",
save_strategy = "steps",
eval_steps = 100,
save_steps = 10, # checkpoint each 10 steps
resume_from_checkpoint = True,
logging_steps = 5,
remove_unused_columns = False,
push_to_hub = True,
label_names = ["labels"],
load_best_model_at_end = False,
report_to = "none",
optim = "paged_adamw_8bit",
)
trainer = Trainer(
model = model,
args = training_args,
data_collator = data_collator,
train_dataset = train_dataset,
eval_dataset = train_dataset,
)
trainer.train()