estimate_age / models /train.py
hungdang1610's picture
train.py
87d78a8 verified
import pandas as pd
from tqdm import tqdm
from mivolo.model.mivolo_model import MiVOLOModel
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from sklearn.model_selection import train_test_split
import torch.optim as optim
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from timm.models._helpers import load_state_dict
from PIL import Image
import os
import torchvision.transforms as transforms
import json
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
MEAN_TRAIN = 36.64
STD_TRAIN = 21.74
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import torch.nn as nn
import torchvision.transforms as transforms
import pandas as pd
# Định nghĩa dataset tùy chỉnh
class CustomDataset(Dataset):
def __init__(self, csv_data, test=False, transform=None):
self.data = csv_data
self.transform = transform
self.test = test
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path = "/home/duyht/MiVOLO/MiVOLO/lag_benchmark/" + self.data.iloc[idx]['img_name'] if self.test==False else self.data.iloc[idx]['img_name']
basename = os.path.basename(img_path)
# print("img_path: ", img_path)
image = Image.open(img_path).convert('RGB')
# Lấy tọa độ từ dataframe
face_x0, face_y0, face_x1, face_y1 = self.data.iloc[idx][['face_x0', 'face_y0', 'face_x1', 'face_y1']]
person_x0, person_y0, person_x1, person_y1 = self.data.iloc[idx][['person_x0', 'person_y0', 'person_x1', 'person_y1']]
# Cắt ảnh theo các tọa độ
face_image = image.crop((int(face_x0), int(face_y0), int(face_x1), int(face_y1)))
person_image = image.crop((int(person_x0), int(person_y0), int(person_x1), int(person_y1)))
# Resize ảnh về (224, 224)
face_image = face_image.resize((224, 224))
person_image = person_image.resize((224, 224))
if self.transform:
face_image = self.transform(face_image)
person_image = self.transform(person_image)
image_ = torch.cat((face_image, person_image), dim=0)
y_label = eval(self.data.iloc[idx]['y_label']) # assuming y_label is a string representation of a list
y1, y2, y3 = y_label
# y3 = (y3 - 48.0) / (95 - 1)
# y3 = (y3 - 36.77) / 21.6
y3 = (y3 - MEAN_TRAIN) / STD_TRAIN
# y_label = (y2, y1, y3)
y_label = (y1, y2, y3)
y_label = torch.tensor(y_label, dtype=torch.float32)
return image_, y_label, self.data.iloc[idx]['img_name'] if self.test==False else basename
transform_train = transforms.Compose([
transforms.RandAugment(magnitude=22),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter()], p=0.5),
transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.8, 1.2)),
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_valid = transforms.Compose([
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def denormalize(image, mean, std):
mean = torch.tensor(mean).reshape(3, 1, 1)
std = torch.tensor(std).reshape(3, 1, 1)
image = image * std + mean
return image
# Đọc dữ liệu từ các file CSV đã tách
train_data = pd.read_csv('csv/data_train.csv')
val_data = pd.read_csv('csv/data_valid.csv')
test_data = pd.read_csv('csv/data_test.csv')
kid_data = pd.read_csv('csv/children_test.csv')
# Tạo dataset cho train, validation và test
train_dataset = CustomDataset(train_data, transform=transform_train)
val_dataset = CustomDataset(val_data, transform=transform_valid)
test_dataset = CustomDataset(test_data, transform=transform_valid)
kid_dataset = CustomDataset(kid_data, test=True, transform=transform_valid)
# Tạo dataloader cho train, validation và test
train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=50, shuffle=False, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=50, shuffle=False, num_workers=4)
kid_dataloader = DataLoader(kid_dataset, batch_size=50, shuffle=False, num_workers=4)
# Khởi tạo mô hình và các thành phần khác
model = MiVOLOModel(
layers=(4, 4, 8, 2),
img_size=224,
in_chans=6,
num_classes=3,
patch_size=8,
stem_hidden_dim=64,
embed_dims=(192, 384, 384, 384),
num_heads=(6, 12, 12, 12),
).to('cuda')
state = torch.load("models/model_imdb_cross_person_4.22_99.46.pth.tar", map_location="cpu")
state_dict = state["state_dict"]
model.load_state_dict(state_dict, strict=True)
# state = torch.load("modelstrain/best_model_weights_10.pth", map_location="cpu")
# model.load_state_dict(state, strict=True)
criterion_bce = nn.BCEWithLogitsLoss()
criterion_mse = nn.MSELoss()
# Khởi tạo optimizer với AdamW và scheduler
optimizer = optim.AdamW(model.parameters(), lr=1.0e-6, weight_decay=5e-6)
# Huấn luyện mô hình
num_epochs = 50
best_val_loss = float('inf')
# best_val_loss = 39.2124
stop_training = False
def get_optimizer_info(optimizer):
for param_group in optimizer.param_groups:
lr = param_group['lr']
return f"LR: {lr}"
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
if stop_training:
break
train_dataloader = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch")
for i, (inputs, labels, _) in enumerate(train_dataloader):
inputs = inputs.to('cuda')
labels = [label.to('cuda') for label in labels]
optimizer.zero_grad()
batch_loss = 0
for j in range(inputs.size(0)):
input_image = inputs[j].unsqueeze(0)
target = labels[j].unsqueeze(0)
output = model(input_image)
gender_output = output[:, :2].softmax(dim=-1)
output_bce = output[:, :2]
target_bce = target[:, :2]
output_mse = output[:, 2]
target_mse = target[:, 2]
true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN
loss_bce = criterion_bce(gender_output, target_bce)
loss_mse = criterion_mse(output_mse, target_mse)
loss = loss_bce + loss_mse
batch_loss += loss
# loss = loss_mse
# if true_age >=1:
# batch_loss += loss
# else:
# batch_loss+=loss_mse
if torch.isnan(loss_bce).any() or torch.isnan(loss_mse).any() or torch.isnan(loss).any():
print(f'Epoch [{epoch + 1}], Batch [{i + 1}] - NaN detected in loss computation')
stop_training = True
break
if stop_training:
break
optimizer.zero_grad()
batch_loss /= inputs.size(0)
# print("batch_loss: ", batch_loss)
optimizer_info = get_optimizer_info(optimizer)
train_dataloader.set_postfix(batch_loss=batch_loss.item(), optimizer_info=optimizer_info)
batch_loss.backward()
optimizer.step()
# Tính toán validation loss sau mỗi epoch
model.eval()
val_loss = 0.0
val_dataloader = tqdm(val_dataloader, desc="Validating", unit="batch")
with torch.no_grad():
for i, (inputs, labels, _) in enumerate(val_dataloader):
inputs = inputs.to('cuda')
labels = labels.to('cuda')
for j in range(inputs.size(0)):
input_image = inputs[j].unsqueeze(0)
target = labels[j].unsqueeze(0)
output = model(input_image)
gender_output = output[:, :2].softmax(dim=-1)
output_bce = output[:, :2]
target_bce = target[:, :2]
output_mse = output[:, 2]
target_mse = target[:, 2]
loss_bce = criterion_bce(gender_output, target_bce)
loss_mse = criterion_mse(output_mse, target_mse)
true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN
loss = loss_bce + loss_mse
# if true_age >=1:
# loss = loss_bce + loss_mse
# else:
# loss = loss_mse
# loss = loss_mse
val_loss += loss.item()
val_loss /= len(val_dataloader)
print(f'Epoch [{epoch + 1}], Validation Loss: {val_loss:.4f}')
# Lưu lại trọng số tốt nhất
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'modelstrain/best_model_weights_10.pth')
print(f'Saved best model weights with validation loss: {best_val_loss:.4f}')
print('Finished Training')
####################################### Đánh giá mô hình trên tập test ###################################################3
model.load_state_dict(torch.load('modelstrain/best_model_weights_10.pth'))
model.eval()
test_loss = 0.0
correct_gender = 0
total = 0
def tensor_to_image_with_text(tensor, true_age, predicted_age):
unloader = transforms.ToPILImage()
image = unloader(tensor.cpu().squeeze(0))
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
text_true = f'True Age: {true_age:.2f}'
text_predicted = f'Predicted Age: {predicted_age:.2f}'
# Text positions
text_position_true = (10, 10)
text_position_predicted = (10, 30)
# Calculate bounding box for the text
bbox_true = draw.textbbox(text_position_true, text_true, font=font)
bbox_predicted = draw.textbbox(text_position_predicted, text_predicted, font=font)
# Draw white rectangles behind the text
draw.rectangle(bbox_true, fill="white")
draw.rectangle(bbox_predicted, fill="white")
# Draw the text on top of the rectangles
draw.text(text_position_true, text_true, font=font, fill="green")
draw.text(text_position_predicted, text_predicted, font=font, fill="blue")
return image
save_dir = 'children_test'
# save_dir_under18 = 'saved_images_under18'
os.makedirs(save_dir, exist_ok=True)
# os.makedirs(save_dir_under18, exist_ok=True)
true_ages = []
predicted_ages = []
with torch.no_grad():
test_loss = 0.0
correct_gender = 0
total = 0
# Initialize lists to store paths and ages
image_data = []
# Load existing data from the JSON file if it exists
try:
with open('image_data.json', 'r') as json_file:
image_data = json.load(json_file)
except FileNotFoundError:
# If the file does not exist, start with an empty list
image_data = []
# for i, (inputs, labels) in enumerate(test_dataloader):
for i, (inputs, labels, img_paths) in tqdm(enumerate(kid_dataloader), total=len(kid_dataloader), desc="Processing batches"):
inputs = inputs.to('cuda')
labels = labels.to('cuda')
for j in range(inputs.size(0)):
input_image = inputs[j].unsqueeze(0)
print("input_image: ", input_image.shape)
target = labels[j].unsqueeze(0)
target_image = denormalize(input_image[:,3:].to('cpu'), [*IMAGENET_DEFAULT_MEAN], [*IMAGENET_DEFAULT_STD])
output = model(input_image)
print("output[:, :2]: ", output[:, :2])
gender_output = output[:, :2].softmax(dim=-1)
print("gender_output: ", gender_output)
output_bce = output[:, :2]
target_bce = target[:, :2]
output_mse = output[:, 2]
target_mse = target[:, 2]
# y3 = (y3 - 36.77) / 21.6
# predicted_age = output_mse.item() * (95 - 1) + 48.0
# true_age = target_mse.item() * (95 - 1) + 48.0
# predicted_age = output_mse.item() *21.6 + 36.77 - 1.0
# true_age = target_mse.item() *21.6 + 36.77
predicted_age = output_mse.item() *STD_TRAIN + MEAN_TRAIN
true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN
true_ages.append(true_age)
predicted_ages.append(predicted_age)
# Compute losses
loss_bce = criterion_bce(output_bce, target_bce)
loss_mse = criterion_mse(output_mse, target_mse)
loss = loss_bce + loss_mse
test_loss += loss.item()
_, predicted_gender = torch.max(gender_output, 1)
print("predicted_gender: ", predicted_gender)
_, target_gender = torch.max(target_bce, 1)
correct_gender += (predicted_gender == target_gender).sum().item()
total += target_gender.size(0)
# Convert to PIL image and add text
target_image_pil = tensor_to_image_with_text(target_image, true_age, predicted_age)
if predicted_age >=15:
print(img_paths[j], " ", predicted_age)
# Save the image
image_path = os.path.join(save_dir, f'{img_paths[j]}')
# if true_age < 18:
# image_path = os.path.join(save_dir_under18, f'{img_paths[j]}')
# else:
# image_path = os.path.join(save_dir, f'{img_paths[j]}')
image_data.append({"path": img_paths[j], "predicted_age": predicted_age, "predicted_gender": predicted_gender.item()})
target_image_pil.save(image_path)
# Save the data to a JSON file
with open('image_data.json', 'w') as json_file:
json.dump(image_data, json_file, indent=4)
test_loss /= len(test_dataloader)
gender_accuracy = correct_gender / total
print(f'Test Loss: {test_loss:.4f}, Gender Accuracy: {gender_accuracy:.4f}')
# Plotting true ages vs. predicted ages and save the plot
plt.figure(figsize=(10, 6))
plt.scatter(true_ages, predicted_ages, c='blue', label='Predicted Age')
plt.plot([min(true_ages), max(true_ages)], [min(true_ages), max(true_ages)], color='red', linestyle='--', label='Perfect Prediction')
plt.xlabel('True Age')
plt.ylabel('Predicted Age')
plt.title('True Age vs Predicted Age')
plt.legend()
plt.grid(True)
plt.savefig('age_prediction_comparison.png')
plt.close()