MiVOLO / mivolo /data /dataset /age_gender_dataset.py
admin
sync
319d3b5
import logging
from typing import Any, List, Optional, Set
import cv2
import numpy as np
import torch
from mivolo.data.dataset.reader_age_gender import ReaderAgeGender
from PIL import Image
from torchvision import transforms
_logger = logging.getLogger("AgeGenderDataset")
class AgeGenderDataset(torch.utils.data.Dataset):
def __init__(
self,
images_path,
annotations_path,
name=None,
split="train",
load_bytes=False,
img_mode="RGB",
transform=None,
is_training=False,
seed=1234,
target_size=224,
min_age=None,
max_age=None,
model_with_persons=False,
use_persons=False,
disable_faces=False,
only_age=False,
):
reader = ReaderAgeGender(
images_path,
annotations_path,
split=split,
seed=seed,
target_size=target_size,
with_persons=use_persons,
disable_faces=disable_faces,
only_age=only_age,
)
self.name = name
self.model_with_persons = model_with_persons
self.reader = reader
self.load_bytes = load_bytes
self.img_mode = img_mode
self.transform = transform
self._consecutive_errors = 0
self.is_training = is_training
self.random_flip = 0.0
# Setting up classes.
# If min and max classes are passed - use them to have the same preprocessing for validation
self.max_age: float = None
self.min_age: float = None
self.avg_age: float = None
self.set_ages_min_max(min_age, max_age)
self.genders = ["M", "F"]
self.num_classes_gender = len(self.genders)
self.age_classes: Optional[List[str]] = self.set_age_classes()
self.num_classes_age = 1 if self.age_classes is None else len(self.age_classes)
self.num_classes: int = self.num_classes_age + self.num_classes_gender
self.target_dtype = torch.float32
def set_age_classes(self) -> Optional[List[str]]:
return None # for regression dataset
def set_ages_min_max(self, min_age: Optional[float], max_age: Optional[float]):
assert all(age is None for age in [min_age, max_age]) or all(
age is not None for age in [min_age, max_age]
), "Both min and max age must be passed or none of them"
if max_age is not None and min_age is not None:
_logger.info(f"Received predefined min_age {min_age} and max_age {max_age}")
self.max_age = max_age
self.min_age = min_age
else:
# collect statistics from loaded dataset
all_ages_set: Set[int] = set()
for img_path, image_samples in self.reader._ann.items():
for image_sample_info in image_samples:
if image_sample_info.age == "-1":
continue
age = round(float(image_sample_info.age))
all_ages_set.add(age)
self.max_age = max(all_ages_set)
self.min_age = min(all_ages_set)
self.avg_age = (self.max_age + self.min_age) / 2.0
def _norm_age(self, age):
return (age - self.avg_age) / (self.max_age - self.min_age)
def parse_gender(self, _gender: str) -> float:
if _gender != "-1":
gender = float(0 if _gender == "M" or _gender == "0" else 1)
else:
gender = -1
return gender
def parse_target(self, _age: str, gender: str) -> List[Any]:
if _age != "-1":
age = round(float(_age))
age = self._norm_age(float(age))
else:
age = -1
target: List[float] = [age, self.parse_gender(gender)]
return target
@property
def transform(self):
return self._transform
@transform.setter
def transform(self, transform):
# Disable pretrained monkey-patched transforms
if not transform:
return
_trans = []
for trans in transform.transforms:
if "Resize" in str(trans):
continue
if "Crop" in str(trans):
continue
_trans.append(trans)
self._transform = transforms.Compose(_trans)
def apply_tranforms(self, image: Optional[np.ndarray]) -> np.ndarray:
if image is None:
return None
if self.transform is None:
return image
image = convert_to_pil(image, self.img_mode)
for trans in self.transform.transforms:
image = trans(image)
return image
def __getitem__(self, index):
# get preprocessed face and person crops (np.ndarray)
# resize + pad, for person crops: cut off other bboxes
images, target = self.reader[index]
target = self.parse_target(*target)
if self.model_with_persons:
face_image, person_image = images
person_image: np.ndarray = self.apply_tranforms(person_image)
else:
face_image = images[0]
person_image = None
face_image: np.ndarray = self.apply_tranforms(face_image)
if person_image is not None:
img = np.concatenate([face_image, person_image], axis=0)
else:
img = face_image
return img, target
def __len__(self):
return len(self.reader)
def filename(self, index, basename=False, absolute=False):
return self.reader.filename(index, basename, absolute)
def filenames(self, basename=False, absolute=False):
return self.reader.filenames(basename, absolute)
def convert_to_pil(cv_im: Optional[np.ndarray], img_mode: str = "RGB") -> "Image":
if cv_im is None:
return None
if img_mode == "RGB":
cv_im = cv2.cvtColor(cv_im, cv2.COLOR_BGR2RGB)
else:
raise Exception("Incorrect image mode has been passed!")
cv_im = np.ascontiguousarray(cv_im)
pil_image = Image.fromarray(cv_im)
return pil_image