# Copyright (C) 2024 Ronan Le Meillat # License: Apache License 2.0 # Description: Train the model on the dataset import os import sys import torch from huggingface_hub import login as hf_login from datasets import load_dataset from peft import LoraConfig from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration, TrainingArguments, Trainer from datasets.utils.logging import disable_progress_bar disable_progress_bar() HF_TOKEN = "" arguments = sys.argv[1:] if os.environ.get('HF_TOKEN') is not None: HF_TOKEN = os.environ.get('HF_TOKEN') print(f"Hugging Face token found in environment variable") # If HF_TOKEN is empty checks if the first argument seems to be the token (ie starts with "hf_" ) if not HF_TOKEN and arguments and arguments[0].startswith("hf_"): HF_TOKEN = arguments[0] print(f"Hugging Face token found in script arguments") hf_login( token=HF_TOKEN, add_to_git_credential=True ) dataset_id = "eltorio/ROCO-radiology" prompt= "You are an expert radiologist certified with over 15 years of experience in diagnostic imaging, describe this image" source_model_id = "HuggingFaceM4/Idefics3-8B-Llama3" destination_model_id = "eltorio/IDEFICS3_ROCOv2" output_dir = "IDEFICS3_ROCOv2" cache_dir = "/workspace/data" full_dataset = load_dataset(dataset_id,keep_in_memory=False) train_dataset = full_dataset["train"] eval_dataset = full_dataset["validation"] DEVICE = "cuda:0" USE_LORA = False USE_QLORA = True processor = AutoProcessor.from_pretrained( source_model_id, do_image_splitting=False ) if USE_QLORA or USE_LORA: lora_config = LoraConfig( r=8, lora_alpha=8, lora_dropout=0.1, target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$', use_dora=False if USE_QLORA else True, init_lora_weights="gaussian" ) if USE_QLORA: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16 ) model = Idefics3ForConditionalGeneration.from_pretrained( source_model_id, torch_dtype=torch.float16, quantization_config=bnb_config if USE_QLORA else None, ) model.add_adapter(lora_config) model.enable_adapters() else: model = Idefics3ForConditionalGeneration.from_pretrained( source_model_id, torch_dtype=torch.float16, _attn_implementation="flash_attention_2", # This works for A100 or H100 ).to(DEVICE) class MyDataCollator: def __init__(self, processor): self.processor = processor self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ processor.tokenizer.additional_special_tokens.index("") ] def __call__(self, samples): texts = [] images = [] for sample in samples: image = sample["image"] answer = sample["caption"] messages = [ { "role": "system", "content": [ {"type": "text", "text": prompt} ] }, { "role": "user", "content": [ {"type": "image"}, ] }, { "role": "assistant", "content": [ {"type": "text", "text": answer} ] } ] text = processor.apply_chat_template(messages, add_generation_prompt=False) texts.append(text.strip()) images.append([image.convert('RGB')]) batch = processor(text=texts, images=images, return_tensors="pt", padding=True) labels = batch["input_ids"].clone() labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id batch["labels"] = labels return batch data_collator = MyDataCollator(processor) training_args = TrainingArguments( output_dir = output_dir, overwrite_output_dir = False, auto_find_batch_size = True, learning_rate = 2e-4, fp16 = True, per_device_train_batch_size = 2, per_device_eval_batch_size = 2, gradient_accumulation_steps = 8, dataloader_pin_memory = False, save_total_limit = 3, eval_strategy = "steps", save_strategy = "steps", eval_steps = 100, save_steps = 10, # checkpoint each 10 steps resume_from_checkpoint = True, logging_steps = 5, remove_unused_columns = False, push_to_hub = True, label_names = ["labels"], load_best_model_at_end = False, report_to = "none", optim = "paged_adamw_8bit", ) trainer = Trainer( model = model, args = training_args, data_collator = data_collator, train_dataset = train_dataset, eval_dataset = train_dataset, ) trainer.train()