hungdang1610 commited on
Commit
87d78a8
1 Parent(s): efe06d3
Files changed (1) hide show
  1. models/train.py +396 -0
models/train.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from tqdm import tqdm
3
+ from mivolo.model.mivolo_model import MiVOLOModel
4
+ from torchvision.transforms.functional import to_pil_image
5
+ from torch.utils.data import DataLoader, TensorDataset
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ from sklearn.model_selection import train_test_split
9
+ import torch.optim as optim
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from timm.models._helpers import load_state_dict
14
+ from PIL import Image
15
+ import os
16
+ import torchvision.transforms as transforms
17
+ import json
18
+
19
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
20
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
21
+ MEAN_TRAIN = 36.64
22
+ STD_TRAIN = 21.74
23
+
24
+ import torch
25
+ from torch.utils.data import Dataset, DataLoader
26
+ from PIL import Image
27
+ import pandas as pd
28
+ import torch.nn as nn
29
+ import torchvision.transforms as transforms
30
+
31
+ import pandas as pd
32
+
33
+
34
+
35
+ # Định nghĩa dataset tùy chỉnh
36
+ class CustomDataset(Dataset):
37
+ def __init__(self, csv_data, test=False, transform=None):
38
+ self.data = csv_data
39
+ self.transform = transform
40
+ self.test = test
41
+
42
+ def __len__(self):
43
+ return len(self.data)
44
+
45
+ def __getitem__(self, idx):
46
+ img_path = "/home/duyht/MiVOLO/MiVOLO/lag_benchmark/" + self.data.iloc[idx]['img_name'] if self.test==False else self.data.iloc[idx]['img_name']
47
+ basename = os.path.basename(img_path)
48
+ # print("img_path: ", img_path)
49
+ image = Image.open(img_path).convert('RGB')
50
+
51
+ # Lấy tọa độ từ dataframe
52
+ face_x0, face_y0, face_x1, face_y1 = self.data.iloc[idx][['face_x0', 'face_y0', 'face_x1', 'face_y1']]
53
+ person_x0, person_y0, person_x1, person_y1 = self.data.iloc[idx][['person_x0', 'person_y0', 'person_x1', 'person_y1']]
54
+
55
+ # Cắt ảnh theo các tọa độ
56
+ face_image = image.crop((int(face_x0), int(face_y0), int(face_x1), int(face_y1)))
57
+
58
+ person_image = image.crop((int(person_x0), int(person_y0), int(person_x1), int(person_y1)))
59
+
60
+ # Resize ảnh về (224, 224)
61
+ face_image = face_image.resize((224, 224))
62
+ person_image = person_image.resize((224, 224))
63
+
64
+ if self.transform:
65
+ face_image = self.transform(face_image)
66
+ person_image = self.transform(person_image)
67
+
68
+
69
+ image_ = torch.cat((face_image, person_image), dim=0)
70
+
71
+ y_label = eval(self.data.iloc[idx]['y_label']) # assuming y_label is a string representation of a list
72
+ y1, y2, y3 = y_label
73
+ # y3 = (y3 - 48.0) / (95 - 1)
74
+ # y3 = (y3 - 36.77) / 21.6
75
+ y3 = (y3 - MEAN_TRAIN) / STD_TRAIN
76
+ # y_label = (y2, y1, y3)
77
+ y_label = (y1, y2, y3)
78
+ y_label = torch.tensor(y_label, dtype=torch.float32)
79
+
80
+ return image_, y_label, self.data.iloc[idx]['img_name'] if self.test==False else basename
81
+
82
+
83
+ transform_train = transforms.Compose([
84
+ transforms.RandAugment(magnitude=22),
85
+ transforms.RandomHorizontalFlip(p=0.5),
86
+ transforms.RandomApply([transforms.ColorJitter()], p=0.5),
87
+ transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.8, 1.2)),
88
+ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
89
+ transforms.ToTensor(),
90
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
91
+ ])
92
+
93
+ transform_valid = transforms.Compose([
94
+ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
95
+ transforms.ToTensor(),
96
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
97
+ ])
98
+
99
+ def denormalize(image, mean, std):
100
+ mean = torch.tensor(mean).reshape(3, 1, 1)
101
+ std = torch.tensor(std).reshape(3, 1, 1)
102
+ image = image * std + mean
103
+ return image
104
+
105
+
106
+
107
+
108
+
109
+
110
+ # Đọc dữ liệu từ các file CSV đã tách
111
+ train_data = pd.read_csv('csv/data_train.csv')
112
+ val_data = pd.read_csv('csv/data_valid.csv')
113
+ test_data = pd.read_csv('csv/data_test.csv')
114
+
115
+ kid_data = pd.read_csv('csv/children_test.csv')
116
+
117
+
118
+ # Tạo dataset cho train, validation và test
119
+ train_dataset = CustomDataset(train_data, transform=transform_train)
120
+ val_dataset = CustomDataset(val_data, transform=transform_valid)
121
+ test_dataset = CustomDataset(test_data, transform=transform_valid)
122
+ kid_dataset = CustomDataset(kid_data, test=True, transform=transform_valid)
123
+
124
+ # Tạo dataloader cho train, validation và test
125
+ train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=True, num_workers=4)
126
+ val_dataloader = DataLoader(val_dataset, batch_size=50, shuffle=False, num_workers=4)
127
+ test_dataloader = DataLoader(test_dataset, batch_size=50, shuffle=False, num_workers=4)
128
+ kid_dataloader = DataLoader(kid_dataset, batch_size=50, shuffle=False, num_workers=4)
129
+
130
+
131
+ # Khởi tạo mô hình và các thành phần khác
132
+ model = MiVOLOModel(
133
+ layers=(4, 4, 8, 2),
134
+ img_size=224,
135
+ in_chans=6,
136
+ num_classes=3,
137
+ patch_size=8,
138
+ stem_hidden_dim=64,
139
+ embed_dims=(192, 384, 384, 384),
140
+ num_heads=(6, 12, 12, 12),
141
+ ).to('cuda')
142
+
143
+ state = torch.load("models/model_imdb_cross_person_4.22_99.46.pth.tar", map_location="cpu")
144
+ state_dict = state["state_dict"]
145
+ model.load_state_dict(state_dict, strict=True)
146
+ # state = torch.load("modelstrain/best_model_weights_10.pth", map_location="cpu")
147
+ # model.load_state_dict(state, strict=True)
148
+
149
+ criterion_bce = nn.BCEWithLogitsLoss()
150
+ criterion_mse = nn.MSELoss()
151
+
152
+
153
+ # Khởi tạo optimizer với AdamW và scheduler
154
+ optimizer = optim.AdamW(model.parameters(), lr=1.0e-6, weight_decay=5e-6)
155
+
156
+
157
+ # Huấn luyện mô hình
158
+ num_epochs = 50
159
+ best_val_loss = float('inf')
160
+ # best_val_loss = 39.2124
161
+ stop_training = False
162
+ def get_optimizer_info(optimizer):
163
+ for param_group in optimizer.param_groups:
164
+ lr = param_group['lr']
165
+ return f"LR: {lr}"
166
+
167
+
168
+ for epoch in range(num_epochs):
169
+ model.train()
170
+ running_loss = 0.0
171
+ if stop_training:
172
+ break
173
+
174
+ train_dataloader = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch")
175
+
176
+ for i, (inputs, labels, _) in enumerate(train_dataloader):
177
+ inputs = inputs.to('cuda')
178
+ labels = [label.to('cuda') for label in labels]
179
+
180
+ optimizer.zero_grad()
181
+ batch_loss = 0
182
+ for j in range(inputs.size(0)):
183
+ input_image = inputs[j].unsqueeze(0)
184
+ target = labels[j].unsqueeze(0)
185
+
186
+ output = model(input_image)
187
+ gender_output = output[:, :2].softmax(dim=-1)
188
+
189
+ output_bce = output[:, :2]
190
+ target_bce = target[:, :2]
191
+ output_mse = output[:, 2]
192
+ target_mse = target[:, 2]
193
+ true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN
194
+ loss_bce = criterion_bce(gender_output, target_bce)
195
+ loss_mse = criterion_mse(output_mse, target_mse)
196
+
197
+ loss = loss_bce + loss_mse
198
+ batch_loss += loss
199
+ # loss = loss_mse
200
+ # if true_age >=1:
201
+ # batch_loss += loss
202
+ # else:
203
+ # batch_loss+=loss_mse
204
+
205
+
206
+ if torch.isnan(loss_bce).any() or torch.isnan(loss_mse).any() or torch.isnan(loss).any():
207
+ print(f'Epoch [{epoch + 1}], Batch [{i + 1}] - NaN detected in loss computation')
208
+ stop_training = True
209
+ break
210
+
211
+ if stop_training:
212
+ break
213
+
214
+ optimizer.zero_grad()
215
+ batch_loss /= inputs.size(0)
216
+ # print("batch_loss: ", batch_loss)
217
+ optimizer_info = get_optimizer_info(optimizer)
218
+ train_dataloader.set_postfix(batch_loss=batch_loss.item(), optimizer_info=optimizer_info)
219
+
220
+ batch_loss.backward()
221
+ optimizer.step()
222
+
223
+ # Tính toán validation loss sau mỗi epoch
224
+ model.eval()
225
+ val_loss = 0.0
226
+
227
+ val_dataloader = tqdm(val_dataloader, desc="Validating", unit="batch")
228
+
229
+ with torch.no_grad():
230
+ for i, (inputs, labels, _) in enumerate(val_dataloader):
231
+ inputs = inputs.to('cuda')
232
+ labels = labels.to('cuda')
233
+
234
+ for j in range(inputs.size(0)):
235
+ input_image = inputs[j].unsqueeze(0)
236
+ target = labels[j].unsqueeze(0)
237
+ output = model(input_image)
238
+ gender_output = output[:, :2].softmax(dim=-1)
239
+
240
+ output_bce = output[:, :2]
241
+ target_bce = target[:, :2]
242
+ output_mse = output[:, 2]
243
+ target_mse = target[:, 2]
244
+
245
+ loss_bce = criterion_bce(gender_output, target_bce)
246
+ loss_mse = criterion_mse(output_mse, target_mse)
247
+ true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN
248
+ loss = loss_bce + loss_mse
249
+ # if true_age >=1:
250
+ # loss = loss_bce + loss_mse
251
+ # else:
252
+ # loss = loss_mse
253
+
254
+ # loss = loss_mse
255
+
256
+ val_loss += loss.item()
257
+
258
+ val_loss /= len(val_dataloader)
259
+ print(f'Epoch [{epoch + 1}], Validation Loss: {val_loss:.4f}')
260
+
261
+ # Lưu lại trọng số tốt nhất
262
+ if val_loss < best_val_loss:
263
+ best_val_loss = val_loss
264
+ torch.save(model.state_dict(), 'modelstrain/best_model_weights_10.pth')
265
+ print(f'Saved best model weights with validation loss: {best_val_loss:.4f}')
266
+
267
+ print('Finished Training')
268
+
269
+
270
+ ####################################### Đánh giá mô hình trên tập test ###################################################3
271
+ model.load_state_dict(torch.load('modelstrain/best_model_weights_10.pth'))
272
+ model.eval()
273
+ test_loss = 0.0
274
+ correct_gender = 0
275
+ total = 0
276
+
277
+
278
+ def tensor_to_image_with_text(tensor, true_age, predicted_age):
279
+ unloader = transforms.ToPILImage()
280
+ image = unloader(tensor.cpu().squeeze(0))
281
+
282
+ draw = ImageDraw.Draw(image)
283
+ font = ImageFont.load_default()
284
+
285
+ text_true = f'True Age: {true_age:.2f}'
286
+ text_predicted = f'Predicted Age: {predicted_age:.2f}'
287
+
288
+ # Text positions
289
+ text_position_true = (10, 10)
290
+ text_position_predicted = (10, 30)
291
+
292
+ # Calculate bounding box for the text
293
+ bbox_true = draw.textbbox(text_position_true, text_true, font=font)
294
+ bbox_predicted = draw.textbbox(text_position_predicted, text_predicted, font=font)
295
+
296
+ # Draw white rectangles behind the text
297
+ draw.rectangle(bbox_true, fill="white")
298
+ draw.rectangle(bbox_predicted, fill="white")
299
+
300
+ # Draw the text on top of the rectangles
301
+ draw.text(text_position_true, text_true, font=font, fill="green")
302
+ draw.text(text_position_predicted, text_predicted, font=font, fill="blue")
303
+
304
+ return image
305
+
306
+ save_dir = 'children_test'
307
+ # save_dir_under18 = 'saved_images_under18'
308
+ os.makedirs(save_dir, exist_ok=True)
309
+ # os.makedirs(save_dir_under18, exist_ok=True)
310
+ true_ages = []
311
+ predicted_ages = []
312
+ with torch.no_grad():
313
+ test_loss = 0.0
314
+ correct_gender = 0
315
+ total = 0
316
+ # Initialize lists to store paths and ages
317
+ image_data = []
318
+
319
+ # Load existing data from the JSON file if it exists
320
+ try:
321
+ with open('image_data.json', 'r') as json_file:
322
+ image_data = json.load(json_file)
323
+ except FileNotFoundError:
324
+ # If the file does not exist, start with an empty list
325
+ image_data = []
326
+ # for i, (inputs, labels) in enumerate(test_dataloader):
327
+ for i, (inputs, labels, img_paths) in tqdm(enumerate(kid_dataloader), total=len(kid_dataloader), desc="Processing batches"):
328
+ inputs = inputs.to('cuda')
329
+ labels = labels.to('cuda')
330
+ for j in range(inputs.size(0)):
331
+ input_image = inputs[j].unsqueeze(0)
332
+ print("input_image: ", input_image.shape)
333
+ target = labels[j].unsqueeze(0)
334
+ target_image = denormalize(input_image[:,3:].to('cpu'), [*IMAGENET_DEFAULT_MEAN], [*IMAGENET_DEFAULT_STD])
335
+ output = model(input_image)
336
+ print("output[:, :2]: ", output[:, :2])
337
+ gender_output = output[:, :2].softmax(dim=-1)
338
+ print("gender_output: ", gender_output)
339
+ output_bce = output[:, :2]
340
+ target_bce = target[:, :2]
341
+ output_mse = output[:, 2]
342
+ target_mse = target[:, 2]
343
+ # y3 = (y3 - 36.77) / 21.6
344
+ # predicted_age = output_mse.item() * (95 - 1) + 48.0
345
+ # true_age = target_mse.item() * (95 - 1) + 48.0
346
+ # predicted_age = output_mse.item() *21.6 + 36.77 - 1.0
347
+ # true_age = target_mse.item() *21.6 + 36.77
348
+ predicted_age = output_mse.item() *STD_TRAIN + MEAN_TRAIN
349
+ true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN
350
+ true_ages.append(true_age)
351
+ predicted_ages.append(predicted_age)
352
+
353
+ # Compute losses
354
+ loss_bce = criterion_bce(output_bce, target_bce)
355
+ loss_mse = criterion_mse(output_mse, target_mse)
356
+ loss = loss_bce + loss_mse
357
+
358
+ test_loss += loss.item()
359
+ _, predicted_gender = torch.max(gender_output, 1)
360
+ print("predicted_gender: ", predicted_gender)
361
+ _, target_gender = torch.max(target_bce, 1)
362
+ correct_gender += (predicted_gender == target_gender).sum().item()
363
+ total += target_gender.size(0)
364
+
365
+ # Convert to PIL image and add text
366
+ target_image_pil = tensor_to_image_with_text(target_image, true_age, predicted_age)
367
+ if predicted_age >=15:
368
+ print(img_paths[j], " ", predicted_age)
369
+
370
+ # Save the image
371
+ image_path = os.path.join(save_dir, f'{img_paths[j]}')
372
+ # if true_age < 18:
373
+ # image_path = os.path.join(save_dir_under18, f'{img_paths[j]}')
374
+ # else:
375
+ # image_path = os.path.join(save_dir, f'{img_paths[j]}')
376
+ image_data.append({"path": img_paths[j], "predicted_age": predicted_age, "predicted_gender": predicted_gender.item()})
377
+ target_image_pil.save(image_path)
378
+ # Save the data to a JSON file
379
+
380
+ with open('image_data.json', 'w') as json_file:
381
+ json.dump(image_data, json_file, indent=4)
382
+ test_loss /= len(test_dataloader)
383
+ gender_accuracy = correct_gender / total
384
+ print(f'Test Loss: {test_loss:.4f}, Gender Accuracy: {gender_accuracy:.4f}')
385
+
386
+ # Plotting true ages vs. predicted ages and save the plot
387
+ plt.figure(figsize=(10, 6))
388
+ plt.scatter(true_ages, predicted_ages, c='blue', label='Predicted Age')
389
+ plt.plot([min(true_ages), max(true_ages)], [min(true_ages), max(true_ages)], color='red', linestyle='--', label='Perfect Prediction')
390
+ plt.xlabel('True Age')
391
+ plt.ylabel('Predicted Age')
392
+ plt.title('True Age vs Predicted Age')
393
+ plt.legend()
394
+ plt.grid(True)
395
+ plt.savefig('age_prediction_comparison.png')
396
+ plt.close()