|
import pandas as pd |
|
from tqdm import tqdm |
|
from mivolo.model.mivolo_model import MiVOLOModel |
|
from torchvision.transforms.functional import to_pil_image |
|
from torch.utils.data import DataLoader, TensorDataset |
|
import matplotlib.pyplot as plt |
|
from PIL import Image, ImageDraw, ImageFont |
|
from sklearn.model_selection import train_test_split |
|
import torch.optim as optim |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from timm.models._helpers import load_state_dict |
|
from PIL import Image |
|
import os |
|
import torchvision.transforms as transforms |
|
import json |
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
MEAN_TRAIN = 36.64 |
|
STD_TRAIN = 21.74 |
|
|
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
from PIL import Image |
|
import pandas as pd |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, csv_data, test=False, transform=None): |
|
self.data = csv_data |
|
self.transform = transform |
|
self.test = test |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
img_path = "/home/duyht/MiVOLO/MiVOLO/lag_benchmark/" + self.data.iloc[idx]['img_name'] if self.test==False else self.data.iloc[idx]['img_name'] |
|
basename = os.path.basename(img_path) |
|
|
|
image = Image.open(img_path).convert('RGB') |
|
|
|
|
|
face_x0, face_y0, face_x1, face_y1 = self.data.iloc[idx][['face_x0', 'face_y0', 'face_x1', 'face_y1']] |
|
person_x0, person_y0, person_x1, person_y1 = self.data.iloc[idx][['person_x0', 'person_y0', 'person_x1', 'person_y1']] |
|
|
|
|
|
face_image = image.crop((int(face_x0), int(face_y0), int(face_x1), int(face_y1))) |
|
|
|
person_image = image.crop((int(person_x0), int(person_y0), int(person_x1), int(person_y1))) |
|
|
|
|
|
face_image = face_image.resize((224, 224)) |
|
person_image = person_image.resize((224, 224)) |
|
|
|
if self.transform: |
|
face_image = self.transform(face_image) |
|
person_image = self.transform(person_image) |
|
|
|
|
|
image_ = torch.cat((face_image, person_image), dim=0) |
|
|
|
y_label = eval(self.data.iloc[idx]['y_label']) |
|
y1, y2, y3 = y_label |
|
|
|
|
|
y3 = (y3 - MEAN_TRAIN) / STD_TRAIN |
|
|
|
y_label = (y1, y2, y3) |
|
y_label = torch.tensor(y_label, dtype=torch.float32) |
|
|
|
return image_, y_label, self.data.iloc[idx]['img_name'] if self.test==False else basename |
|
|
|
|
|
transform_train = transforms.Compose([ |
|
transforms.RandAugment(magnitude=22), |
|
transforms.RandomHorizontalFlip(p=0.5), |
|
transforms.RandomApply([transforms.ColorJitter()], p=0.5), |
|
transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.8, 1.2)), |
|
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
transform_valid = transforms.Compose([ |
|
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
def denormalize(image, mean, std): |
|
mean = torch.tensor(mean).reshape(3, 1, 1) |
|
std = torch.tensor(std).reshape(3, 1, 1) |
|
image = image * std + mean |
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_data = pd.read_csv('csv/data_train.csv') |
|
val_data = pd.read_csv('csv/data_valid.csv') |
|
test_data = pd.read_csv('csv/data_test.csv') |
|
|
|
kid_data = pd.read_csv('csv/children_test.csv') |
|
|
|
|
|
|
|
train_dataset = CustomDataset(train_data, transform=transform_train) |
|
val_dataset = CustomDataset(val_data, transform=transform_valid) |
|
test_dataset = CustomDataset(test_data, transform=transform_valid) |
|
kid_dataset = CustomDataset(kid_data, test=True, transform=transform_valid) |
|
|
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=True, num_workers=4) |
|
val_dataloader = DataLoader(val_dataset, batch_size=50, shuffle=False, num_workers=4) |
|
test_dataloader = DataLoader(test_dataset, batch_size=50, shuffle=False, num_workers=4) |
|
kid_dataloader = DataLoader(kid_dataset, batch_size=50, shuffle=False, num_workers=4) |
|
|
|
|
|
|
|
model = MiVOLOModel( |
|
layers=(4, 4, 8, 2), |
|
img_size=224, |
|
in_chans=6, |
|
num_classes=3, |
|
patch_size=8, |
|
stem_hidden_dim=64, |
|
embed_dims=(192, 384, 384, 384), |
|
num_heads=(6, 12, 12, 12), |
|
).to('cuda') |
|
|
|
state = torch.load("models/model_imdb_cross_person_4.22_99.46.pth.tar", map_location="cpu") |
|
state_dict = state["state_dict"] |
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
|
|
criterion_bce = nn.BCEWithLogitsLoss() |
|
criterion_mse = nn.MSELoss() |
|
|
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=1.0e-6, weight_decay=5e-6) |
|
|
|
|
|
|
|
num_epochs = 50 |
|
best_val_loss = float('inf') |
|
|
|
stop_training = False |
|
def get_optimizer_info(optimizer): |
|
for param_group in optimizer.param_groups: |
|
lr = param_group['lr'] |
|
return f"LR: {lr}" |
|
|
|
|
|
for epoch in range(num_epochs): |
|
model.train() |
|
running_loss = 0.0 |
|
if stop_training: |
|
break |
|
|
|
train_dataloader = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") |
|
|
|
for i, (inputs, labels, _) in enumerate(train_dataloader): |
|
inputs = inputs.to('cuda') |
|
labels = [label.to('cuda') for label in labels] |
|
|
|
optimizer.zero_grad() |
|
batch_loss = 0 |
|
for j in range(inputs.size(0)): |
|
input_image = inputs[j].unsqueeze(0) |
|
target = labels[j].unsqueeze(0) |
|
|
|
output = model(input_image) |
|
gender_output = output[:, :2].softmax(dim=-1) |
|
|
|
output_bce = output[:, :2] |
|
target_bce = target[:, :2] |
|
output_mse = output[:, 2] |
|
target_mse = target[:, 2] |
|
true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN |
|
loss_bce = criterion_bce(gender_output, target_bce) |
|
loss_mse = criterion_mse(output_mse, target_mse) |
|
|
|
loss = loss_bce + loss_mse |
|
batch_loss += loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.isnan(loss_bce).any() or torch.isnan(loss_mse).any() or torch.isnan(loss).any(): |
|
print(f'Epoch [{epoch + 1}], Batch [{i + 1}] - NaN detected in loss computation') |
|
stop_training = True |
|
break |
|
|
|
if stop_training: |
|
break |
|
|
|
optimizer.zero_grad() |
|
batch_loss /= inputs.size(0) |
|
|
|
optimizer_info = get_optimizer_info(optimizer) |
|
train_dataloader.set_postfix(batch_loss=batch_loss.item(), optimizer_info=optimizer_info) |
|
|
|
batch_loss.backward() |
|
optimizer.step() |
|
|
|
|
|
model.eval() |
|
val_loss = 0.0 |
|
|
|
val_dataloader = tqdm(val_dataloader, desc="Validating", unit="batch") |
|
|
|
with torch.no_grad(): |
|
for i, (inputs, labels, _) in enumerate(val_dataloader): |
|
inputs = inputs.to('cuda') |
|
labels = labels.to('cuda') |
|
|
|
for j in range(inputs.size(0)): |
|
input_image = inputs[j].unsqueeze(0) |
|
target = labels[j].unsqueeze(0) |
|
output = model(input_image) |
|
gender_output = output[:, :2].softmax(dim=-1) |
|
|
|
output_bce = output[:, :2] |
|
target_bce = target[:, :2] |
|
output_mse = output[:, 2] |
|
target_mse = target[:, 2] |
|
|
|
loss_bce = criterion_bce(gender_output, target_bce) |
|
loss_mse = criterion_mse(output_mse, target_mse) |
|
true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN |
|
loss = loss_bce + loss_mse |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val_loss += loss.item() |
|
|
|
val_loss /= len(val_dataloader) |
|
print(f'Epoch [{epoch + 1}], Validation Loss: {val_loss:.4f}') |
|
|
|
|
|
if val_loss < best_val_loss: |
|
best_val_loss = val_loss |
|
torch.save(model.state_dict(), 'modelstrain/best_model_weights_10.pth') |
|
print(f'Saved best model weights with validation loss: {best_val_loss:.4f}') |
|
|
|
print('Finished Training') |
|
|
|
|
|
|
|
model.load_state_dict(torch.load('modelstrain/best_model_weights_10.pth')) |
|
model.eval() |
|
test_loss = 0.0 |
|
correct_gender = 0 |
|
total = 0 |
|
|
|
|
|
def tensor_to_image_with_text(tensor, true_age, predicted_age): |
|
unloader = transforms.ToPILImage() |
|
image = unloader(tensor.cpu().squeeze(0)) |
|
|
|
draw = ImageDraw.Draw(image) |
|
font = ImageFont.load_default() |
|
|
|
text_true = f'True Age: {true_age:.2f}' |
|
text_predicted = f'Predicted Age: {predicted_age:.2f}' |
|
|
|
|
|
text_position_true = (10, 10) |
|
text_position_predicted = (10, 30) |
|
|
|
|
|
bbox_true = draw.textbbox(text_position_true, text_true, font=font) |
|
bbox_predicted = draw.textbbox(text_position_predicted, text_predicted, font=font) |
|
|
|
|
|
draw.rectangle(bbox_true, fill="white") |
|
draw.rectangle(bbox_predicted, fill="white") |
|
|
|
|
|
draw.text(text_position_true, text_true, font=font, fill="green") |
|
draw.text(text_position_predicted, text_predicted, font=font, fill="blue") |
|
|
|
return image |
|
|
|
save_dir = 'children_test' |
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
true_ages = [] |
|
predicted_ages = [] |
|
with torch.no_grad(): |
|
test_loss = 0.0 |
|
correct_gender = 0 |
|
total = 0 |
|
|
|
image_data = [] |
|
|
|
|
|
try: |
|
with open('image_data.json', 'r') as json_file: |
|
image_data = json.load(json_file) |
|
except FileNotFoundError: |
|
|
|
image_data = [] |
|
|
|
for i, (inputs, labels, img_paths) in tqdm(enumerate(kid_dataloader), total=len(kid_dataloader), desc="Processing batches"): |
|
inputs = inputs.to('cuda') |
|
labels = labels.to('cuda') |
|
for j in range(inputs.size(0)): |
|
input_image = inputs[j].unsqueeze(0) |
|
print("input_image: ", input_image.shape) |
|
target = labels[j].unsqueeze(0) |
|
target_image = denormalize(input_image[:,3:].to('cpu'), [*IMAGENET_DEFAULT_MEAN], [*IMAGENET_DEFAULT_STD]) |
|
output = model(input_image) |
|
print("output[:, :2]: ", output[:, :2]) |
|
gender_output = output[:, :2].softmax(dim=-1) |
|
print("gender_output: ", gender_output) |
|
output_bce = output[:, :2] |
|
target_bce = target[:, :2] |
|
output_mse = output[:, 2] |
|
target_mse = target[:, 2] |
|
|
|
|
|
|
|
|
|
|
|
predicted_age = output_mse.item() *STD_TRAIN + MEAN_TRAIN |
|
true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN |
|
true_ages.append(true_age) |
|
predicted_ages.append(predicted_age) |
|
|
|
|
|
loss_bce = criterion_bce(output_bce, target_bce) |
|
loss_mse = criterion_mse(output_mse, target_mse) |
|
loss = loss_bce + loss_mse |
|
|
|
test_loss += loss.item() |
|
_, predicted_gender = torch.max(gender_output, 1) |
|
print("predicted_gender: ", predicted_gender) |
|
_, target_gender = torch.max(target_bce, 1) |
|
correct_gender += (predicted_gender == target_gender).sum().item() |
|
total += target_gender.size(0) |
|
|
|
|
|
target_image_pil = tensor_to_image_with_text(target_image, true_age, predicted_age) |
|
if predicted_age >=15: |
|
print(img_paths[j], " ", predicted_age) |
|
|
|
|
|
image_path = os.path.join(save_dir, f'{img_paths[j]}') |
|
|
|
|
|
|
|
|
|
image_data.append({"path": img_paths[j], "predicted_age": predicted_age, "predicted_gender": predicted_gender.item()}) |
|
target_image_pil.save(image_path) |
|
|
|
|
|
with open('image_data.json', 'w') as json_file: |
|
json.dump(image_data, json_file, indent=4) |
|
test_loss /= len(test_dataloader) |
|
gender_accuracy = correct_gender / total |
|
print(f'Test Loss: {test_loss:.4f}, Gender Accuracy: {gender_accuracy:.4f}') |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.scatter(true_ages, predicted_ages, c='blue', label='Predicted Age') |
|
plt.plot([min(true_ages), max(true_ages)], [min(true_ages), max(true_ages)], color='red', linestyle='--', label='Perfect Prediction') |
|
plt.xlabel('True Age') |
|
plt.ylabel('Predicted Age') |
|
plt.title('True Age vs Predicted Age') |
|
plt.legend() |
|
plt.grid(True) |
|
plt.savefig('age_prediction_comparison.png') |
|
plt.close() |