Edit model card
import os

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (AdamW, AutoModelForCausalLM, AutoProcessor,
                          get_scheduler)

from data import ObjectDetectionDataset

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model and processor
# model = AutoModelForCausalLM.from_pretrained("model/Florence-2-base-ft", trust_remote_code=True).to(device)
# processor = AutoProcessor.from_pretrained("model/Florence-2-base-ft", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", revision="refs/pr/10", trust_remote_code=True, device_map="cuda") # load the model on GPU
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", revision="refs/pr/10", trust_remote_code=True)

IGNORE_ID = -100 # Pytorch ignore index when computing loss
MAX_LENGTH = 512

def collate_fn(examples):
    prompt_texts = [example[0] for example in examples]
    label_texts = [example[1] for example in examples]
    images = [example[2] for example in examples]

    inputs = processor(
        images=images,
        text=prompt_texts,
        return_tensors="pt",
        padding="longest",
        max_length=MAX_LENGTH,
    ).to(device)

    return inputs, label_texts


# Create datasets
train_dataset = ObjectDetectionDataset("train", processor=processor)
val_dataset = ObjectDetectionDataset("test", processor=processor)

# Create DataLoader
batch_size = 4
num_workers = 0

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn,
    num_workers=num_workers,
    shuffle=True,
)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers
)


def train_model(train_loader, val_loader, model, processor, epochs=10, lr=1e-6):
    optimizer = AdamW(model.parameters(), lr=lr)
    num_training_steps = epochs * len(train_loader)
    lr_scheduler = get_scheduler(
        name="cosine",
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=num_training_steps,
    )

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0
        i = -1
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):
            i += 1
            inputs, label_texts = batch

            labels = processor.tokenizer(
                label_texts,
                return_tensors="pt",
                padding="longest",
                max_length=MAX_LENGTH,
                return_token_type_ids=False, # no need to set this to True since BART does not use token type ids
            )["input_ids"].to(device)

            labels[labels == processor.tokenizer.pad_token_id] = IGNORE_ID # do not learn to predict pad tokens during training

            input_ids = inputs["input_ids"]
            pixel_values = inputs["pixel_values"]

            outputs = model(
                input_ids=input_ids, pixel_values=pixel_values, labels=labels
            )
            loss = outputs.loss

            if i % 25 == 0:
                print(loss)

                generated_ids = model.generate(
                    input_ids=inputs["input_ids"],
                    pixel_values=inputs["pixel_values"],
                    max_new_tokens=128,
                    early_stopping=False,
                    do_sample=False,
                    num_beams=3,
                )
                generated_texts = processor.batch_decode(
                    generated_ids, skip_special_tokens=False
                )

                for generated_text, answer in zip(generated_texts, label_texts):
                    parsed_answer = processor.post_process_generation(
                        generated_text,
                        task="<OD>",
                        image_size=(
                            inputs["pixel_values"].shape[-2],
                            inputs["pixel_values"].shape[-1],
                        ),
                    )
                    print("GT:", answer)
                    print("Generated Text:", generated_text)
                    print("Pred:", parsed_answer["<OD>"])

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f"Average Training Loss: {avg_train_loss}")

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in tqdm(
                val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"
            ):
                inputs, labels = batch

                input_ids = inputs["input_ids"]
                pixel_values = inputs["pixel_values"]
                labels = processor.tokenizer(
                    text=labels,
                    return_tensors="pt",
                    padding=True,
                    return_token_type_ids=False,
                ).input_ids.to(device)

                outputs = model(
                    input_ids=input_ids, pixel_values=pixel_values, labels=labels
                )
                loss = outputs.loss

                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"Average Validation Loss: {avg_val_loss}")

        # Save model checkpoint
        output_dir = f"./model_checkpoints/epoch_{epoch+1}"
        os.makedirs(output_dir, exist_ok=True)
        model.save_pretrained(output_dir)
        processor.save_pretrained(output_dir)

for param in model.vision_tower.parameters():
    param.requires_grad = False

model_total_params = sum(p.numel() for p in model.parameters())
model_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Number of trainable parameters {model_train_params} out of {model_total_params}, rate: {model_train_params/model_total_params:0.3f}")

train_model(train_loader, val_loader, model, processor, epochs=3, lr=1e-6)

model.push_to_hub("danelcsb/Florence-2-FT-cppe-5")
processor.push_to_hub("danelcsb/Florence-2-FT-cppe-5")
Downloads last month
10
Safetensors
Model size
823M params
Tensor type
F32
·
Inference Examples
Inference API (serverless) does not yet support model repos that contain custom code.