hungdang1610
commited on
Commit
•
87d78a8
1
Parent(s):
efe06d3
train.py
Browse files- models/train.py +396 -0
models/train.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from tqdm import tqdm
|
3 |
+
from mivolo.model.mivolo_model import MiVOLOModel
|
4 |
+
from torchvision.transforms.functional import to_pil_image
|
5 |
+
from torch.utils.data import DataLoader, TensorDataset
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from PIL import Image, ImageDraw, ImageFont
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
import torch.optim as optim
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.utils.data import Dataset, DataLoader
|
13 |
+
from timm.models._helpers import load_state_dict
|
14 |
+
from PIL import Image
|
15 |
+
import os
|
16 |
+
import torchvision.transforms as transforms
|
17 |
+
import json
|
18 |
+
|
19 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
20 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
21 |
+
MEAN_TRAIN = 36.64
|
22 |
+
STD_TRAIN = 21.74
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from torch.utils.data import Dataset, DataLoader
|
26 |
+
from PIL import Image
|
27 |
+
import pandas as pd
|
28 |
+
import torch.nn as nn
|
29 |
+
import torchvision.transforms as transforms
|
30 |
+
|
31 |
+
import pandas as pd
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
# Định nghĩa dataset tùy chỉnh
|
36 |
+
class CustomDataset(Dataset):
|
37 |
+
def __init__(self, csv_data, test=False, transform=None):
|
38 |
+
self.data = csv_data
|
39 |
+
self.transform = transform
|
40 |
+
self.test = test
|
41 |
+
|
42 |
+
def __len__(self):
|
43 |
+
return len(self.data)
|
44 |
+
|
45 |
+
def __getitem__(self, idx):
|
46 |
+
img_path = "/home/duyht/MiVOLO/MiVOLO/lag_benchmark/" + self.data.iloc[idx]['img_name'] if self.test==False else self.data.iloc[idx]['img_name']
|
47 |
+
basename = os.path.basename(img_path)
|
48 |
+
# print("img_path: ", img_path)
|
49 |
+
image = Image.open(img_path).convert('RGB')
|
50 |
+
|
51 |
+
# Lấy tọa độ từ dataframe
|
52 |
+
face_x0, face_y0, face_x1, face_y1 = self.data.iloc[idx][['face_x0', 'face_y0', 'face_x1', 'face_y1']]
|
53 |
+
person_x0, person_y0, person_x1, person_y1 = self.data.iloc[idx][['person_x0', 'person_y0', 'person_x1', 'person_y1']]
|
54 |
+
|
55 |
+
# Cắt ảnh theo các tọa độ
|
56 |
+
face_image = image.crop((int(face_x0), int(face_y0), int(face_x1), int(face_y1)))
|
57 |
+
|
58 |
+
person_image = image.crop((int(person_x0), int(person_y0), int(person_x1), int(person_y1)))
|
59 |
+
|
60 |
+
# Resize ảnh về (224, 224)
|
61 |
+
face_image = face_image.resize((224, 224))
|
62 |
+
person_image = person_image.resize((224, 224))
|
63 |
+
|
64 |
+
if self.transform:
|
65 |
+
face_image = self.transform(face_image)
|
66 |
+
person_image = self.transform(person_image)
|
67 |
+
|
68 |
+
|
69 |
+
image_ = torch.cat((face_image, person_image), dim=0)
|
70 |
+
|
71 |
+
y_label = eval(self.data.iloc[idx]['y_label']) # assuming y_label is a string representation of a list
|
72 |
+
y1, y2, y3 = y_label
|
73 |
+
# y3 = (y3 - 48.0) / (95 - 1)
|
74 |
+
# y3 = (y3 - 36.77) / 21.6
|
75 |
+
y3 = (y3 - MEAN_TRAIN) / STD_TRAIN
|
76 |
+
# y_label = (y2, y1, y3)
|
77 |
+
y_label = (y1, y2, y3)
|
78 |
+
y_label = torch.tensor(y_label, dtype=torch.float32)
|
79 |
+
|
80 |
+
return image_, y_label, self.data.iloc[idx]['img_name'] if self.test==False else basename
|
81 |
+
|
82 |
+
|
83 |
+
transform_train = transforms.Compose([
|
84 |
+
transforms.RandAugment(magnitude=22),
|
85 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
86 |
+
transforms.RandomApply([transforms.ColorJitter()], p=0.5),
|
87 |
+
transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.8, 1.2)),
|
88 |
+
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
|
89 |
+
transforms.ToTensor(),
|
90 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
91 |
+
])
|
92 |
+
|
93 |
+
transform_valid = transforms.Compose([
|
94 |
+
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
|
95 |
+
transforms.ToTensor(),
|
96 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
97 |
+
])
|
98 |
+
|
99 |
+
def denormalize(image, mean, std):
|
100 |
+
mean = torch.tensor(mean).reshape(3, 1, 1)
|
101 |
+
std = torch.tensor(std).reshape(3, 1, 1)
|
102 |
+
image = image * std + mean
|
103 |
+
return image
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
# Đọc dữ liệu từ các file CSV đã tách
|
111 |
+
train_data = pd.read_csv('csv/data_train.csv')
|
112 |
+
val_data = pd.read_csv('csv/data_valid.csv')
|
113 |
+
test_data = pd.read_csv('csv/data_test.csv')
|
114 |
+
|
115 |
+
kid_data = pd.read_csv('csv/children_test.csv')
|
116 |
+
|
117 |
+
|
118 |
+
# Tạo dataset cho train, validation và test
|
119 |
+
train_dataset = CustomDataset(train_data, transform=transform_train)
|
120 |
+
val_dataset = CustomDataset(val_data, transform=transform_valid)
|
121 |
+
test_dataset = CustomDataset(test_data, transform=transform_valid)
|
122 |
+
kid_dataset = CustomDataset(kid_data, test=True, transform=transform_valid)
|
123 |
+
|
124 |
+
# Tạo dataloader cho train, validation và test
|
125 |
+
train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=True, num_workers=4)
|
126 |
+
val_dataloader = DataLoader(val_dataset, batch_size=50, shuffle=False, num_workers=4)
|
127 |
+
test_dataloader = DataLoader(test_dataset, batch_size=50, shuffle=False, num_workers=4)
|
128 |
+
kid_dataloader = DataLoader(kid_dataset, batch_size=50, shuffle=False, num_workers=4)
|
129 |
+
|
130 |
+
|
131 |
+
# Khởi tạo mô hình và các thành phần khác
|
132 |
+
model = MiVOLOModel(
|
133 |
+
layers=(4, 4, 8, 2),
|
134 |
+
img_size=224,
|
135 |
+
in_chans=6,
|
136 |
+
num_classes=3,
|
137 |
+
patch_size=8,
|
138 |
+
stem_hidden_dim=64,
|
139 |
+
embed_dims=(192, 384, 384, 384),
|
140 |
+
num_heads=(6, 12, 12, 12),
|
141 |
+
).to('cuda')
|
142 |
+
|
143 |
+
state = torch.load("models/model_imdb_cross_person_4.22_99.46.pth.tar", map_location="cpu")
|
144 |
+
state_dict = state["state_dict"]
|
145 |
+
model.load_state_dict(state_dict, strict=True)
|
146 |
+
# state = torch.load("modelstrain/best_model_weights_10.pth", map_location="cpu")
|
147 |
+
# model.load_state_dict(state, strict=True)
|
148 |
+
|
149 |
+
criterion_bce = nn.BCEWithLogitsLoss()
|
150 |
+
criterion_mse = nn.MSELoss()
|
151 |
+
|
152 |
+
|
153 |
+
# Khởi tạo optimizer với AdamW và scheduler
|
154 |
+
optimizer = optim.AdamW(model.parameters(), lr=1.0e-6, weight_decay=5e-6)
|
155 |
+
|
156 |
+
|
157 |
+
# Huấn luyện mô hình
|
158 |
+
num_epochs = 50
|
159 |
+
best_val_loss = float('inf')
|
160 |
+
# best_val_loss = 39.2124
|
161 |
+
stop_training = False
|
162 |
+
def get_optimizer_info(optimizer):
|
163 |
+
for param_group in optimizer.param_groups:
|
164 |
+
lr = param_group['lr']
|
165 |
+
return f"LR: {lr}"
|
166 |
+
|
167 |
+
|
168 |
+
for epoch in range(num_epochs):
|
169 |
+
model.train()
|
170 |
+
running_loss = 0.0
|
171 |
+
if stop_training:
|
172 |
+
break
|
173 |
+
|
174 |
+
train_dataloader = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch")
|
175 |
+
|
176 |
+
for i, (inputs, labels, _) in enumerate(train_dataloader):
|
177 |
+
inputs = inputs.to('cuda')
|
178 |
+
labels = [label.to('cuda') for label in labels]
|
179 |
+
|
180 |
+
optimizer.zero_grad()
|
181 |
+
batch_loss = 0
|
182 |
+
for j in range(inputs.size(0)):
|
183 |
+
input_image = inputs[j].unsqueeze(0)
|
184 |
+
target = labels[j].unsqueeze(0)
|
185 |
+
|
186 |
+
output = model(input_image)
|
187 |
+
gender_output = output[:, :2].softmax(dim=-1)
|
188 |
+
|
189 |
+
output_bce = output[:, :2]
|
190 |
+
target_bce = target[:, :2]
|
191 |
+
output_mse = output[:, 2]
|
192 |
+
target_mse = target[:, 2]
|
193 |
+
true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN
|
194 |
+
loss_bce = criterion_bce(gender_output, target_bce)
|
195 |
+
loss_mse = criterion_mse(output_mse, target_mse)
|
196 |
+
|
197 |
+
loss = loss_bce + loss_mse
|
198 |
+
batch_loss += loss
|
199 |
+
# loss = loss_mse
|
200 |
+
# if true_age >=1:
|
201 |
+
# batch_loss += loss
|
202 |
+
# else:
|
203 |
+
# batch_loss+=loss_mse
|
204 |
+
|
205 |
+
|
206 |
+
if torch.isnan(loss_bce).any() or torch.isnan(loss_mse).any() or torch.isnan(loss).any():
|
207 |
+
print(f'Epoch [{epoch + 1}], Batch [{i + 1}] - NaN detected in loss computation')
|
208 |
+
stop_training = True
|
209 |
+
break
|
210 |
+
|
211 |
+
if stop_training:
|
212 |
+
break
|
213 |
+
|
214 |
+
optimizer.zero_grad()
|
215 |
+
batch_loss /= inputs.size(0)
|
216 |
+
# print("batch_loss: ", batch_loss)
|
217 |
+
optimizer_info = get_optimizer_info(optimizer)
|
218 |
+
train_dataloader.set_postfix(batch_loss=batch_loss.item(), optimizer_info=optimizer_info)
|
219 |
+
|
220 |
+
batch_loss.backward()
|
221 |
+
optimizer.step()
|
222 |
+
|
223 |
+
# Tính toán validation loss sau mỗi epoch
|
224 |
+
model.eval()
|
225 |
+
val_loss = 0.0
|
226 |
+
|
227 |
+
val_dataloader = tqdm(val_dataloader, desc="Validating", unit="batch")
|
228 |
+
|
229 |
+
with torch.no_grad():
|
230 |
+
for i, (inputs, labels, _) in enumerate(val_dataloader):
|
231 |
+
inputs = inputs.to('cuda')
|
232 |
+
labels = labels.to('cuda')
|
233 |
+
|
234 |
+
for j in range(inputs.size(0)):
|
235 |
+
input_image = inputs[j].unsqueeze(0)
|
236 |
+
target = labels[j].unsqueeze(0)
|
237 |
+
output = model(input_image)
|
238 |
+
gender_output = output[:, :2].softmax(dim=-1)
|
239 |
+
|
240 |
+
output_bce = output[:, :2]
|
241 |
+
target_bce = target[:, :2]
|
242 |
+
output_mse = output[:, 2]
|
243 |
+
target_mse = target[:, 2]
|
244 |
+
|
245 |
+
loss_bce = criterion_bce(gender_output, target_bce)
|
246 |
+
loss_mse = criterion_mse(output_mse, target_mse)
|
247 |
+
true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN
|
248 |
+
loss = loss_bce + loss_mse
|
249 |
+
# if true_age >=1:
|
250 |
+
# loss = loss_bce + loss_mse
|
251 |
+
# else:
|
252 |
+
# loss = loss_mse
|
253 |
+
|
254 |
+
# loss = loss_mse
|
255 |
+
|
256 |
+
val_loss += loss.item()
|
257 |
+
|
258 |
+
val_loss /= len(val_dataloader)
|
259 |
+
print(f'Epoch [{epoch + 1}], Validation Loss: {val_loss:.4f}')
|
260 |
+
|
261 |
+
# Lưu lại trọng số tốt nhất
|
262 |
+
if val_loss < best_val_loss:
|
263 |
+
best_val_loss = val_loss
|
264 |
+
torch.save(model.state_dict(), 'modelstrain/best_model_weights_10.pth')
|
265 |
+
print(f'Saved best model weights with validation loss: {best_val_loss:.4f}')
|
266 |
+
|
267 |
+
print('Finished Training')
|
268 |
+
|
269 |
+
|
270 |
+
####################################### Đánh giá mô hình trên tập test ###################################################3
|
271 |
+
model.load_state_dict(torch.load('modelstrain/best_model_weights_10.pth'))
|
272 |
+
model.eval()
|
273 |
+
test_loss = 0.0
|
274 |
+
correct_gender = 0
|
275 |
+
total = 0
|
276 |
+
|
277 |
+
|
278 |
+
def tensor_to_image_with_text(tensor, true_age, predicted_age):
|
279 |
+
unloader = transforms.ToPILImage()
|
280 |
+
image = unloader(tensor.cpu().squeeze(0))
|
281 |
+
|
282 |
+
draw = ImageDraw.Draw(image)
|
283 |
+
font = ImageFont.load_default()
|
284 |
+
|
285 |
+
text_true = f'True Age: {true_age:.2f}'
|
286 |
+
text_predicted = f'Predicted Age: {predicted_age:.2f}'
|
287 |
+
|
288 |
+
# Text positions
|
289 |
+
text_position_true = (10, 10)
|
290 |
+
text_position_predicted = (10, 30)
|
291 |
+
|
292 |
+
# Calculate bounding box for the text
|
293 |
+
bbox_true = draw.textbbox(text_position_true, text_true, font=font)
|
294 |
+
bbox_predicted = draw.textbbox(text_position_predicted, text_predicted, font=font)
|
295 |
+
|
296 |
+
# Draw white rectangles behind the text
|
297 |
+
draw.rectangle(bbox_true, fill="white")
|
298 |
+
draw.rectangle(bbox_predicted, fill="white")
|
299 |
+
|
300 |
+
# Draw the text on top of the rectangles
|
301 |
+
draw.text(text_position_true, text_true, font=font, fill="green")
|
302 |
+
draw.text(text_position_predicted, text_predicted, font=font, fill="blue")
|
303 |
+
|
304 |
+
return image
|
305 |
+
|
306 |
+
save_dir = 'children_test'
|
307 |
+
# save_dir_under18 = 'saved_images_under18'
|
308 |
+
os.makedirs(save_dir, exist_ok=True)
|
309 |
+
# os.makedirs(save_dir_under18, exist_ok=True)
|
310 |
+
true_ages = []
|
311 |
+
predicted_ages = []
|
312 |
+
with torch.no_grad():
|
313 |
+
test_loss = 0.0
|
314 |
+
correct_gender = 0
|
315 |
+
total = 0
|
316 |
+
# Initialize lists to store paths and ages
|
317 |
+
image_data = []
|
318 |
+
|
319 |
+
# Load existing data from the JSON file if it exists
|
320 |
+
try:
|
321 |
+
with open('image_data.json', 'r') as json_file:
|
322 |
+
image_data = json.load(json_file)
|
323 |
+
except FileNotFoundError:
|
324 |
+
# If the file does not exist, start with an empty list
|
325 |
+
image_data = []
|
326 |
+
# for i, (inputs, labels) in enumerate(test_dataloader):
|
327 |
+
for i, (inputs, labels, img_paths) in tqdm(enumerate(kid_dataloader), total=len(kid_dataloader), desc="Processing batches"):
|
328 |
+
inputs = inputs.to('cuda')
|
329 |
+
labels = labels.to('cuda')
|
330 |
+
for j in range(inputs.size(0)):
|
331 |
+
input_image = inputs[j].unsqueeze(0)
|
332 |
+
print("input_image: ", input_image.shape)
|
333 |
+
target = labels[j].unsqueeze(0)
|
334 |
+
target_image = denormalize(input_image[:,3:].to('cpu'), [*IMAGENET_DEFAULT_MEAN], [*IMAGENET_DEFAULT_STD])
|
335 |
+
output = model(input_image)
|
336 |
+
print("output[:, :2]: ", output[:, :2])
|
337 |
+
gender_output = output[:, :2].softmax(dim=-1)
|
338 |
+
print("gender_output: ", gender_output)
|
339 |
+
output_bce = output[:, :2]
|
340 |
+
target_bce = target[:, :2]
|
341 |
+
output_mse = output[:, 2]
|
342 |
+
target_mse = target[:, 2]
|
343 |
+
# y3 = (y3 - 36.77) / 21.6
|
344 |
+
# predicted_age = output_mse.item() * (95 - 1) + 48.0
|
345 |
+
# true_age = target_mse.item() * (95 - 1) + 48.0
|
346 |
+
# predicted_age = output_mse.item() *21.6 + 36.77 - 1.0
|
347 |
+
# true_age = target_mse.item() *21.6 + 36.77
|
348 |
+
predicted_age = output_mse.item() *STD_TRAIN + MEAN_TRAIN
|
349 |
+
true_age = target_mse.item() *STD_TRAIN + MEAN_TRAIN
|
350 |
+
true_ages.append(true_age)
|
351 |
+
predicted_ages.append(predicted_age)
|
352 |
+
|
353 |
+
# Compute losses
|
354 |
+
loss_bce = criterion_bce(output_bce, target_bce)
|
355 |
+
loss_mse = criterion_mse(output_mse, target_mse)
|
356 |
+
loss = loss_bce + loss_mse
|
357 |
+
|
358 |
+
test_loss += loss.item()
|
359 |
+
_, predicted_gender = torch.max(gender_output, 1)
|
360 |
+
print("predicted_gender: ", predicted_gender)
|
361 |
+
_, target_gender = torch.max(target_bce, 1)
|
362 |
+
correct_gender += (predicted_gender == target_gender).sum().item()
|
363 |
+
total += target_gender.size(0)
|
364 |
+
|
365 |
+
# Convert to PIL image and add text
|
366 |
+
target_image_pil = tensor_to_image_with_text(target_image, true_age, predicted_age)
|
367 |
+
if predicted_age >=15:
|
368 |
+
print(img_paths[j], " ", predicted_age)
|
369 |
+
|
370 |
+
# Save the image
|
371 |
+
image_path = os.path.join(save_dir, f'{img_paths[j]}')
|
372 |
+
# if true_age < 18:
|
373 |
+
# image_path = os.path.join(save_dir_under18, f'{img_paths[j]}')
|
374 |
+
# else:
|
375 |
+
# image_path = os.path.join(save_dir, f'{img_paths[j]}')
|
376 |
+
image_data.append({"path": img_paths[j], "predicted_age": predicted_age, "predicted_gender": predicted_gender.item()})
|
377 |
+
target_image_pil.save(image_path)
|
378 |
+
# Save the data to a JSON file
|
379 |
+
|
380 |
+
with open('image_data.json', 'w') as json_file:
|
381 |
+
json.dump(image_data, json_file, indent=4)
|
382 |
+
test_loss /= len(test_dataloader)
|
383 |
+
gender_accuracy = correct_gender / total
|
384 |
+
print(f'Test Loss: {test_loss:.4f}, Gender Accuracy: {gender_accuracy:.4f}')
|
385 |
+
|
386 |
+
# Plotting true ages vs. predicted ages and save the plot
|
387 |
+
plt.figure(figsize=(10, 6))
|
388 |
+
plt.scatter(true_ages, predicted_ages, c='blue', label='Predicted Age')
|
389 |
+
plt.plot([min(true_ages), max(true_ages)], [min(true_ages), max(true_ages)], color='red', linestyle='--', label='Perfect Prediction')
|
390 |
+
plt.xlabel('True Age')
|
391 |
+
plt.ylabel('Predicted Age')
|
392 |
+
plt.title('True Age vs Predicted Age')
|
393 |
+
plt.legend()
|
394 |
+
plt.grid(True)
|
395 |
+
plt.savefig('age_prediction_comparison.png')
|
396 |
+
plt.close()
|