import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (AdamW, AutoModelForCausalLM, AutoProcessor,
get_scheduler)
from data import ObjectDetectionDataset
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model and processor
# model = AutoModelForCausalLM.from_pretrained("model/Florence-2-base-ft", trust_remote_code=True).to(device)
# processor = AutoProcessor.from_pretrained("model/Florence-2-base-ft", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", revision="refs/pr/10", trust_remote_code=True, device_map="cuda") # load the model on GPU
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", revision="refs/pr/10", trust_remote_code=True)
IGNORE_ID = -100 # Pytorch ignore index when computing loss
MAX_LENGTH = 512
def collate_fn(examples):
prompt_texts = [example[0] for example in examples]
label_texts = [example[1] for example in examples]
images = [example[2] for example in examples]
inputs = processor(
images=images,
text=prompt_texts,
return_tensors="pt",
padding="longest",
max_length=MAX_LENGTH,
).to(device)
return inputs, label_texts
# Create datasets
train_dataset = ObjectDetectionDataset("train", processor=processor)
val_dataset = ObjectDetectionDataset("test", processor=processor)
# Create DataLoader
batch_size = 4
num_workers = 0
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=num_workers,
shuffle=True,
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers
)
def train_model(train_loader, val_loader, model, processor, epochs=10, lr=1e-6):
optimizer = AdamW(model.parameters(), lr=lr)
num_training_steps = epochs * len(train_loader)
lr_scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=num_training_steps,
)
for epoch in range(epochs):
# Training phase
model.train()
train_loss = 0
i = -1
for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):
i += 1
inputs, label_texts = batch
labels = processor.tokenizer(
label_texts,
return_tensors="pt",
padding="longest",
max_length=MAX_LENGTH,
return_token_type_ids=False, # no need to set this to True since BART does not use token type ids
)["input_ids"].to(device)
labels[labels == processor.tokenizer.pad_token_id] = IGNORE_ID # do not learn to predict pad tokens during training
input_ids = inputs["input_ids"]
pixel_values = inputs["pixel_values"]
outputs = model(
input_ids=input_ids, pixel_values=pixel_values, labels=labels
)
loss = outputs.loss
if i % 25 == 0:
print(loss)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=128,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_texts = processor.batch_decode(
generated_ids, skip_special_tokens=False
)
for generated_text, answer in zip(generated_texts, label_texts):
parsed_answer = processor.post_process_generation(
generated_text,
task="<OD>",
image_size=(
inputs["pixel_values"].shape[-2],
inputs["pixel_values"].shape[-1],
),
)
print("GT:", answer)
print("Generated Text:", generated_text)
print("Pred:", parsed_answer["<OD>"])
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
train_loss += loss.item()
avg_train_loss = train_loss / len(train_loader)
print(f"Average Training Loss: {avg_train_loss}")
# Validation phase
model.eval()
val_loss = 0
with torch.no_grad():
for batch in tqdm(
val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"
):
inputs, labels = batch
input_ids = inputs["input_ids"]
pixel_values = inputs["pixel_values"]
labels = processor.tokenizer(
text=labels,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
).input_ids.to(device)
outputs = model(
input_ids=input_ids, pixel_values=pixel_values, labels=labels
)
loss = outputs.loss
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
print(f"Average Validation Loss: {avg_val_loss}")
# Save model checkpoint
output_dir = f"./model_checkpoints/epoch_{epoch+1}"
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
for param in model.vision_tower.parameters():
param.requires_grad = False
model_total_params = sum(p.numel() for p in model.parameters())
model_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters {model_train_params} out of {model_total_params}, rate: {model_train_params/model_total_params:0.3f}")
train_model(train_loader, val_loader, model, processor, epochs=3, lr=1e-6)
model.push_to_hub("danelcsb/Florence-2-FT-cppe-5")
processor.push_to_hub("danelcsb/Florence-2-FT-cppe-5")
- Downloads last month
- 13
Inference API (serverless) does not yet support model repos that contain custom code.