diff --git a/age_estimator/.DS_Store b/age_estimator/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..c7012a9446f53640aa9f7828d81d4df21da7d4bc
Binary files /dev/null and b/age_estimator/.DS_Store differ
diff --git a/age_estimator/__init__.py b/age_estimator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/age_estimator/__pycache__/__init__.cpython-38.pyc b/age_estimator/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee8f9f068487da04cd70ddab3bc83f13ad294f9d
Binary files /dev/null and b/age_estimator/__pycache__/__init__.cpython-38.pyc differ
diff --git a/age_estimator/__pycache__/admin.cpython-38.pyc b/age_estimator/__pycache__/admin.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2bfc1a54189550326c801a79eaa39a7b253a33aa
Binary files /dev/null and b/age_estimator/__pycache__/admin.cpython-38.pyc differ
diff --git a/age_estimator/__pycache__/apps.cpython-38.pyc b/age_estimator/__pycache__/apps.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..674e8aaed47e31e1c6de9176165d290cbf0315f6
Binary files /dev/null and b/age_estimator/__pycache__/apps.cpython-38.pyc differ
diff --git a/age_estimator/__pycache__/models.cpython-38.pyc b/age_estimator/__pycache__/models.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1223fb1df6113661748a85bc509b30106d745578
Binary files /dev/null and b/age_estimator/__pycache__/models.cpython-38.pyc differ
diff --git a/age_estimator/__pycache__/urls.cpython-38.pyc b/age_estimator/__pycache__/urls.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a63c98a4178e55b8cf65b65e6b3fca723c138f0
Binary files /dev/null and b/age_estimator/__pycache__/urls.cpython-38.pyc differ
diff --git a/age_estimator/__pycache__/views.cpython-38.pyc b/age_estimator/__pycache__/views.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3af0cfdbb41d85f501d0818b27d7e1fdd7db27f2
Binary files /dev/null and b/age_estimator/__pycache__/views.cpython-38.pyc differ
diff --git a/age_estimator/admin.py b/age_estimator/admin.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c38f3f3dad51e4585f3984282c2a4bec5349c1e
--- /dev/null
+++ b/age_estimator/admin.py
@@ -0,0 +1,3 @@
+from django.contrib import admin
+
+# Register your models here.
diff --git a/age_estimator/apps.py b/age_estimator/apps.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd67634fc0752a859ed6a8cba63dd9b80d97e8bb
--- /dev/null
+++ b/age_estimator/apps.py
@@ -0,0 +1,6 @@
+from django.apps import AppConfig
+
+
+class AgeEstimatorConfig(AppConfig):
+ default_auto_field = 'django.db.models.BigAutoField'
+ name = 'age_estimator'
diff --git a/age_estimator/migrations/__init__.py b/age_estimator/migrations/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/age_estimator/migrations/__pycache__/__init__.cpython-38.pyc b/age_estimator/migrations/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..586cdd5989324468cc02ad0c0f839bd7604f3116
Binary files /dev/null and b/age_estimator/migrations/__pycache__/__init__.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/.DS_Store b/age_estimator/mivolo/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..c43ee62a22d53cc75b181ee2096e26474c200495
Binary files /dev/null and b/age_estimator/mivolo/.DS_Store differ
diff --git a/age_estimator/mivolo/.flake8 b/age_estimator/mivolo/.flake8
new file mode 100644
index 0000000000000000000000000000000000000000..c61c9b5694ff9580094261468d7e7e6c87abd51c
--- /dev/null
+++ b/age_estimator/mivolo/.flake8
@@ -0,0 +1,5 @@
+[flake8]
+max-line-length = 120
+inline-quotes = "
+multiline-quotes = "
+ignore = E203,W503
diff --git a/age_estimator/mivolo/.gitignore b/age_estimator/mivolo/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..8b048b39eff5d43df1234e5731f045960381751c
--- /dev/null
+++ b/age_estimator/mivolo/.gitignore
@@ -0,0 +1,85 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+*.DS_Store
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# PyTorch weights
+*.tar
+*.pth
+*.pt
+*.torch
+*.gz
+Untitled.ipynb
+Testing notebook.ipynb
+
+# Root dir exclusions
+/*.csv
+/*.yaml
+/*.json
+/*.jpg
+/*.png
+/*.zip
+/*.tar.*
+*.jpg
+*.png
+*.avi
+*.mp4
+*.svg
+
+.mypy_cache/
+.vscode/
+.idea
+
+output/
+input/
+
+run.sh
diff --git a/age_estimator/mivolo/.isort.cfg b/age_estimator/mivolo/.isort.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..63135f1820165af7421d82b767e964f8f2daa464
--- /dev/null
+++ b/age_estimator/mivolo/.isort.cfg
@@ -0,0 +1,5 @@
+[settings]
+profile = black
+line_length = 120
+src_paths = ["mivolo", "scripts", "tools"]
+filter_files = true
diff --git a/age_estimator/mivolo/.pre-commit-config.yaml b/age_estimator/mivolo/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f722055e25bcc3685762d14109f68518abfce1ff
--- /dev/null
+++ b/age_estimator/mivolo/.pre-commit-config.yaml
@@ -0,0 +1,31 @@
+repos:
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.2.0
+ hooks:
+ - id: check-yaml
+ args: ['--unsafe']
+ - id: check-toml
+ - id: debug-statements
+ - id: end-of-file-fixer
+ exclude: poetry.lock
+ - id: trailing-whitespace
+- repo: https://github.com/PyCQA/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ args: [ "--profile", "black", "--filter-files" ]
+- repo: https://github.com/psf/black
+ rev: 22.3.0
+ hooks:
+ - id: black
+ args: ["--line-length", "120"]
+- repo: https://github.com/PyCQA/flake8
+ rev: 3.9.2
+ hooks:
+ - id: flake8
+ args: [ "--config", ".flake8" ]
+ additional_dependencies: [ "flake8-quotes" ]
+- repo: https://github.com/pre-commit/mirrors-mypy
+ rev: v0.942
+ hooks:
+ - id: mypy
diff --git a/age_estimator/mivolo/CHANGELOG.md b/age_estimator/mivolo/CHANGELOG.md
new file mode 100644
index 0000000000000000000000000000000000000000..91904fb73d6f47d730b38ebf97113ecf729abc59
--- /dev/null
+++ b/age_estimator/mivolo/CHANGELOG.md
@@ -0,0 +1,16 @@
+
+## 0.4.1dev (15.08.2023)
+
+### Added
+- Support for video streams, including YouTube URLs
+- Instructions and explanations for various export types.
+
+### Changed
+- Removed CutOff operation. It has been proven to be ineffective for inference time and quite costly at the same time. Now it is only used during training.
+
+## 0.4.2dev (22.09.2023)
+
+### Added
+
+- Script for AgeDB dataset convertation to csv format
+- Additional metrics were added to README
diff --git a/age_estimator/mivolo/README.md b/age_estimator/mivolo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7fafe2cd6b32ca98dcf7bedc3c4fdaa1276cc83d
--- /dev/null
+++ b/age_estimator/mivolo/README.md
@@ -0,0 +1,417 @@
+
+
+
+
+
+
+
+
+
+
+## MiVOLO: Multi-input Transformer for Age and Gender Estimation
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/mivolo-multi-input-transformer-for-age-and/age-estimation-on-utkface)](https://paperswithcode.com/sota/age-estimation-on-utkface?p=mivolo-multi-input-transformer-for-age-and) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/age-estimation-on-imdb-clean)](https://paperswithcode.com/sota/age-estimation-on-imdb-clean?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/facial-attribute-classification-on-fairface)](https://paperswithcode.com/sota/facial-attribute-classification-on-fairface?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/age-and-gender-classification-on-adience)](https://paperswithcode.com/sota/age-and-gender-classification-on-adience?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/age-and-gender-classification-on-adience-age)](https://paperswithcode.com/sota/age-and-gender-classification-on-adience-age?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/age-and-gender-estimation-on-lagenda-age)](https://paperswithcode.com/sota/age-and-gender-estimation-on-lagenda-age?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/gender-prediction-on-lagenda)](https://paperswithcode.com/sota/gender-prediction-on-lagenda?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/mivolo-multi-input-transformer-for-age-and/age-estimation-on-agedb)](https://paperswithcode.com/sota/age-estimation-on-agedb?p=mivolo-multi-input-transformer-for-age-and) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/mivolo-multi-input-transformer-for-age-and/gender-prediction-on-agedb)](https://paperswithcode.com/sota/gender-prediction-on-agedb?p=mivolo-multi-input-transformer-for-age-and) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/age-estimation-on-cacd)](https://paperswithcode.com/sota/age-estimation-on-cacd?p=beyond-specialization-assessing-the-1)
+
+> [**MiVOLO: Multi-input Transformer for Age and Gender Estimation**](https://arxiv.org/abs/2307.04616),
+> Maksim Kuprashevich, Irina Tolstykh,
+> *2023 [arXiv 2307.04616](https://arxiv.org/abs/2307.04616)*
+
+> [**Beyond Specialization: Assessing the Capabilities of MLLMs in Age and Gender Estimation**](https://arxiv.org/abs/2403.02302),
+> Maksim Kuprashevich, Grigorii Alekseenko, Irina Tolstykh
+> *2024 [arXiv 2403.02302](https://arxiv.org/abs/2403.02302)*
+
+[[`Paper 2023`](https://arxiv.org/abs/2307.04616)] [[`Paper 2024`](https://arxiv.org/abs/2403.02302)] [[`Demo`](https://huggingface.co/spaces/iitolstykh/age_gender_estimation_demo)] [[`Telegram Bot`](https://t.me/AnyAgeBot)] [[`BibTex`](#citing)] [[`Data`](https://wildchlamydia.github.io/lagenda/)]
+
+
+## MiVOLO pretrained models
+
+Gender & Age recognition performance.
+
+
+
+ Model |
+ Type |
+ Dataset (train and test) |
+ Age MAE |
+ Age CS@5 |
+ Gender Accuracy |
+ download |
+
+
+ volo_d1 |
+ face_only, age |
+ IMDB-cleaned |
+ 4.29 |
+ 67.71 |
+ - |
+ checkpoint |
+
+
+ volo_d1 |
+ face_only, age, gender |
+ IMDB-cleaned |
+ 4.22 |
+ 68.68 |
+ 99.38 |
+ checkpoint |
+
+
+ mivolo_d1 |
+ face_body, age, gender |
+ IMDB-cleaned |
+ 4.24 [face+body] 6.87 [body] |
+ 68.32 [face+body] 46.32 [body] |
+ 99.46 [face+body] 96.48 [body] |
+ model_imdb_cross_person_4.24_99.46.pth.tar |
+
+
+ volo_d1 |
+ face_only, age |
+ UTKFace |
+ 4.23 |
+ 69.72 |
+ - |
+ checkpoint |
+
+
+ volo_d1 |
+ face_only, age, gender |
+ UTKFace |
+ 4.23 |
+ 69.78 |
+ 97.69 |
+ checkpoint |
+
+
+ mivolo_d1 |
+ face_body, age, gender |
+ Lagenda |
+ 3.99 [face+body] |
+ 71.27 [face+body] |
+ 97.36 [face+body] |
+ demo |
+
+
+ mivolov2_d1_384x384 |
+ face_body, age, gender |
+ Lagenda |
+ 3.65 [face+body] |
+ 74.48 [face+body] |
+ 97.99 [face+body] |
+ telegram bot |
+
+
+
+
+## MiVOLO regression benchmarks
+
+Gender & Age recognition performance.
+
+Use [valid_age_gender.sh](scripts/valid_age_gender.sh) to reproduce results with our checkpoints.
+
+
+
+ Model |
+ Type |
+ Train Dataset |
+ Test Dataset |
+ Age MAE |
+ Age CS@5 |
+ Gender Accuracy |
+ download |
+
+
+
+ mivolo_d1 |
+ face_body, age, gender |
+ Lagenda |
+ AgeDB |
+ 5.55 [face] |
+ 55.08 [face] |
+ 98.3 [face] |
+ demo |
+
+
+ mivolo_d1 |
+ face_body, age, gender |
+ IMDB-cleaned |
+ AgeDB |
+ 5.58 [face] |
+ 55.54 [face] |
+ 97.93 [face] |
+ model_imdb_cross_person_4.24_99.46.pth.tar |
+
+
+
+
+## MiVOLO classification benchmarks
+
+Gender & Age recognition performance.
+
+
+
+ Model |
+ Type |
+ Train Dataset |
+ Test Dataset |
+ Age Accuracy |
+ Gender Accuracy |
+
+
+
+ mivolo_d1 |
+ face_body, age, gender |
+ Lagenda |
+ FairFace |
+ 61.07 [face+body] |
+ 95.73 [face+body] |
+
+
+ mivolo_d1 |
+ face_body, age, gender |
+ Lagenda |
+ Adience |
+ 68.69 [face] |
+ 96.51[face] |
+
+
+ mivolov2_d1_384 |
+ face_body, age, gender |
+ Lagenda |
+ Adience |
+ 69.43 [face] |
+ 97.39[face] |
+
+
+
+
+## Dataset
+
+**Please, [cite our papers](#citing) if you use any this data!**
+
+- Lagenda dataset: [images](https://drive.google.com/file/d/1QXO0NlkABPZT6x1_0Uc2i6KAtdcrpTbG/view?usp=sharing) and [annotation](https://drive.google.com/file/d/1mNYjYFb3MuKg-OL1UISoYsKObMUllbJx/view?usp=sharing).
+- IMDB-clean: follow [these instructions](https://github.com/yiminglin-ai/imdb-clean) to get images and [download](https://drive.google.com/file/d/17uEqyU3uQ5trWZ5vRJKzh41yeuDe5hyL/view?usp=sharing) our annotations.
+- UTK dataset: [origin full images](https://susanqq.github.io/UTKFace/) and our annotation: [split from the article](https://drive.google.com/file/d/1Fo1vPWrKtC5bPtnnVWNTdD4ZTKRXL9kv/view?usp=sharing), [random full split](https://drive.google.com/file/d/177AV631C3SIfi5nrmZA8CEihIt29cznJ/view?usp=sharing).
+- Adience dataset: follow [these instructions](https://talhassner.github.io/home/projects/Adience/Adience-data.html) to get images and [download](https://drive.google.com/file/d/1wS1Q4FpksxnCR88A1tGLsLIr91xHwcVv/view?usp=sharing) our annotations.
+
+ Click to expand!
+
+ After downloading them, your `data` directory should look something like this:
+
+ ```console
+ data
+ └── Adience
+ ├── annotations (folder with our annotations)
+ ├── aligned (will not be used)
+ ├── faces
+ ├── fold_0_data.txt
+ ├── fold_1_data.txt
+ ├── fold_2_data.txt
+ ├── fold_3_data.txt
+ └── fold_4_data.txt
+ ```
+
+ We use coarse aligned images from `faces/` dir.
+
+ Using our detector we found a face bbox for each image (see [tools/prepare_adience.py](tools/prepare_adience.py)).
+
+ This dataset has five folds. The performance metric is accuracy on five-fold cross validation.
+
+ | images before removal | fold 0 | fold 1 | fold 2 | fold 3 | fold 4 |
+ | --------------------- | ------ | ------ | ------ | ------ | ------ |
+ | 19,370 | 4,484 | 3,730 | 3,894 | 3,446 | 3,816 |
+
+ Not complete data
+
+ | only age not found | only gender not found | SUM |
+ | ------------------ | --------------------- | ------------- |
+ | 40 | 1170 | 1,210 (6.2 %) |
+
+ Removed data
+
+ | failed to process image | age and gender not found | SUM |
+ | ----------------------- | ------------------------ | ----------- |
+ | 0 | 708 | 708 (3.6 %) |
+
+ Genders
+
+ | female | male |
+ | ------ | ----- |
+ | 9,372 | 8,120 |
+
+ Ages (8 classes) after mapping to not intersected ages intervals
+
+ | 0-2 | 4-6 | 8-12 | 15-20 | 25-32 | 38-43 | 48-53 | 60-100 |
+ | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ------ |
+ | 2,509 | 2,140 | 2,293 | 1,791 | 5,589 | 2,490 | 909 | 901 |
+
+
+
+- FairFace dataset: follow [these instructions](https://github.com/joojs/fairface) to get images and [download](https://drive.google.com/file/d/1EdY30A1SQmox96Y39VhBxdgALYhbkzdm/view?usp=drive_link) our annotations.
+
+ Click to expand!
+
+ After downloading them, your `data` directory should look something like this:
+
+ ```console
+ data
+ └── FairFace
+ ├── annotations (folder with our annotations)
+ ├── fairface-img-margin025-trainval (will not be used)
+ ├── train
+ ├── val
+ ├── fairface-img-margin125-trainval
+ ├── train
+ ├── val
+ ├── fairface_label_train.csv
+ ├── fairface_label_val.csv
+
+ ```
+
+ We use aligned images from `fairface-img-margin125-trainval/` dir.
+
+ Using our detector we found a face bbox for each image and added a person bbox if it was possible (see [tools/prepare_fairface.py](tools/prepare_fairface.py)).
+
+ This dataset has 2 splits: train and val. The performance metric is accuracy on validation.
+
+ | images train | images val |
+ | ------------ | ---------- |
+ | 86,744 | 10,954 |
+
+ Genders for **validation**
+
+ | female | male |
+ | ------ | ----- |
+ | 5,162 | 5,792 |
+
+ Ages for **validation** (9 classes):
+
+ | 0-2 | 3-9 | 10-19 | 20-29 | 30-39 | 40-49 | 50-59 | 60-69 | 70+ |
+ | --- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | --- |
+ | 199 | 1,356 | 1,181 | 3,300 | 2,330 | 1,353 | 796 | 321 | 118 |
+
+
+- AgeDB dataset: follow [these instructions](https://ibug.doc.ic.ac.uk/resources/agedb/) to get images and [download](https://drive.google.com/file/d/1Dp72BUlAsyUKeSoyE_DOsFRS1x6ZBJen/view) our annotations.
+
+ Click to expand!
+
+ **Ages**: 1 - 101
+
+ **Genders**: 9788 faces of `M`, 6700 faces of `F`
+
+ | images 0 | images 1 | images 2 | images 3 | images 4 | images 5 | images 6 | images 7 | images 8 | images 9 |
+ |----------|----------|----------|----------|----------|----------|----------|----------|----------|----------|
+ | 1701 | 1721 | 1615 | 1619 | 1626 | 1643 | 1634 | 1596 | 1676 | 1657 |
+
+ Data splits were taken from [here](https://github.com/paplhjak/Facial-Age-Estimation-Benchmark-Databases)
+
+ !! **All splits(all dataset) were used for models evaluation.**
+
+
+## Install
+
+Install pytorch 1.13+ and other requirements.
+
+```
+pip install -r requirements.txt
+pip install .
+```
+
+
+## Demo
+
+1. [Download](https://drive.google.com/file/d/1CGNCkZQNj5WkP3rLpENWAOgrBQkUWRdw/view) body + face detector model to `models/yolov8x_person_face.pt`
+2. [Download](https://drive.google.com/file/d/11i8pKctxz3wVkDBlWKvhYIh7kpVFXSZ4/view) mivolo checkpoint to `models/mivolo_imbd.pth.tar`
+
+```bash
+wget https://variety.com/wp-content/uploads/2023/04/MCDNOHA_SP001.jpg -O jennifer_lawrence.jpg
+
+python3 demo.py \
+--input "jennifer_lawrence.jpg" \
+--output "output" \
+--detector-weights "models/yolov8x_person_face.pt " \
+--checkpoint "models/mivolo_imbd.pth.tar" \
+--device "cuda:0" \
+--with-persons \
+--draw
+```
+
+To run demo for a youtube video:
+```bash
+python3 demo.py \
+--input "https://www.youtube.com/shorts/pVh32k0hGEI" \
+--output "output" \
+--detector-weights "models/yolov8x_person_face.pt" \
+--checkpoint "models/mivolo_imbd.pth.tar" \
+--device "cuda:0" \
+--draw \
+--with-persons
+```
+
+
+## Validation
+
+To reproduce validation metrics:
+
+1. Download prepared annotations for imbd-clean / utk / adience / lagenda / fairface.
+2. Download checkpoint
+3. Run validation:
+
+```bash
+python3 eval_pretrained.py \
+ --dataset_images /path/to/dataset/utk/images \
+ --dataset_annotations /path/to/dataset/utk/annotation \
+ --dataset_name utk \
+ --split valid \
+ --batch-size 512 \
+ --checkpoint models/mivolo_imbd.pth.tar \
+ --half \
+ --with-persons \
+ --device "cuda:0"
+````
+
+Supported dataset names: "utk", "imdb", "lagenda", "fairface", "adience".
+
+
+## Changelog
+
+[CHANGELOG.md](CHANGELOG.md)
+
+## ONNX and TensorRT export
+
+As of now (11.08.2023), while ONNX export is technically feasible, it is not advisable due to the poor performance of the resulting model with batch processing.
+**TensorRT** and **OpenVINO** export is impossible due to its lack of support for col2im.
+
+If you remain absolutely committed to utilizing ONNX export, you can refer to [these instructions](https://github.com/WildChlamydia/MiVOLO/issues/14#issuecomment-1675245889).
+
+The most highly recommended export method at present **is using TorchScript**. You can achieve this with a single line of code:
+```python
+torch.jit.trace(model)
+```
+This approach provides you with a model that maintains its original speed and only requires a single file for usage, eliminating the need for additional code.
+
+## License
+
+Please, see [here](./license)
+
+
+## Citing
+
+If you use our models, code or dataset, we kindly request you to cite the following paper and give repository a :star:
+
+```bibtex
+@article{mivolo2023,
+ Author = {Maksim Kuprashevich and Irina Tolstykh},
+ Title = {MiVOLO: Multi-input Transformer for Age and Gender Estimation},
+ Year = {2023},
+ Eprint = {arXiv:2307.04616},
+}
+```
+```bibtex
+@article{mivolo2024,
+ Author = {Maksim Kuprashevich and Grigorii Alekseenko and Irina Tolstykh},
+ Title = {Beyond Specialization: Assessing the Capabilities of MLLMs in Age and Gender Estimation},
+ Year = {2024},
+ Eprint = {arXiv:2403.02302},
+}
+```
diff --git a/age_estimator/mivolo/__pycache__/demo_copy.cpython-38.pyc b/age_estimator/mivolo/__pycache__/demo_copy.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32763331326c804b18d5cc804cf3f584cff1550a
Binary files /dev/null and b/age_estimator/mivolo/__pycache__/demo_copy.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/demo.py b/age_estimator/mivolo/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..51e367b5035c7ce0bc21a1f57340bba4db95eba4
--- /dev/null
+++ b/age_estimator/mivolo/demo.py
@@ -0,0 +1,145 @@
+import argparse
+import logging
+import os
+import random
+
+import cv2
+import torch
+import yt_dlp
+from mivolo.data.data_reader import InputType, get_all_files, get_input_type
+from mivolo.predictor import Predictor
+from timm.utils import setup_default_logging
+
+_logger = logging.getLogger("inference")
+
+
+def get_direct_video_url(video_url):
+ ydl_opts = {
+ "format": "bestvideo",
+ "quiet": True, # Suppress terminal output
+ }
+
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
+ info_dict = ydl.extract_info(video_url, download=False)
+
+ if "url" in info_dict:
+ direct_url = info_dict["url"]
+ resolution = (info_dict["width"], info_dict["height"])
+ fps = info_dict["fps"]
+ yid = info_dict["id"]
+ return direct_url, resolution, fps, yid
+
+ return None, None, None, None
+
+
+def get_local_video_info(vid_uri):
+ cap = cv2.VideoCapture(vid_uri)
+ if not cap.isOpened():
+ raise ValueError(f"Failed to open video source {vid_uri}")
+ res = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ return res, fps
+
+
+def get_random_frames(cap, num_frames):
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ frame_indices = random.sample(range(total_frames), num_frames)
+
+ frames = []
+ for idx in frame_indices:
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
+ ret, frame = cap.read()
+ if ret:
+ frames.append(frame)
+ return frames
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="PyTorch MiVOLO Inference")
+ parser.add_argument("--input", type=str, default=None, required=True, help="image file or folder with images")
+ parser.add_argument("--output", type=str, default=None, required=True, help="folder for output results")
+ parser.add_argument("--detector-weights", type=str, default=None, required=True, help="Detector weights (YOLOv8).")
+ parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint")
+
+ parser.add_argument(
+ "--with-persons", action="store_true", default=False, help="If set model will run with persons, if available"
+ )
+ parser.add_argument(
+ "--disable-faces", action="store_true", default=False, help="If set model will use only persons if available"
+ )
+
+ parser.add_argument("--draw", action="store_true", default=False, help="If set, resulted images will be drawn")
+ parser.add_argument("--device", default="cuda", type=str, help="Device (accelerator) to use.")
+
+ return parser
+
+
+def main():
+ parser = get_parser()
+ setup_default_logging()
+ args = parser.parse_args()
+
+ if torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.benchmark = True
+ os.makedirs(args.output, exist_ok=True)
+
+ predictor = Predictor(args, verbose=True)
+
+ input_type = get_input_type(args.input)
+
+ if input_type == InputType.Video or input_type == InputType.VideoStream:
+ if "youtube" in args.input:
+ args.input, res, fps, yid = get_direct_video_url(args.input)
+ if not args.input:
+ raise ValueError(f"Failed to get direct video url {args.input}")
+ else:
+ cap = cv2.VideoCapture(args.input)
+ if not cap.isOpened():
+ raise ValueError(f"Failed to open video source {args.input}")
+
+ # Extract 4-5 random frames from the video
+ random_frames = get_random_frames(cap, num_frames=5)
+
+ age_list = []
+ for frame in random_frames:
+ detected_objects, out_im, age = predictor.recognize(frame)
+ age_list.append(age[0])
+
+ if args.draw:
+ bname = os.path.splitext(os.path.basename(args.input))[0]
+ filename = os.path.join(args.output, f"out_{bname}.jpg")
+ cv2.imwrite(filename, out_im)
+ _logger.info(f"Saved result to {filename}")
+
+ # Calculate and print average age
+ avg_age = sum(age_list) / len(age_list) if age_list else 0
+ print(f"Age list: {age_list}")
+ print(f"Average age: {avg_age:.2f}")
+ absolute_age = round(abs(avg_age))
+ # Define the range
+ lower_bound = absolute_age - 2
+ upper_bound = absolute_age + 2
+
+
+ return absolute_age, lower_bound, upper_bound
+
+ elif input_type == InputType.Image:
+ image_files = get_all_files(args.input) if os.path.isdir(args.input) else [args.input]
+
+ for img_p in image_files:
+ img = cv2.imread(img_p)
+ detected_objects, out_im, age = predictor.recognize(img)
+
+ if args.draw:
+ bname = os.path.splitext(os.path.basename(img_p))[0]
+ filename = os.path.join(args.output, f"out_{bname}.jpg")
+ cv2.imwrite(filename, out_im)
+ _logger.info(f"Saved result to {filename}")
+
+
+if __name__ == "__main__":
+ absolute_age, lower_bound, upper_bound = main()
+ # Output the results in the desired format
+ print(f"Absolute Age: {absolute_age}")
+ print(f"Range: {lower_bound} - {upper_bound}")
\ No newline at end of file
diff --git a/age_estimator/mivolo/demo_copy.py b/age_estimator/mivolo/demo_copy.py
new file mode 100644
index 0000000000000000000000000000000000000000..a915cb1daae47b4d4fa8e72d531b4cb459b661ab
--- /dev/null
+++ b/age_estimator/mivolo/demo_copy.py
@@ -0,0 +1,144 @@
+import argparse
+import logging
+import os
+import random
+
+import cv2
+import torch
+import yt_dlp
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '././')))
+
+from mivolo.data.data_reader import InputType, get_all_files, get_input_type
+from mivolo.predictor import Predictor
+from timm.utils import setup_default_logging
+
+_logger = logging.getLogger("inference")
+
+
+def get_direct_video_url(video_url):
+ ydl_opts = {
+ "format": "bestvideo",
+ "quiet": True, # Suppress terminal output
+ }
+
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
+ info_dict = ydl.extract_info(video_url, download=False)
+
+ if "url" in info_dict:
+ direct_url = info_dict["url"]
+ resolution = (info_dict["width"], info_dict["height"])
+ fps = info_dict["fps"]
+ yid = info_dict["id"]
+ return direct_url, resolution, fps, yid
+
+ return None, None, None, None
+
+
+def get_random_frames(cap, num_frames):
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ frame_indices = random.sample(range(total_frames), num_frames)
+
+ frames = []
+ for idx in frame_indices:
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
+ ret, frame = cap.read()
+ if ret:
+ frames.append(frame)
+ return frames
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="PyTorch MiVOLO Inference")
+ parser.add_argument("--input", type=str, default=None, required=True, help="image file or folder with images")
+ parser.add_argument("--output", type=str, default=None, required=True, help="folder for output results")
+ parser.add_argument("--detector-weights", type=str, default=None, required=True, help="Detector weights (YOLOv8).")
+ parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint")
+
+ parser.add_argument(
+ "--with_persons", action="store_true", default=False, help="If set model will run with persons, if available"
+ )
+ parser.add_argument(
+ "--disable_faces", action="store_true", default=False, help="If set model will use only persons if available"
+ )
+
+ parser.add_argument("--draw", action="store_true", default=False, help="If set, resulted images will be drawn")
+ parser.add_argument("--device", default="cpu", type=str, help="Device (accelerator) to use.")
+
+ return parser
+
+
+def main(video_path, output_folder, detector_weights, checkpoint, device, with_persons, disable_faces,draw=False):
+ setup_default_logging()
+
+ if torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.benchmark = True
+
+ os.makedirs(output_folder, exist_ok=True)
+
+ # Initialize predictor
+ args = argparse.Namespace(
+ input=video_path,
+ output=output_folder,
+ detector_weights=detector_weights,
+ checkpoint=checkpoint,
+ draw=draw,
+ device=device,
+ with_persons=with_persons,
+ disable_faces=disable_faces
+ )
+
+ predictor = Predictor(args, verbose=True)
+
+ if "youtube" in video_path:
+ video_path, res, fps, yid = get_direct_video_url(video_path)
+ if not video_path:
+ raise ValueError(f"Failed to get direct video url {video_path}")
+
+ cap = cv2.VideoCapture(video_path)
+ if not cap.isOpened():
+ raise ValueError(f"Failed to open video source {video_path}")
+
+ # Extract 4-5 random frames from the video
+ random_frames = get_random_frames(cap, num_frames=10)
+ age_list = []
+
+ for frame in random_frames:
+ detected_objects, out_im, age = predictor.recognize(frame)
+ try:
+ age_list.append(age[0]) # Attempt to access the first element of age
+ if draw:
+ bname = os.path.splitext(os.path.basename(video_path))[0]
+ filename = os.path.join(output_folder, f"out_{bname}.jpg")
+ cv2.imwrite(filename, out_im)
+ _logger.info(f"Saved result to {filename}")
+ except IndexError:
+ continue
+
+ if len(age_list)==0:
+ raise ValueError("No person was detected in the frame. Please upload a proper face video.")
+
+
+
+ # Calculate and print average age
+ avg_age = sum(age_list) / len(age_list) if age_list else 0
+ print(f"Age list: {age_list}")
+ print(f"Average age: {avg_age:.2f}")
+ absolute_age = round(abs(avg_age))
+
+ # Define the range
+ lower_bound = absolute_age - 2
+ upper_bound = absolute_age + 2
+
+ return absolute_age, lower_bound, upper_bound
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+
+ absolute_age, lower_bound, upper_bound = main(args.input, args.output, args.detector_weights, args.checkpoint, args.device, args.with_persons, args.disable_faces ,args.draw)
+ # Output the results in the desired format
+ print(f"Absolute Age: {absolute_age}")
+ print(f"Range: {lower_bound} - {upper_bound}")
diff --git a/age_estimator/mivolo/eval_pretrained.py b/age_estimator/mivolo/eval_pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..567b5d5a464d4e8ca3234ba11e08148ecd2373ff
--- /dev/null
+++ b/age_estimator/mivolo/eval_pretrained.py
@@ -0,0 +1,232 @@
+import argparse
+import json
+import logging
+from typing import Tuple
+
+import matplotlib.pyplot as plt
+import seaborn as sns
+import torch
+from eval_tools import Metrics, time_sync, write_results
+from mivolo.data.dataset import build as build_data
+from mivolo.model.mi_volo import MiVOLO
+from timm.utils import setup_default_logging
+
+_logger = logging.getLogger("inference")
+LOG_FREQUENCY = 10
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="PyTorch MiVOLO Validation")
+ parser.add_argument("--dataset_images", default="", type=str, required=True, help="path to images")
+ parser.add_argument("--dataset_annotations", default="", type=str, required=True, help="path to annotations")
+ parser.add_argument(
+ "--dataset_name",
+ default=None,
+ type=str,
+ required=True,
+ choices=["utk", "imdb", "lagenda", "fairface", "adience", "agedb", "cacd"],
+ help="dataset name",
+ )
+ parser.add_argument("--split", default="validation", help="dataset splits separated by comma (default: validation)")
+ parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint")
+
+ parser.add_argument("--batch-size", default=64, type=int, help="batch size")
+ parser.add_argument(
+ "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
+ )
+ parser.add_argument("--device", default="cuda", type=str, help="Device (accelerator) to use.")
+ parser.add_argument("--l-for-cs", type=int, default=5, help="L for CS (cumulative score)")
+
+ parser.add_argument("--half", action="store_true", default=False, help="use half-precision model")
+ parser.add_argument(
+ "--with-persons", action="store_true", default=False, help="If the model will run with persons, if available"
+ )
+ parser.add_argument(
+ "--disable-faces", action="store_true", default=False, help="If the model will use only persons if available"
+ )
+
+ parser.add_argument("--draw-hist", action="store_true", help="Draws the hist of error by age")
+ parser.add_argument(
+ "--results-file",
+ default="",
+ type=str,
+ metavar="FILENAME",
+ help="Output csv file for validation results (summary)",
+ )
+ parser.add_argument(
+ "--results-format", default="csv", type=str, help="Format for results file one of (csv, json) (default: csv)."
+ )
+
+ return parser
+
+
+def process_batch(
+ mivolo_model: MiVOLO,
+ input: torch.tensor,
+ target: torch.tensor,
+ num_classes_gender: int = 2,
+):
+
+ start = time_sync()
+ output = mivolo_model.inference(input)
+ # target with age == -1 and gender == -1 marks that sample is not valid
+ assert not (all(target[:, 0] == -1) and all(target[:, 1] == -1))
+
+ if not mivolo_model.meta.only_age:
+ gender_out = output[:, :num_classes_gender]
+ gender_target = target[:, 1]
+ age_out = output[:, num_classes_gender:]
+ else:
+ age_out = output
+ gender_out, gender_target = None, None
+
+ # measure elapsed time
+ process_time = time_sync() - start
+
+ age_target = target[:, 0].unsqueeze(1)
+
+ return age_out, age_target, gender_out, gender_target, process_time
+
+
+def _filter_invalid_target(out: torch.tensor, target: torch.tensor):
+ # exclude samples where target gt == -1, that marks sample is not valid
+ mask = target != -1
+ return out[mask], target[mask]
+
+
+def postprocess_gender(gender_out: torch.tensor, gender_target: torch.tensor) -> Tuple[torch.tensor, torch.tensor]:
+ if gender_target is None:
+ return gender_out, gender_target
+ return _filter_invalid_target(gender_out, gender_target)
+
+
+def postprocess_age(age_out: torch.tensor, age_target: torch.tensor, dataset) -> Tuple[torch.tensor, torch.tensor]:
+ # Revert _norm_age() operation. Output is 2 float tensors
+
+ age_out, age_target = _filter_invalid_target(age_out, age_target)
+
+ age_out = age_out * (dataset.max_age - dataset.min_age) + dataset.avg_age
+ # clamp to 0 because age can be below zero
+ age_out = torch.clamp(age_out, min=0)
+
+ if dataset.age_classes is not None:
+ # classification case
+ age_out = torch.round(age_out)
+ if dataset._intervals.device != age_out.device:
+ dataset._intervals = dataset._intervals.to(age_out.device)
+ age_inds = torch.searchsorted(dataset._intervals, age_out, side="right") - 1
+ age_out = age_inds
+ else:
+ age_target = age_target * (dataset.max_age - dataset.min_age) + dataset.avg_age
+ return age_out, age_target
+
+
+def validate(args):
+
+ if torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.benchmark = True
+
+ mivolo_model = MiVOLO(
+ args.checkpoint,
+ args.device,
+ half=args.half,
+ use_persons=args.with_persons,
+ disable_faces=args.disable_faces,
+ verbose=True,
+ )
+
+ dataset, loader = build_data(
+ name=args.dataset_name,
+ images_path=args.dataset_images,
+ annotations_path=args.dataset_annotations,
+ split=args.split,
+ mivolo_model=mivolo_model, # to get meta information from model
+ workers=args.workers,
+ batch_size=args.batch_size,
+ )
+
+ d_stat = Metrics(args.l_for_cs, args.draw_hist, dataset.age_classes)
+
+ # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
+ mivolo_model.warmup(args.batch_size)
+
+ preproc_end = time_sync()
+ for batch_idx, (input, target) in enumerate(loader):
+
+ preprocess_time = time_sync() - preproc_end
+ # get output and calculate loss
+ age_out, age_target, gender_out, gender_target, process_time = process_batch(
+ mivolo_model, input, target, dataset.num_classes_gender
+ )
+
+ gender_out, gender_target = postprocess_gender(gender_out, gender_target)
+ age_out, age_target = postprocess_age(age_out, age_target, dataset)
+
+ d_stat.update_gender_accuracy(gender_out, gender_target)
+ if d_stat.is_regression:
+ d_stat.update_regression_age_metrics(age_out, age_target)
+ else:
+ d_stat.update_age_accuracy(age_out, age_target)
+ d_stat.update_time(process_time, preprocess_time, input.shape[0])
+
+ if batch_idx % LOG_FREQUENCY == 0:
+ _logger.info(
+ "Test: [{0:>4d}/{1}] " "{2}".format(batch_idx, len(loader), d_stat.get_info_str(input.size(0)))
+ )
+
+ preproc_end = time_sync()
+
+ # model info
+ results = dict(
+ model=args.checkpoint,
+ dataset_name=args.dataset_name,
+ param_count=round(mivolo_model.param_count / 1e6, 2),
+ img_size=mivolo_model.input_size,
+ use_faces=mivolo_model.meta.use_face_crops,
+ use_persons=mivolo_model.meta.use_persons,
+ in_chans=mivolo_model.meta.in_chans,
+ batch=args.batch_size,
+ )
+ # metrics info
+ results.update(d_stat.get_result())
+ return results
+
+
+def main():
+ parser = get_parser()
+ setup_default_logging()
+ args = parser.parse_args()
+
+ if torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.benchmark = True
+
+ results = validate(args)
+
+ result_str = " * Age Acc@1 {:.3f} ({:.3f})".format(results["agetop1"], results["agetop1_err"])
+ if "gendertop1" in results:
+ result_str += " Gender Acc@1 1 {:.3f} ({:.3f})".format(results["gendertop1"], results["gendertop1_err"])
+ result_str += " Mean inference time {:.3f} ms Mean preprocessing time {:.3f}".format(
+ results["mean_inference_time"], results["mean_preprocessing_time"]
+ )
+ _logger.info(result_str)
+
+ if args.draw_hist and "per_age_error" in results:
+ err = [sum(v) / len(v) for k, v in results["per_age_error"].items()]
+ ages = list(results["per_age_error"].keys())
+ sns.scatterplot(x=ages, y=err, hue=err)
+ plt.legend([], [], frameon=False)
+ plt.xlabel("Age")
+ plt.ylabel("MAE")
+ plt.savefig("age_error.png", dpi=300)
+
+ if args.results_file:
+ write_results(args.results_file, results, format=args.results_format)
+
+ # output results in JSON to stdout w/ delimiter for runner script
+ print(f"--result\n{json.dumps(results, indent=4)}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/age_estimator/mivolo/eval_tools.py b/age_estimator/mivolo/eval_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..db9a29973cc68796f98c80559f158ef21f812a65
--- /dev/null
+++ b/age_estimator/mivolo/eval_tools.py
@@ -0,0 +1,149 @@
+import csv
+import json
+import time
+from collections import OrderedDict, defaultdict
+
+import torch
+from mivolo.data.misc import cumulative_error, cumulative_score
+from timm.utils import AverageMeter, accuracy
+
+
+def time_sync():
+ # pytorch-accurate time
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ return time.time()
+
+
+def write_results(results_file, results, format="csv"):
+ with open(results_file, mode="w") as cf:
+ if format == "json":
+ json.dump(results, cf, indent=4)
+ else:
+ if not isinstance(results, (list, tuple)):
+ results = [results]
+ if not results:
+ return
+ dw = csv.DictWriter(cf, fieldnames=results[0].keys())
+ dw.writeheader()
+ for r in results:
+ dw.writerow(r)
+ cf.flush()
+
+
+class Metrics:
+ def __init__(self, l_for_cs, draw_hist, age_classes=None):
+ self.batch_time = AverageMeter()
+ self.preproc_batch_time = AverageMeter()
+ self.seen = 0
+
+ self.losses = AverageMeter()
+ self.top1_m_gender = AverageMeter()
+ self.top1_m_age = AverageMeter()
+
+ if age_classes is None:
+ self.is_regression = True
+ self.av_csl_age = AverageMeter()
+ self.max_error = AverageMeter()
+ self.per_age_error = defaultdict(list)
+ self.l_for_cs = l_for_cs
+ else:
+ self.is_regression = False
+
+ self.draw_hist = draw_hist
+
+ def update_regression_age_metrics(self, age_out, age_target):
+ batch_size = age_out.size(0)
+
+ age_abs_err = torch.abs(age_out - age_target)
+ age_acc1 = torch.sum(age_abs_err) / age_out.shape[0]
+ age_csl = cumulative_score(age_out, age_target, self.l_for_cs)
+ me = cumulative_error(age_out, age_target, 20)
+
+ self.top1_m_age.update(age_acc1.item(), batch_size)
+ self.av_csl_age.update(age_csl.item(), batch_size)
+ self.max_error.update(me.item(), batch_size)
+
+ if self.draw_hist:
+ for i in range(age_out.shape[0]):
+ self.per_age_error[int(age_target[i].item())].append(age_abs_err[i].item())
+
+ def update_age_accuracy(self, age_out, age_target):
+ batch_size = age_out.size(0)
+ if batch_size == 0:
+ return
+ correct = torch.sum(age_out == age_target)
+ age_acc1 = correct * 100.0 / batch_size
+ self.top1_m_age.update(age_acc1.item(), batch_size)
+
+ def update_gender_accuracy(self, gender_out, gender_target):
+ if gender_out is None or gender_out.size(0) == 0:
+ return
+ batch_size = gender_out.size(0)
+ gender_acc1 = accuracy(gender_out, gender_target, topk=(1,))[0]
+ if gender_acc1 is not None:
+ self.top1_m_gender.update(gender_acc1.item(), batch_size)
+
+ def update_loss(self, loss, batch_size):
+ self.losses.update(loss.item(), batch_size)
+
+ def update_time(self, process_time, preprocess_time, batch_size):
+ self.seen += batch_size
+ self.batch_time.update(process_time)
+ self.preproc_batch_time.update(preprocess_time)
+
+ def get_info_str(self, batch_size):
+ avg_time = (self.preproc_batch_time.sum + self.batch_time.sum) / self.batch_time.count
+ cur_time = self.batch_time.val + self.preproc_batch_time.val
+ middle_info = (
+ "Time: {cur_time:.3f}s ({avg_time:.3f}s, {rate_avg:>7.2f}/s) "
+ "Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) "
+ "Gender Acc: {top1gender.val:>7.2f} ({top1gender.avg:>7.2f}) ".format(
+ cur_time=cur_time,
+ avg_time=avg_time,
+ rate_avg=batch_size / avg_time,
+ loss=self.losses,
+ top1gender=self.top1_m_gender,
+ )
+ )
+
+ if self.is_regression:
+ age_info = (
+ "Age CS@{l_for_cs}: {csl.val:>7.4f} ({csl.avg:>7.4f}) "
+ "Age CE@20: {max_error.val:>7.4f} ({max_error.avg:>7.4f}) "
+ "Age ME: {top1age.val:>7.2f} ({top1age.avg:>7.2f})".format(
+ top1age=self.top1_m_age, csl=self.av_csl_age, max_error=self.max_error, l_for_cs=self.l_for_cs
+ )
+ )
+ else:
+ age_info = "Age Acc: {top1age.val:>7.2f} ({top1age.avg:>7.2f})".format(top1age=self.top1_m_age)
+
+ return middle_info + age_info
+
+ def get_result(self):
+ age_top1a = self.top1_m_age.avg
+ gender_top1 = self.top1_m_gender.avg if self.top1_m_gender.count > 0 else None
+
+ mean_per_image_time = self.batch_time.sum / self.seen
+ mean_preprocessing_time = self.preproc_batch_time.sum / self.seen
+
+ results = OrderedDict(
+ mean_inference_time=mean_per_image_time * 1e3,
+ mean_preprocessing_time=mean_preprocessing_time * 1e3,
+ agetop1=round(age_top1a, 4),
+ agetop1_err=round(100 - age_top1a, 4),
+ )
+
+ if self.is_regression:
+ results.update(
+ dict(
+ max_error=self.max_error.avg,
+ csl=self.av_csl_age.avg,
+ per_age_error=self.per_age_error,
+ )
+ )
+
+ if gender_top1 is not None:
+ results.update(dict(gendertop1=round(gender_top1, 4), gendertop1_err=round(100 - gender_top1, 4)))
+
+ return results
diff --git a/age_estimator/mivolo/images/MiVOLO.jpg b/age_estimator/mivolo/images/MiVOLO.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4dbe50c3eb596321671ebc3443cf01467deb2593
Binary files /dev/null and b/age_estimator/mivolo/images/MiVOLO.jpg differ
diff --git a/age_estimator/mivolo/infer.py b/age_estimator/mivolo/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e3d79c79d3cd4e4444b2f2afc8428f19c99d711
--- /dev/null
+++ b/age_estimator/mivolo/infer.py
@@ -0,0 +1,88 @@
+#!/usr/bin/env python
+import pathlib
+import os
+import huggingface_hub
+import numpy as np
+import argparse
+from dataclasses import dataclass
+from mivolo.predictor import Predictor
+from PIL import Image
+
+@dataclass
+class Cfg:
+ detector_weights: str
+ checkpoint: str
+ device: str = "cpu"
+ with_persons: bool = True
+ disable_faces: bool = False
+ draw: bool = True
+
+
+def load_models():
+ detector_path = huggingface_hub.hf_hub_download('iitolstykh/demo_yolov8_detector',
+ 'yolov8x_person_face.pt')
+
+ age_gender_path_v1 = 'age_estimator/MiVOLO-main/models/model_imdb_cross_person_4.22_99.46.pth.tar'
+ predictor_cfg_v1 = Cfg(detector_path, age_gender_path_v1)
+
+ predictor_v1 = Predictor(predictor_cfg_v1)
+
+ return predictor_v1
+
+def detect(image: np.ndarray, score_threshold: float, iou_threshold: float, mode: str, predictor: Predictor) -> np.ndarray:
+ predictor.detector.detector_kwargs['conf'] = score_threshold
+ predictor.detector.detector_kwargs['iou'] = iou_threshold
+
+ if mode == "Use persons and faces":
+ use_persons = True
+ disable_faces = False
+ elif mode == "Use persons only":
+ use_persons = True
+ disable_faces = True
+ elif mode == "Use faces only":
+ use_persons = False
+ disable_faces = False
+
+ predictor.age_gender_model.meta.use_persons = use_persons
+ predictor.age_gender_model.meta.disable_faces = disable_faces
+
+ image = image[:, :, ::-1] # RGB -> BGR for OpenCV
+ detected_objects, out_im = predictor.recognize(image)
+ return out_im[:, :, ::-1] # BGR -> RGB
+
+def load_image(image_path: str):
+ image = Image.open(image_path)
+ image_np = np.array(image)
+ return image_np
+
+def main(args):
+ # Load models
+ predictor_v1 = load_models()
+
+ # Set parameters from args
+ score_threshold = args.score_threshold
+ iou_threshold = args.iou_threshold
+ mode = args.mode
+
+ # Load and process image
+ image_np = load_image(args.image_path)
+
+ # Predict with model
+ result = detect(image_np, score_threshold, iou_threshold, mode, predictor_v1)
+
+ output_image = Image.fromarray(result)
+ output_image.save(args.output_path)
+ output_image.show()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Object Detection with YOLOv8 and Age/Gender Prediction')
+ parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
+ parser.add_argument('--output_path', type=str, default='output_image.jpg', help='Path to save the output image')
+ parser.add_argument('--score_threshold', type=float, default=0.4, help='Score threshold for detection')
+ parser.add_argument('--iou_threshold', type=float, default=0.7, help='IoU threshold for detection')
+ parser.add_argument('--mode', type=str, choices=["Use persons and faces", "Use persons only", "Use faces only"],
+ default="Use persons and faces", help='Detection mode')
+
+ args = parser.parse_args()
+ main(args)
+
diff --git a/age_estimator/mivolo/license/en_us.pdf b/age_estimator/mivolo/license/en_us.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..bbe714b77de6ce80dd7e09a086151688f33b6aff
Binary files /dev/null and b/age_estimator/mivolo/license/en_us.pdf differ
diff --git a/age_estimator/mivolo/license/ru.pdf b/age_estimator/mivolo/license/ru.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..a22946d5e9e6010cec1b21c0e219249df4c44fd2
Binary files /dev/null and b/age_estimator/mivolo/license/ru.pdf differ
diff --git a/age_estimator/mivolo/measure_time.py b/age_estimator/mivolo/measure_time.py
new file mode 100644
index 0000000000000000000000000000000000000000..c01f74ac729cde64e237c11b3183a380e6d3082a
--- /dev/null
+++ b/age_estimator/mivolo/measure_time.py
@@ -0,0 +1,77 @@
+import pandas as pd
+import torch
+import tqdm
+from eval_tools import time_sync
+from mivolo.model.create_timm_model import create_model
+
+if __name__ == "__main__":
+
+ face_person_ckpt_path = "/data/dataset/iikrasnova/age_gender/pretrained/checkpoint-377.pth.tar"
+ face_person_input_size = [6, 224, 224]
+
+ face_age_ckpt_path = "/data/dataset/iikrasnova/age_gender/pretrained/model_only_age_imdb_4.32.pth.tar"
+ face_input_size = [3, 224, 224]
+
+ model_names = ["face_body_model", "face_model"]
+ # batch_size = 16
+ steps = 1000
+ warmup_steps = 10
+ device = torch.device("cuda:1")
+
+ df_data = []
+ batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
+
+ for ckpt_path, input_size, model_name, num_classes in zip(
+ [face_person_ckpt_path, face_age_ckpt_path], [face_person_input_size, face_input_size], model_names, [3, 1]
+ ):
+
+ in_chans = input_size[0]
+ print(f"Collecting stat for {ckpt_path} ...")
+ model = create_model(
+ "mivolo_d1_224",
+ num_classes=num_classes,
+ in_chans=in_chans,
+ pretrained=False,
+ checkpoint_path=ckpt_path,
+ filter_keys=["fds."],
+ )
+ model = model.to(device)
+ model.eval()
+ model = model.half()
+
+ time_per_batch = {}
+ for batch_size in batch_sizes:
+ create_t0 = time_sync()
+ for _ in range(steps):
+ inputs = torch.randn((batch_size,) + tuple(input_size)).to(device).half()
+ create_t1 = time_sync()
+ create_taken = create_t1 - create_t0
+
+ with torch.no_grad():
+ inputs = torch.randn((batch_size,) + tuple(input_size)).to(device).half()
+ for _ in range(warmup_steps):
+ out = model(inputs)
+
+ all_time = 0
+ for _ in tqdm.tqdm(range(steps), desc=f"{model_name} batch {batch_size}"):
+ start = time_sync()
+ inputs = torch.randn((batch_size,) + tuple(input_size)).to(device).half()
+ out = model(inputs)
+ out += 1
+ end = time_sync()
+ all_time += end - start
+
+ time_taken = (all_time - create_taken) * 1000 / steps / batch_size
+ print(f"Inference {inputs.shape}, steps: {steps}. Mean time taken {time_taken} ms / image")
+
+ time_per_batch[str(batch_size)] = f"{time_taken:.2f}"
+ df_data.append(time_per_batch)
+
+ headers = list(map(str, batch_sizes))
+ output_df = pd.DataFrame(df_data, columns=headers)
+ output_df.index = model_names
+
+ df2_transposed = output_df.T
+ out_file = "batch_sizes.csv"
+ df2_transposed.to_csv(out_file, sep=",")
+ print(f"Saved time stat for {len(df2_transposed)} batches to {out_file}")
diff --git a/age_estimator/mivolo/mivolo/__init__.py b/age_estimator/mivolo/mivolo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/age_estimator/mivolo/mivolo/__pycache__/__init__.cpython-38.pyc b/age_estimator/mivolo/mivolo/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0551922d63de57d3a5afbe87a5c0c959a104ce8d
Binary files /dev/null and b/age_estimator/mivolo/mivolo/__pycache__/__init__.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/__pycache__/predictor.cpython-38.pyc b/age_estimator/mivolo/mivolo/__pycache__/predictor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f94779610f26d1daa5ee6c21909ce16d18cd7b8
Binary files /dev/null and b/age_estimator/mivolo/mivolo/__pycache__/predictor.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/__pycache__/structures.cpython-38.pyc b/age_estimator/mivolo/mivolo/__pycache__/structures.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a681475e2b883e03d909b735d056830675f4eafc
Binary files /dev/null and b/age_estimator/mivolo/mivolo/__pycache__/structures.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/data/__init__.py b/age_estimator/mivolo/mivolo/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/age_estimator/mivolo/mivolo/data/__pycache__/__init__.cpython-38.pyc b/age_estimator/mivolo/mivolo/data/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eeb2030d2141b89a5df92b36f1aa01c64af25015
Binary files /dev/null and b/age_estimator/mivolo/mivolo/data/__pycache__/__init__.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/data/__pycache__/data_reader.cpython-38.pyc b/age_estimator/mivolo/mivolo/data/__pycache__/data_reader.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ad45d32e93ae108962d7a069fa694385a4452c3
Binary files /dev/null and b/age_estimator/mivolo/mivolo/data/__pycache__/data_reader.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/data/__pycache__/misc.cpython-38.pyc b/age_estimator/mivolo/mivolo/data/__pycache__/misc.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee1b9aac35447e19e2c44442e081a63e372bc0a1
Binary files /dev/null and b/age_estimator/mivolo/mivolo/data/__pycache__/misc.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/data/data_reader.py b/age_estimator/mivolo/mivolo/data/data_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..44b819c6736d6a12fa0327d57152878dca1ebb0f
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/data/data_reader.py
@@ -0,0 +1,125 @@
+import os
+from collections import defaultdict
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Dict, List, Optional, Tuple
+
+import pandas as pd
+
+IMAGES_EXT: Tuple = (".jpeg", ".jpg", ".png", ".webp", ".bmp", ".gif")
+VIDEO_EXT: Tuple = (".mp4", ".avi", ".mov", ".mkv", ".webm")
+
+
+@dataclass
+class PictureInfo:
+ image_path: str
+ age: Optional[str] # age or age range(start;end format) or "-1"
+ gender: Optional[str] # "M" of "F" or "-1"
+ bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # face bbox: xyxy
+ person_bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # person bbox: xyxy
+
+ @property
+ def has_person_bbox(self) -> bool:
+ return any(coord != -1 for coord in self.person_bbox)
+
+ @property
+ def has_face_bbox(self) -> bool:
+ return any(coord != -1 for coord in self.bbox)
+
+ def has_gt(self, only_age: bool = False) -> bool:
+ if only_age:
+ return self.age != "-1"
+ else:
+ return not (self.age == "-1" and self.gender == "-1")
+
+ def clear_person_bbox(self):
+ self.person_bbox = [-1, -1, -1, -1]
+
+ def clear_face_bbox(self):
+ self.bbox = [-1, -1, -1, -1]
+
+
+class AnnotType(Enum):
+ ORIGINAL = "original"
+ PERSONS = "persons"
+ NONE = "none"
+
+ @classmethod
+ def _missing_(cls, value):
+ print(f"WARN: Unknown annotation type {value}.")
+ return AnnotType.NONE
+
+
+def get_all_files(path: str, extensions: Tuple = IMAGES_EXT):
+ files_all = []
+ for root, subFolders, files in os.walk(path):
+ for name in files:
+ # linux tricks with .directory that still is file
+ if "directory" not in name and sum([ext.lower() in name.lower() for ext in extensions]) > 0:
+ files_all.append(os.path.join(root, name))
+ return files_all
+
+
+class InputType(Enum):
+ Image = 0
+ Video = 1
+ VideoStream = 2
+
+
+def get_input_type(input_path: str) -> InputType:
+ if os.path.isdir(input_path):
+ print("Input is a folder, only images will be processed")
+ return InputType.Image
+ elif os.path.isfile(input_path):
+ if input_path.endswith(VIDEO_EXT):
+ return InputType.Video
+ if input_path.endswith(IMAGES_EXT):
+ return InputType.Image
+ else:
+ raise ValueError(
+ f"Unknown or unsupported input file format {input_path}, \
+ supported video formats: {VIDEO_EXT}, \
+ supported image formats: {IMAGES_EXT}"
+ )
+ elif input_path.startswith("http") and not input_path.endswith(IMAGES_EXT):
+ return InputType.VideoStream
+ else:
+ raise ValueError(f"Unknown input {input_path}")
+
+
+def read_csv_annotation_file(annotation_file: str, images_dir: str, ignore_without_gt=False):
+ bboxes_per_image: Dict[str, List[PictureInfo]] = defaultdict(list)
+
+ df = pd.read_csv(annotation_file, sep=",")
+
+ annot_type = AnnotType("persons") if "person_x0" in df.columns else AnnotType("original")
+ print(f"Reading {annotation_file} (type: {annot_type})...")
+
+ missing_images = 0
+ for index, row in df.iterrows():
+ img_path = os.path.join(images_dir, row["img_name"])
+ if not os.path.exists(img_path):
+ missing_images += 1
+ continue
+
+ face_x1, face_y1, face_x2, face_y2 = row["face_x0"], row["face_y0"], row["face_x1"], row["face_y1"]
+ age, gender = str(row["age"]), str(row["gender"])
+
+ if ignore_without_gt and (age == "-1" or gender == "-1"):
+ continue
+
+ if annot_type == AnnotType.PERSONS:
+ p_x1, p_y1, p_x2, p_y2 = row["person_x0"], row["person_y0"], row["person_x1"], row["person_y1"]
+ person_bbox = list(map(int, [p_x1, p_y1, p_x2, p_y2]))
+ else:
+ person_bbox = [-1, -1, -1, -1]
+
+ bbox = list(map(int, [face_x1, face_y1, face_x2, face_y2]))
+ pic_info = PictureInfo(img_path, age, gender, bbox, person_bbox)
+ assert isinstance(pic_info.person_bbox, list)
+
+ bboxes_per_image[img_path].append(pic_info)
+
+ if missing_images > 0:
+ print(f"WARNING: Missing images: {missing_images}/{len(df)}")
+ return bboxes_per_image, annot_type
diff --git a/age_estimator/mivolo/mivolo/data/dataset/__init__.py b/age_estimator/mivolo/mivolo/data/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c847b48bfb57251bd6e290a61059813c3b698a3
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/data/dataset/__init__.py
@@ -0,0 +1,66 @@
+from typing import Tuple
+
+import torch
+from mivolo.model.mi_volo import MiVOLO
+
+from .age_gender_dataset import AgeGenderDataset
+from .age_gender_loader import create_loader
+from .classification_dataset import AdienceDataset, FairFaceDataset
+
+DATASET_CLASS_MAP = {
+ "utk": AgeGenderDataset,
+ "lagenda": AgeGenderDataset,
+ "imdb": AgeGenderDataset,
+ "agedb": AgeGenderDataset,
+ "cacd": AgeGenderDataset,
+ "adience": AdienceDataset,
+ "fairface": FairFaceDataset,
+}
+
+
+def build(
+ name: str,
+ images_path: str,
+ annotations_path: str,
+ split: str,
+ mivolo_model: MiVOLO,
+ workers: int,
+ batch_size: int,
+) -> Tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]:
+
+ dataset_class = DATASET_CLASS_MAP[name]
+
+ dataset: torch.utils.data.Dataset = dataset_class(
+ images_path=images_path,
+ annotations_path=annotations_path,
+ name=name,
+ split=split,
+ target_size=mivolo_model.input_size,
+ max_age=mivolo_model.meta.max_age,
+ min_age=mivolo_model.meta.min_age,
+ model_with_persons=mivolo_model.meta.with_persons_model,
+ use_persons=mivolo_model.meta.use_persons,
+ disable_faces=mivolo_model.meta.disable_faces,
+ only_age=mivolo_model.meta.only_age,
+ )
+
+ data_config = mivolo_model.data_config
+
+ in_chans = 3 if not mivolo_model.meta.with_persons_model else 6
+ input_size = (in_chans, mivolo_model.input_size, mivolo_model.input_size)
+
+ dataset_loader: torch.utils.data.DataLoader = create_loader(
+ dataset,
+ input_size=input_size,
+ batch_size=batch_size,
+ mean=data_config["mean"],
+ std=data_config["std"],
+ num_workers=workers,
+ crop_pct=data_config["crop_pct"],
+ crop_mode=data_config["crop_mode"],
+ pin_memory=False,
+ device=mivolo_model.device,
+ target_type=dataset.target_dtype,
+ )
+
+ return dataset, dataset_loader
diff --git a/age_estimator/mivolo/mivolo/data/dataset/age_gender_dataset.py b/age_estimator/mivolo/mivolo/data/dataset/age_gender_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cb3cb9922ac7591f9c01f70fc16245b40483871
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/data/dataset/age_gender_dataset.py
@@ -0,0 +1,194 @@
+import logging
+from typing import Any, List, Optional, Set
+
+import cv2
+import numpy as np
+import torch
+from mivolo.data.dataset.reader_age_gender import ReaderAgeGender
+from PIL import Image
+from torchvision import transforms
+
+_logger = logging.getLogger("AgeGenderDataset")
+
+
+class AgeGenderDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ images_path,
+ annotations_path,
+ name=None,
+ split="train",
+ load_bytes=False,
+ img_mode="RGB",
+ transform=None,
+ is_training=False,
+ seed=1234,
+ target_size=224,
+ min_age=None,
+ max_age=None,
+ model_with_persons=False,
+ use_persons=False,
+ disable_faces=False,
+ only_age=False,
+ ):
+ reader = ReaderAgeGender(
+ images_path,
+ annotations_path,
+ split=split,
+ seed=seed,
+ target_size=target_size,
+ with_persons=use_persons,
+ disable_faces=disable_faces,
+ only_age=only_age,
+ )
+
+ self.name = name
+ self.model_with_persons = model_with_persons
+ self.reader = reader
+ self.load_bytes = load_bytes
+ self.img_mode = img_mode
+ self.transform = transform
+ self._consecutive_errors = 0
+ self.is_training = is_training
+ self.random_flip = 0.0
+
+ # Setting up classes.
+ # If min and max classes are passed - use them to have the same preprocessing for validation
+ self.max_age: float = None
+ self.min_age: float = None
+ self.avg_age: float = None
+ self.set_ages_min_max(min_age, max_age)
+
+ self.genders = ["M", "F"]
+ self.num_classes_gender = len(self.genders)
+
+ self.age_classes: Optional[List[str]] = self.set_age_classes()
+
+ self.num_classes_age = 1 if self.age_classes is None else len(self.age_classes)
+ self.num_classes: int = self.num_classes_age + self.num_classes_gender
+ self.target_dtype = torch.float32
+
+ def set_age_classes(self) -> Optional[List[str]]:
+ return None # for regression dataset
+
+ def set_ages_min_max(self, min_age: Optional[float], max_age: Optional[float]):
+
+ assert all(age is None for age in [min_age, max_age]) or all(
+ age is not None for age in [min_age, max_age]
+ ), "Both min and max age must be passed or none of them"
+
+ if max_age is not None and min_age is not None:
+ _logger.info(f"Received predefined min_age {min_age} and max_age {max_age}")
+ self.max_age = max_age
+ self.min_age = min_age
+ else:
+ # collect statistics from loaded dataset
+ all_ages_set: Set[int] = set()
+ for img_path, image_samples in self.reader._ann.items():
+ for image_sample_info in image_samples:
+ if image_sample_info.age == "-1":
+ continue
+ age = round(float(image_sample_info.age))
+ all_ages_set.add(age)
+
+ self.max_age = max(all_ages_set)
+ self.min_age = min(all_ages_set)
+
+ self.avg_age = (self.max_age + self.min_age) / 2.0
+
+ def _norm_age(self, age):
+ return (age - self.avg_age) / (self.max_age - self.min_age)
+
+ def parse_gender(self, _gender: str) -> float:
+ if _gender != "-1":
+ gender = float(0 if _gender == "M" or _gender == "0" else 1)
+ else:
+ gender = -1
+ return gender
+
+ def parse_target(self, _age: str, gender: str) -> List[Any]:
+ if _age != "-1":
+ age = round(float(_age))
+ age = self._norm_age(float(age))
+ else:
+ age = -1
+
+ target: List[float] = [age, self.parse_gender(gender)]
+ return target
+
+ @property
+ def transform(self):
+ return self._transform
+
+ @transform.setter
+ def transform(self, transform):
+ # Disable pretrained monkey-patched transforms
+ if not transform:
+ return
+
+ _trans = []
+ for trans in transform.transforms:
+ if "Resize" in str(trans):
+ continue
+ if "Crop" in str(trans):
+ continue
+ _trans.append(trans)
+ self._transform = transforms.Compose(_trans)
+
+ def apply_tranforms(self, image: Optional[np.ndarray]) -> np.ndarray:
+ if image is None:
+ return None
+
+ if self.transform is None:
+ return image
+
+ image = convert_to_pil(image, self.img_mode)
+ for trans in self.transform.transforms:
+ image = trans(image)
+ return image
+
+ def __getitem__(self, index):
+ # get preprocessed face and person crops (np.ndarray)
+ # resize + pad, for person crops: cut off other bboxes
+ images, target = self.reader[index]
+
+ target = self.parse_target(*target)
+
+ if self.model_with_persons:
+ face_image, person_image = images
+ person_image: np.ndarray = self.apply_tranforms(person_image)
+ else:
+ face_image = images[0]
+ person_image = None
+
+ face_image: np.ndarray = self.apply_tranforms(face_image)
+
+ if person_image is not None:
+ img = np.concatenate([face_image, person_image], axis=0)
+ else:
+ img = face_image
+
+ return img, target
+
+ def __len__(self):
+ return len(self.reader)
+
+ def filename(self, index, basename=False, absolute=False):
+ return self.reader.filename(index, basename, absolute)
+
+ def filenames(self, basename=False, absolute=False):
+ return self.reader.filenames(basename, absolute)
+
+
+def convert_to_pil(cv_im: Optional[np.ndarray], img_mode: str = "RGB") -> "Image":
+ if cv_im is None:
+ return None
+
+ if img_mode == "RGB":
+ cv_im = cv2.cvtColor(cv_im, cv2.COLOR_BGR2RGB)
+ else:
+ raise Exception("Incorrect image mode has been passed!")
+
+ cv_im = np.ascontiguousarray(cv_im)
+ pil_image = Image.fromarray(cv_im)
+ return pil_image
diff --git a/age_estimator/mivolo/mivolo/data/dataset/age_gender_loader.py b/age_estimator/mivolo/mivolo/data/dataset/age_gender_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..260e310de7c4a33dc11c6b0809f000940199a5c3
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/data/dataset/age_gender_loader.py
@@ -0,0 +1,169 @@
+"""
+Code adapted from timm https://github.com/huggingface/pytorch-image-models
+
+Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
+"""
+
+import logging
+from contextlib import suppress
+from functools import partial
+from itertools import repeat
+
+import numpy as np
+import torch
+import torch.utils.data
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.data.dataset import IterableImageDataset
+from timm.data.loader import PrefetchLoader, _worker_init
+from timm.data.transforms_factory import create_transform
+
+_logger = logging.getLogger(__name__)
+
+
+def fast_collate(batch, target_dtype=torch.uint8):
+ """A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)"""
+ assert isinstance(batch[0], tuple)
+ batch_size = len(batch)
+ if isinstance(batch[0][0], np.ndarray):
+ targets = torch.tensor([b[1] for b in batch], dtype=target_dtype)
+ assert len(targets) == batch_size
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
+ for i in range(batch_size):
+ tensor[i] += torch.from_numpy(batch[i][0])
+ return tensor, targets
+ else:
+ raise ValueError(f"Incorrect batch type: {type(batch[0][0])}")
+
+
+def adapt_to_chs(x, n):
+ if not isinstance(x, (tuple, list)):
+ x = tuple(repeat(x, n))
+ elif len(x) != n:
+ # doubled channels
+ if len(x) * 2 == n:
+ x = np.concatenate((x, x))
+ _logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.")
+ else:
+ x_mean = np.mean(x).item()
+ x = (x_mean,) * n
+ _logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.")
+ else:
+ assert len(x) == n, "normalization stats must match image channels"
+ return x
+
+
+class PrefetchLoaderForMultiInput(PrefetchLoader):
+ def __init__(
+ self,
+ loader,
+ mean=IMAGENET_DEFAULT_MEAN,
+ std=IMAGENET_DEFAULT_STD,
+ channels=3,
+ device=torch.device("cuda"),
+ img_dtype=torch.float32,
+ ):
+
+ mean = adapt_to_chs(mean, channels)
+ std = adapt_to_chs(std, channels)
+ normalization_shape = (1, channels, 1, 1)
+
+ self.loader = loader
+ self.device = device
+ self.img_dtype = img_dtype
+ self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
+ self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
+
+ self.is_cuda = torch.cuda.is_available() and device.type == "cuda"
+
+ def __iter__(self):
+ first = True
+ if self.is_cuda:
+ stream = torch.cuda.Stream()
+ stream_context = partial(torch.cuda.stream, stream=stream)
+ else:
+ stream = None
+ stream_context = suppress
+
+ for next_input, next_target in self.loader:
+
+ with stream_context():
+ next_input = next_input.to(device=self.device, non_blocking=True)
+ next_target = next_target.to(device=self.device, non_blocking=True)
+ next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
+
+ if not first:
+ yield input, target # noqa: F823, F821
+ else:
+ first = False
+
+ if stream is not None:
+ torch.cuda.current_stream().wait_stream(stream)
+
+ input = next_input
+ target = next_target
+
+ yield input, target
+
+
+def create_loader(
+ dataset,
+ input_size,
+ batch_size,
+ mean=IMAGENET_DEFAULT_MEAN,
+ std=IMAGENET_DEFAULT_STD,
+ num_workers=1,
+ crop_pct=None,
+ crop_mode=None,
+ pin_memory=False,
+ img_dtype=torch.float32,
+ device=torch.device("cuda"),
+ persistent_workers=True,
+ worker_seeding="all",
+ target_type=torch.int64,
+):
+
+ transform = create_transform(
+ input_size,
+ is_training=False,
+ use_prefetcher=True,
+ mean=mean,
+ std=std,
+ crop_pct=crop_pct,
+ crop_mode=crop_mode,
+ )
+ dataset.transform = transform
+
+ if isinstance(dataset, IterableImageDataset):
+ # give Iterable datasets early knowledge of num_workers so that sample estimates
+ # are correct before worker processes are launched
+ dataset.set_loader_cfg(num_workers=num_workers)
+ raise ValueError("Incorrect dataset type: IterableImageDataset")
+
+ loader_class = torch.utils.data.DataLoader
+ loader_args = dict(
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ sampler=None,
+ collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type),
+ pin_memory=pin_memory,
+ drop_last=False,
+ worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
+ persistent_workers=persistent_workers,
+ )
+ try:
+ loader = loader_class(dataset, **loader_args)
+ except TypeError:
+ loader_args.pop("persistent_workers") # only in Pytorch 1.7+
+ loader = loader_class(dataset, **loader_args)
+
+ loader = PrefetchLoaderForMultiInput(
+ loader,
+ mean=mean,
+ std=std,
+ channels=input_size[0],
+ device=device,
+ img_dtype=img_dtype,
+ )
+
+ return loader
diff --git a/age_estimator/mivolo/mivolo/data/dataset/classification_dataset.py b/age_estimator/mivolo/mivolo/data/dataset/classification_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2157dd548ef3749daa6704b597bfd4776781fe2c
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/data/dataset/classification_dataset.py
@@ -0,0 +1,47 @@
+from typing import Any, List, Optional
+
+import torch
+
+from .age_gender_dataset import AgeGenderDataset
+
+
+class ClassificationDataset(AgeGenderDataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.target_dtype = torch.int32
+
+ def set_age_classes(self) -> Optional[List[str]]:
+ raise NotImplementedError
+
+ def parse_target(self, age: str, gender: str) -> List[Any]:
+ assert self.age_classes is not None
+ if age != "-1":
+ assert age in self.age_classes, f"Unknown category in {self.name} dataset: {age}"
+ age_ind = self.age_classes.index(age)
+ else:
+ age_ind = -1
+
+ target: List[int] = [age_ind, int(self.parse_gender(gender))]
+ return target
+
+
+class FairFaceDataset(ClassificationDataset):
+ def set_age_classes(self) -> Optional[List[str]]:
+ age_classes = ["0;2", "3;9", "10;19", "20;29", "30;39", "40;49", "50;59", "60;69", "70;120"]
+ # a[i-1] <= v < a[i] => age_classes[i-1]
+ self._intervals = torch.tensor([0, 3, 10, 20, 30, 40, 50, 60, 70])
+ return age_classes
+
+
+class AdienceDataset(ClassificationDataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.target_dtype = torch.int32
+
+ def set_age_classes(self) -> Optional[List[str]]:
+ age_classes = ["0;2", "4;6", "8;12", "15;20", "25;32", "38;43", "48;53", "60;100"]
+ # a[i-1] <= v < a[i] => age_classes[i-1]
+ self._intervals = torch.tensor([0, 4, 7, 14, 24, 36, 46, 57])
+ return age_classes
diff --git a/age_estimator/mivolo/mivolo/data/dataset/reader_age_gender.py b/age_estimator/mivolo/mivolo/data/dataset/reader_age_gender.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4b7d82b25fceeb56791ca5a77f4d6e8db72c2f5
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/data/dataset/reader_age_gender.py
@@ -0,0 +1,492 @@
+import logging
+import os
+from functools import partial
+from multiprocessing.pool import ThreadPool
+from typing import Dict, List, Optional, Tuple
+
+import cv2
+import numpy as np
+from mivolo.data.data_reader import AnnotType, PictureInfo, get_all_files, read_csv_annotation_file
+from mivolo.data.misc import IOU, class_letterbox
+from timm.data.readers.reader import Reader
+from tqdm import tqdm
+
+CROP_ROUND_TOL = 0.3
+MIN_PERSON_SIZE = 100
+MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
+
+_logger = logging.getLogger("ReaderAgeGender")
+
+
+class ReaderAgeGender(Reader):
+ """
+ Reader for almost original imdb-wiki cleaned dataset.
+ Two changes:
+ 1. Your annotation must be in ./annotation subdir of dataset root
+ 2. Images must be in images subdir
+
+ """
+
+ def __init__(
+ self,
+ images_path,
+ annotations_path,
+ split="validation",
+ target_size=224,
+ min_size=5,
+ seed=1234,
+ with_persons=False,
+ min_person_size=MIN_PERSON_SIZE,
+ disable_faces=False,
+ only_age=False,
+ min_person_aftercut_ratio=MIN_PERSON_CROP_AFTERCUT_RATIO,
+ crop_round_tol=CROP_ROUND_TOL,
+ ):
+ super().__init__()
+
+ self.with_persons = with_persons
+ self.disable_faces = disable_faces
+ self.only_age = only_age
+
+ # can be only black for now, even though it's not very good with further normalization
+ self.crop_out_color = (0, 0, 0)
+
+ self.empty_crop = np.ones((target_size, target_size, 3)) * self.crop_out_color
+ self.empty_crop = self.empty_crop.astype(np.uint8)
+
+ self.min_person_size = min_person_size
+ self.min_person_aftercut_ratio = min_person_aftercut_ratio
+ self.crop_round_tol = crop_round_tol
+
+ splits = split.split(",")
+ self.splits = [split.strip() for split in splits if len(split.strip())]
+ assert len(self.splits), "Incorrect split arg"
+
+ self.min_size = min_size
+ self.seed = seed
+ self.target_size = target_size
+
+ # Reading annotations. Can be multiple files if annotations_path dir
+ self._ann: Dict[str, List[PictureInfo]] = {} # list of samples for each image
+ self._associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
+ self._faces_list: List[Tuple[str, int]] = [] # samples from this list will be loaded in __getitem__
+
+ self._read_annotations(images_path, annotations_path)
+ _logger.info(f"Dataset length: {len(self._faces_list)} crops")
+
+ def __getitem__(self, index):
+ return self._read_img_and_label(index)
+
+ def __len__(self):
+ return len(self._faces_list)
+
+ def _filename(self, index, basename=False, absolute=False):
+ img_p = self._faces_list[index][0]
+ return os.path.basename(img_p) if basename else img_p
+
+ def _read_annotations(self, images_path, csvs_path):
+ self._ann = {}
+ self._faces_list = []
+ self._associated_objects = {}
+
+ csvs = get_all_files(csvs_path, [".csv"])
+ csvs = [c for c in csvs if any(split_name in os.path.basename(c) for split_name in self.splits)]
+
+ # load annotations per image
+ for csv in csvs:
+ db, ann_type = read_csv_annotation_file(csv, images_path)
+ if self.with_persons and ann_type != AnnotType.PERSONS:
+ raise ValueError(
+ f"Annotation type in file {csv} contains no persons, "
+ f"but annotations with persons are requested."
+ )
+ self._ann.update(db)
+
+ if len(self._ann) == 0:
+ raise ValueError("Annotations are empty!")
+
+ self._ann, self._associated_objects = self.prepare_annotations()
+ images_list = list(self._ann.keys())
+
+ for img_path in images_list:
+ for index, image_sample_info in enumerate(self._ann[img_path]):
+ assert image_sample_info.has_gt(
+ self.only_age
+ ), "Annotations must be checked with self.prepare_annotations() func"
+ self._faces_list.append((img_path, index))
+
+ def _read_img_and_label(self, index):
+ if not isinstance(index, int):
+ raise TypeError("ReaderAgeGender expected index to be integer")
+
+ img_p, face_index = self._faces_list[index]
+ ann: PictureInfo = self._ann[img_p][face_index]
+ img = cv2.imread(img_p)
+
+ face_empty = True
+ if ann.has_face_bbox and not (self.with_persons and self.disable_faces):
+ face_crop, face_empty = self._get_crop(ann.bbox, img)
+
+ if not self.with_persons and face_empty:
+ # model without persons
+ raise ValueError("Annotations must be checked with self.prepare_annotations() func")
+
+ if face_empty:
+ face_crop = self.empty_crop
+
+ person_empty = True
+ if self.with_persons or self.disable_faces:
+ if ann.has_person_bbox:
+ # cut off all associated objects from person crop
+ objects = self._associated_objects[img_p][face_index]
+ person_crop, person_empty = self._get_crop(
+ ann.person_bbox,
+ img,
+ crop_out_color=self.crop_out_color,
+ asced_objects=objects,
+ )
+
+ if face_empty and person_empty:
+ raise ValueError("Annotations must be checked with self.prepare_annotations() func")
+
+ if person_empty:
+ person_crop = self.empty_crop
+
+ return (face_crop, person_crop), [ann.age, ann.gender]
+
+ def _get_crop(
+ self,
+ bbox,
+ img,
+ asced_objects=None,
+ crop_out_color=(0, 0, 0),
+ ) -> Tuple[np.ndarray, bool]:
+
+ empty_bbox = False
+
+ xmin, ymin, xmax, ymax = bbox
+ assert not (
+ ymax - ymin < self.min_size or xmax - xmin < self.min_size
+ ), "Annotations must be checked with self.prepare_annotations() func"
+
+ crop = img[ymin:ymax, xmin:xmax]
+
+ if asced_objects:
+ # cut off other objects for person crop
+ crop, empty_bbox = _cropout_asced_objs(
+ asced_objects,
+ bbox,
+ crop.copy(),
+ crop_out_color=crop_out_color,
+ min_person_size=self.min_person_size,
+ crop_round_tol=self.crop_round_tol,
+ min_person_aftercut_ratio=self.min_person_aftercut_ratio,
+ )
+ if empty_bbox:
+ crop = self.empty_crop
+
+ crop = class_letterbox(crop, new_shape=(self.target_size, self.target_size), color=crop_out_color)
+ return crop, empty_bbox
+
+ def prepare_annotations(self):
+
+ good_anns: Dict[str, List[PictureInfo]] = {}
+ all_associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
+
+ if not self.with_persons:
+ # remove all persons
+ for img_path, bboxes in self._ann.items():
+ for sample in bboxes:
+ sample.clear_person_bbox()
+
+ # check dataset and collect associated_objects
+ verify_images_func = partial(
+ verify_images,
+ min_size=self.min_size,
+ min_person_size=self.min_person_size,
+ with_persons=self.with_persons,
+ disable_faces=self.disable_faces,
+ crop_round_tol=self.crop_round_tol,
+ min_person_aftercut_ratio=self.min_person_aftercut_ratio,
+ only_age=self.only_age,
+ )
+ num_threads = min(8, os.cpu_count())
+
+ all_msgs = []
+ broken = 0
+ skipped = 0
+ all_skipped_crops = 0
+ desc = "Check annotations..."
+ with ThreadPool(num_threads) as pool:
+ pbar = tqdm(
+ pool.imap_unordered(verify_images_func, list(self._ann.items())),
+ desc=desc,
+ total=len(self._ann),
+ )
+
+ for (img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops) in pbar:
+ broken += 1 if is_corrupted else 0
+ all_msgs.extend(msgs)
+ all_skipped_crops += skipped_crops
+ skipped += 1 if is_empty_annotations else 0
+ if img_info is not None:
+ img_path, img_samples = img_info
+ good_anns[img_path] = img_samples
+ all_associated_objects.update({img_path: associated_objects})
+
+ pbar.desc = (
+ f"{desc} {skipped} images skipped ({all_skipped_crops} crops are incorrect); "
+ f"{broken} images corrupted"
+ )
+
+ pbar.close()
+
+ for msg in all_msgs:
+ print(msg)
+ print(f"\nLeft images: {len(good_anns)}")
+
+ return good_anns, all_associated_objects
+
+
+def verify_images(
+ img_info,
+ min_size: int,
+ min_person_size: int,
+ with_persons: bool,
+ disable_faces: bool,
+ crop_round_tol: float,
+ min_person_aftercut_ratio: float,
+ only_age: bool,
+):
+ # If crop is too small, if image can not be read or if image does not exist
+ # then filter out this sample
+
+ disable_faces = disable_faces and with_persons
+ kwargs = dict(
+ min_person_size=min_person_size,
+ disable_faces=disable_faces,
+ with_persons=with_persons,
+ crop_round_tol=crop_round_tol,
+ min_person_aftercut_ratio=min_person_aftercut_ratio,
+ only_age=only_age,
+ )
+
+ def bbox_correct(bbox, min_size, im_h, im_w) -> Tuple[bool, List[int]]:
+ ymin, ymax, xmin, xmax = _correct_bbox(bbox, im_h, im_w)
+ crop_h, crop_w = ymax - ymin, xmax - xmin
+ if crop_h < min_size or crop_w < min_size:
+ return False, [-1, -1, -1, -1]
+ bbox = [xmin, ymin, xmax, ymax]
+ return True, bbox
+
+ msgs = []
+ skipped_crops = 0
+ is_corrupted = False
+ is_empty_annotations = False
+
+ img_path: str = img_info[0]
+ img_samples: List[PictureInfo] = img_info[1]
+ try:
+ im_cv = cv2.imread(img_path)
+ im_h, im_w = im_cv.shape[:2]
+ except Exception:
+ msgs.append(f"Can not load image {img_path}")
+ is_corrupted = True
+ return None, {}, msgs, is_corrupted, is_empty_annotations, skipped_crops
+
+ out_samples: List[PictureInfo] = []
+ for sample in img_samples:
+ # correct face bbox
+ if sample.has_face_bbox:
+ is_correct, sample.bbox = bbox_correct(sample.bbox, min_size, im_h, im_w)
+ if not is_correct and sample.has_gt(only_age):
+ msgs.append("Small face. Passing..")
+ skipped_crops += 1
+
+ # correct person bbox
+ if sample.has_person_bbox:
+ is_correct, sample.person_bbox = bbox_correct(
+ sample.person_bbox, max(min_person_size, min_size), im_h, im_w
+ )
+ if not is_correct and sample.has_gt(only_age):
+ msgs.append(f"Small person {img_path}. Passing..")
+ skipped_crops += 1
+
+ if sample.has_face_bbox or sample.has_person_bbox:
+ out_samples.append(sample)
+ elif sample.has_gt(only_age):
+ msgs.append("Sample has no face and no body. Passing..")
+ skipped_crops += 1
+
+ # sort that samples with undefined age and gender be the last
+ out_samples = sorted(out_samples, key=lambda sample: 1 if not sample.has_gt(only_age) else 0)
+
+ # for each person find other faces and persons bboxes, intersected with it
+ associated_objects: Dict[int, List[List[int]]] = find_associated_objects(out_samples, only_age=only_age)
+
+ out_samples, associated_objects, skipped_crops = filter_bad_samples(
+ out_samples, associated_objects, im_cv, msgs, skipped_crops, **kwargs
+ )
+
+ out_img_info: Optional[Tuple[str, List]] = (img_path, out_samples)
+ if len(out_samples) == 0:
+ out_img_info = None
+ is_empty_annotations = True
+
+ return out_img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops
+
+
+def filter_bad_samples(
+ out_samples: List[PictureInfo],
+ associated_objects: dict,
+ im_cv: np.ndarray,
+ msgs: List[str],
+ skipped_crops: int,
+ **kwargs,
+):
+ with_persons, disable_faces, min_person_size, crop_round_tol, min_person_aftercut_ratio, only_age = (
+ kwargs["with_persons"],
+ kwargs["disable_faces"],
+ kwargs["min_person_size"],
+ kwargs["crop_round_tol"],
+ kwargs["min_person_aftercut_ratio"],
+ kwargs["only_age"],
+ )
+
+ # left only samples with annotations
+ inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_gt(only_age)]
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
+
+ if kwargs["disable_faces"]:
+ # clear all faces
+ for ind, sample in enumerate(out_samples):
+ sample.clear_face_bbox()
+
+ # left only samples with person_bbox
+ inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_person_bbox]
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
+
+ if with_persons or disable_faces:
+ # check that preprocessing func
+ # _cropout_asced_objs() return not empty person_image for each out sample
+
+ inds = []
+ for ind, sample in enumerate(out_samples):
+ person_empty = True
+ if sample.has_person_bbox:
+ xmin, ymin, xmax, ymax = sample.person_bbox
+ crop = im_cv[ymin:ymax, xmin:xmax]
+ # cut off all associated objects from person crop
+ _, person_empty = _cropout_asced_objs(
+ associated_objects[ind],
+ sample.person_bbox,
+ crop.copy(),
+ min_person_size=min_person_size,
+ crop_round_tol=crop_round_tol,
+ min_person_aftercut_ratio=min_person_aftercut_ratio,
+ )
+
+ if person_empty and not sample.has_face_bbox:
+ msgs.append("Small person after preprocessing. Passing..")
+ skipped_crops += 1
+ else:
+ inds.append(ind)
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
+
+ assert len(associated_objects) == len(out_samples)
+ return out_samples, associated_objects, skipped_crops
+
+
+def _filter_by_ind(out_samples, associated_objects, inds):
+ _associated_objects = {}
+ _out_samples = []
+ for ind, sample in enumerate(out_samples):
+ if ind in inds:
+ _associated_objects[len(_out_samples)] = associated_objects[ind]
+ _out_samples.append(sample)
+
+ return _out_samples, _associated_objects
+
+
+def find_associated_objects(
+ image_samples: List[PictureInfo], iou_thresh=0.0001, only_age=False
+) -> Dict[int, List[List[int]]]:
+ """
+ For each person (which has gt age and gt gender) find other faces and persons bboxes, intersected with it
+ """
+ associated_objects: Dict[int, List[List[int]]] = {}
+
+ for iindex, image_sample_info in enumerate(image_samples):
+ # add own face
+ associated_objects[iindex] = [image_sample_info.bbox] if image_sample_info.has_face_bbox else []
+
+ if not image_sample_info.has_person_bbox or not image_sample_info.has_gt(only_age):
+ # if sample has not gt => not be used
+ continue
+
+ iperson_box = image_sample_info.person_bbox
+ for jindex, other_image_sample in enumerate(image_samples):
+ if iindex == jindex:
+ continue
+ if other_image_sample.has_face_bbox:
+ jface_bbox = other_image_sample.bbox
+ iou = _get_iou(jface_bbox, iperson_box)
+ if iou >= iou_thresh:
+ associated_objects[iindex].append(jface_bbox)
+ if other_image_sample.has_person_bbox:
+ jperson_bbox = other_image_sample.person_bbox
+ iou = _get_iou(jperson_bbox, iperson_box)
+ if iou >= iou_thresh:
+ associated_objects[iindex].append(jperson_bbox)
+
+ return associated_objects
+
+
+def _cropout_asced_objs(
+ asced_objects,
+ person_bbox,
+ crop,
+ min_person_size,
+ crop_round_tol,
+ min_person_aftercut_ratio,
+ crop_out_color=(0, 0, 0),
+):
+ empty = False
+ xmin, ymin, xmax, ymax = person_bbox
+
+ for a_obj in asced_objects:
+ aobj_xmin, aobj_ymin, aobj_xmax, aobj_ymax = a_obj
+
+ aobj_ymin = int(max(aobj_ymin - ymin, 0))
+ aobj_xmin = int(max(aobj_xmin - xmin, 0))
+ aobj_ymax = int(min(aobj_ymax - ymin, ymax - ymin))
+ aobj_xmax = int(min(aobj_xmax - xmin, xmax - xmin))
+
+ crop[aobj_ymin:aobj_ymax, aobj_xmin:aobj_xmax] = crop_out_color
+
+ # calc useful non-black area
+ remain_ratio = np.count_nonzero(crop) / (crop.shape[0] * crop.shape[1] * crop.shape[2])
+ if (crop.shape[0] < min_person_size or crop.shape[1] < min_person_size) or remain_ratio < min_person_aftercut_ratio:
+ crop = None
+ empty = True
+
+ return crop, empty
+
+
+def _correct_bbox(bbox, h, w):
+ xmin, ymin, xmax, ymax = bbox
+ ymin = min(max(ymin, 0), h)
+ ymax = min(max(ymax, 0), h)
+ xmin = min(max(xmin, 0), w)
+ xmax = min(max(xmax, 0), w)
+ return ymin, ymax, xmin, xmax
+
+
+def _get_iou(bbox1, bbox2):
+ xmin1, ymin1, xmax1, ymax1 = bbox1
+ xmin2, ymin2, xmax2, ymax2 = bbox2
+ iou = IOU(
+ [ymin1, xmin1, ymax1, xmax1],
+ [ymin2, xmin2, ymax2, xmax2],
+ )
+ return iou
diff --git a/age_estimator/mivolo/mivolo/data/misc.py b/age_estimator/mivolo/mivolo/data/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..603223b90a3b296cfd7f6da9044e1cb288e03621
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/data/misc.py
@@ -0,0 +1,246 @@
+import argparse
+import ast
+import re
+from typing import List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms.functional as F
+from scipy.optimize import linear_sum_assignment
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+
+CROP_ROUND_RATE = 0.1
+MIN_PERSON_CROP_NONZERO = 0.5
+
+
+def aggregate_votes_winsorized(ages, max_age_dist=6):
+ # Replace any annotation that is more than a max_age_dist away from the median
+ # with the median + max_age_dist if higher or max_age_dist - max_age_dist if below
+ median = np.median(ages)
+ ages = np.clip(ages, median - max_age_dist, median + max_age_dist)
+ return np.mean(ages)
+
+
+def natural_key(string_):
+ """See http://www.codinghorror.com/blog/archives/001018.html"""
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
+
+
+def add_bool_arg(parser, name, default=False, help=""):
+ dest_name = name.replace("-", "_")
+ group = parser.add_mutually_exclusive_group(required=False)
+ group.add_argument("--" + name, dest=dest_name, action="store_true", help=help)
+ group.add_argument("--no-" + name, dest=dest_name, action="store_false", help=help)
+ parser.set_defaults(**{dest_name: default})
+
+
+def cumulative_score(pred_ages, gt_ages, L, tol=1e-6):
+ n = pred_ages.shape[0]
+ num_correct = torch.sum(torch.abs(pred_ages - gt_ages) <= L + tol)
+ cs_score = num_correct / n
+ return cs_score
+
+
+def cumulative_error(pred_ages, gt_ages, L, tol=1e-6):
+ n = pred_ages.shape[0]
+ num_correct = torch.sum(torch.abs(pred_ages - gt_ages) >= L + tol)
+ cs_score = num_correct / n
+ return cs_score
+
+
+class ParseKwargs(argparse.Action):
+ def __call__(self, parser, namespace, values, option_string=None):
+ kw = {}
+ for value in values:
+ key, value = value.split("=")
+ try:
+ kw[key] = ast.literal_eval(value)
+ except ValueError:
+ kw[key] = str(value) # fallback to string (avoid need to escape on command line)
+ setattr(namespace, self.dest, kw)
+
+
+def box_iou(box1, box2, over_second=False):
+ """
+ Return intersection-over-union (Jaccard index) of boxes.
+ If over_second == True, return mean(intersection-over-union, (inter / area2))
+
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+
+ Arguments:
+ box1 (Tensor[N, 4])
+ box2 (Tensor[M, 4])
+ Returns:
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
+ IoU values for every element in boxes1 and boxes2
+ """
+
+ def box_area(box):
+ # box = 4xn
+ return (box[2] - box[0]) * (box[3] - box[1])
+
+ area1 = box_area(box1.T)
+ area2 = box_area(box2.T)
+
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
+
+ iou = inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
+ if over_second:
+ return (inter / area2 + iou) / 2 # mean(inter / area2, iou)
+ else:
+ return iou
+
+
+def split_batch(bs: int, dev: int) -> Tuple[int, int]:
+ full_bs = (bs // dev) * dev
+ part_bs = bs - full_bs
+ return full_bs, part_bs
+
+
+def assign_faces(
+ persons_bboxes: List[torch.tensor], faces_bboxes: List[torch.tensor], iou_thresh: float = 0.0001
+) -> Tuple[List[Optional[int]], List[int]]:
+ """
+ Assign person to each face if it is possible.
+ Return:
+ - assigned_faces List[Optional[int]]: mapping of face_ind to person_ind
+ ( assigned_faces[face_ind] = person_ind ). person_ind can be None
+ - unassigned_persons_inds List[int]: persons indexes without any assigned face
+ """
+
+ assigned_faces: List[Optional[int]] = [None for _ in range(len(faces_bboxes))]
+ unassigned_persons_inds: List[int] = [p_ind for p_ind in range(len(persons_bboxes))]
+
+ if len(persons_bboxes) == 0 or len(faces_bboxes) == 0:
+ return assigned_faces, unassigned_persons_inds
+
+ cost_matrix = box_iou(torch.stack(persons_bboxes), torch.stack(faces_bboxes), over_second=True).cpu().numpy()
+ persons_indexes, face_indexes = [], []
+
+ if len(cost_matrix) > 0:
+ persons_indexes, face_indexes = linear_sum_assignment(cost_matrix, maximize=True)
+
+ matched_persons = set()
+ for person_idx, face_idx in zip(persons_indexes, face_indexes):
+ ciou = cost_matrix[person_idx][face_idx]
+ if ciou > iou_thresh:
+ if person_idx in matched_persons:
+ # Person can not be assigned twice, in reality this should not happen
+ continue
+ assigned_faces[face_idx] = person_idx
+ matched_persons.add(person_idx)
+
+ unassigned_persons_inds = [p_ind for p_ind in range(len(persons_bboxes)) if p_ind not in matched_persons]
+
+ return assigned_faces, unassigned_persons_inds
+
+
+def class_letterbox(im, new_shape=(640, 640), color=(0, 0, 0), scaleup=True):
+ # Resize and pad image while meeting stride-multiple constraints
+ shape = im.shape[:2] # current shape [height, width]
+ if isinstance(new_shape, int):
+ new_shape = (new_shape, new_shape)
+
+ if im.shape[0] == new_shape[0] and im.shape[1] == new_shape[1]:
+ return im
+
+ # Scale ratio (new / old)
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
+ r = min(r, 1.0)
+
+ # Compute padding
+ # ratio = r, r # width, height ratios
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+
+ dw /= 2 # divide padding into 2 sides
+ dh /= 2
+
+ if shape[::-1] != new_unpad: # resize
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ return im
+
+
+def prepare_classification_images(
+ img_list: List[Optional[np.ndarray]],
+ target_size: int = 224,
+ mean=IMAGENET_DEFAULT_MEAN,
+ std=IMAGENET_DEFAULT_STD,
+ device=None,
+) -> torch.tensor:
+
+ prepared_images: List[torch.tensor] = []
+
+ for img in img_list:
+ if img is None:
+ img = torch.zeros((3, target_size, target_size), dtype=torch.float32)
+ img = F.normalize(img, mean=mean, std=std)
+ img = img.unsqueeze(0)
+ prepared_images.append(img)
+ continue
+ img = class_letterbox(img, new_shape=(target_size, target_size))
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ img = img / 255.0
+ img = (img - mean) / std
+ img = img.astype(dtype=np.float32)
+
+ img = img.transpose((2, 0, 1))
+ img = np.ascontiguousarray(img)
+ img = torch.from_numpy(img)
+ img = img.unsqueeze(0)
+
+ prepared_images.append(img)
+
+ if len(prepared_images) == 0:
+ return None
+
+ prepared_input = torch.concat(prepared_images)
+
+ if device:
+ prepared_input = prepared_input.to(device)
+
+ return prepared_input
+
+
+def IOU(bb1: Union[tuple, list], bb2: Union[tuple, list], norm_second_bbox: bool = False) -> float:
+ # expects [ymin, xmin, ymax, xmax], doesnt matter absolute or relative
+ assert bb1[1] < bb1[3]
+ assert bb1[0] < bb1[2]
+ assert bb2[1] < bb2[3]
+ assert bb2[0] < bb2[2]
+
+ # determine the coordinates of the intersection rectangle
+ x_left = max(bb1[1], bb2[1])
+ y_top = max(bb1[0], bb2[0])
+ x_right = min(bb1[3], bb2[3])
+ y_bottom = min(bb1[2], bb2[2])
+
+ if x_right < x_left or y_bottom < y_top:
+ return 0.0
+
+ # The intersection of two axis-aligned bounding boxes is always an
+ # axis-aligned bounding box
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
+ # compute the area of both AABBs
+ bb1_area = (bb1[3] - bb1[1]) * (bb1[2] - bb1[0])
+ bb2_area = (bb2[3] - bb2[1]) * (bb2[2] - bb2[0])
+ if not norm_second_bbox:
+ # compute the intersection over union by taking the intersection
+ # area and dividing it by the sum of prediction + ground-truth
+ # areas - the interesection area
+ iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
+ else:
+ # for cases when we search if second bbox is inside first one
+ iou = intersection_area / float(bb2_area)
+
+ assert iou >= 0.0
+ assert iou <= 1.01
+
+ return iou
diff --git a/age_estimator/mivolo/mivolo/model/__init__.py b/age_estimator/mivolo/mivolo/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/__init__.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d87c2c0392526ecce3b2f9b0c0d6f86757d907a7
Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/__init__.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/create_timm_model.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/create_timm_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b757480d8cf16cf410324c62d0f115bcac4137b
Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/create_timm_model.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/cross_bottleneck_attn.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/cross_bottleneck_attn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a475ed6f1765509d3fcfe86045d370e4ef762a8
Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/cross_bottleneck_attn.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/mi_volo.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/mi_volo.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bb5fb506a9cc1d226c50f23512662b2db4b33e7
Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/mi_volo.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/mivolo_model.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/mivolo_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c24ba1eef8ec2ae4717dd358a88be8362d05cea1
Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/mivolo_model.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/yolo_detector.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/yolo_detector.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01af016e7ad0f0286ba1457da77811eea2a35c3d
Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/yolo_detector.cpython-38.pyc differ
diff --git a/age_estimator/mivolo/mivolo/model/create_timm_model.py b/age_estimator/mivolo/mivolo/model/create_timm_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7545b01cd1a3c612787c25c06d926f6cade4d98
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/model/create_timm_model.py
@@ -0,0 +1,107 @@
+"""
+Code adapted from timm https://github.com/huggingface/pytorch-image-models
+
+Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
+"""
+
+import os
+from typing import Any, Dict, Optional, Union
+
+import timm
+
+# register new models
+from mivolo.model.mivolo_model import * # noqa: F403, F401
+from timm.layers import set_layer_config
+from timm.models._factory import parse_model_name
+from timm.models._helpers import load_state_dict, remap_checkpoint
+from timm.models._hub import load_model_config_from_hf
+from timm.models._pretrained import PretrainedCfg, split_model_name_tag
+from timm.models._registry import is_model, model_entrypoint
+
+
+def load_checkpoint(
+ model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None
+):
+ if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"):
+ # numpy checkpoint, try to load via model specific load_pretrained fn
+ if hasattr(model, "load_pretrained"):
+ timm.models._model_builder.load_pretrained(checkpoint_path)
+ else:
+ raise NotImplementedError("Model cannot load numpy checkpoint")
+ return
+ state_dict = load_state_dict(checkpoint_path, use_ema)
+ if remap:
+ state_dict = remap_checkpoint(model, state_dict)
+ if filter_keys:
+ for sd_key in list(state_dict.keys()):
+ for filter_key in filter_keys:
+ if filter_key in sd_key:
+ if sd_key in state_dict:
+ del state_dict[sd_key]
+
+ rep = []
+ if state_dict_map is not None:
+ # 'patch_embed.conv1.' : 'patch_embed.conv.'
+ for state_k in list(state_dict.keys()):
+ for target_k, target_v in state_dict_map.items():
+ if target_v in state_k:
+ target_name = state_k.replace(target_v, target_k)
+ state_dict[target_name] = state_dict[state_k]
+ rep.append(state_k)
+ for r in rep:
+ if r in state_dict:
+ del state_dict[r]
+
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False)
+ return incompatible_keys
+
+
+def create_model(
+ model_name: str,
+ pretrained: bool = False,
+ pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
+ pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
+ checkpoint_path: str = "",
+ scriptable: Optional[bool] = None,
+ exportable: Optional[bool] = None,
+ no_jit: Optional[bool] = None,
+ filter_keys=None,
+ state_dict_map=None,
+ **kwargs,
+):
+ """Create a model
+ Lookup model's entrypoint function and pass relevant args to create a new model.
+ """
+ # Parameters that aren't supported by all models or are intended to only override model defaults if set
+ # should default to None in command line args/cfg. Remove them if they are present and not set so that
+ # non-supporting models don't break and default args remain in effect.
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+
+ model_source, model_name = parse_model_name(model_name)
+ if model_source == "hf-hub":
+ assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub."
+ # For model names specified in the form `hf-hub:path/architecture_name@revision`,
+ # load model weights + pretrained_cfg from Hugging Face hub.
+ pretrained_cfg, model_name = load_model_config_from_hf(model_name)
+ else:
+ model_name, pretrained_tag = split_model_name_tag(model_name)
+ if not pretrained_cfg:
+ # a valid pretrained_cfg argument takes priority over tag in model name
+ pretrained_cfg = pretrained_tag
+
+ if not is_model(model_name):
+ raise RuntimeError("Unknown model (%s)" % model_name)
+
+ create_fn = model_entrypoint(model_name)
+ with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
+ model = create_fn(
+ pretrained=pretrained,
+ pretrained_cfg=pretrained_cfg,
+ pretrained_cfg_overlay=pretrained_cfg_overlay,
+ **kwargs,
+ )
+
+ if checkpoint_path:
+ load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map)
+
+ return model
diff --git a/age_estimator/mivolo/mivolo/model/cross_bottleneck_attn.py b/age_estimator/mivolo/mivolo/model/cross_bottleneck_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..44976bf39ed22280ce640d5f41b6627e21bd543a
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/model/cross_bottleneck_attn.py
@@ -0,0 +1,116 @@
+"""
+Code based on timm https://github.com/huggingface/pytorch-image-models
+
+Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
+"""
+
+import torch
+import torch.nn as nn
+from timm.layers.bottleneck_attn import PosEmbedRel
+from timm.layers.helpers import make_divisible
+from timm.layers.mlp import Mlp
+from timm.layers.trace_utils import _assert
+from timm.layers.weight_init import trunc_normal_
+
+
+class CrossBottleneckAttn(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out=None,
+ feat_size=None,
+ stride=1,
+ num_heads=4,
+ dim_head=None,
+ qk_ratio=1.0,
+ qkv_bias=False,
+ scale_pos_embed=False,
+ ):
+ super().__init__()
+ assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required"
+ dim_out = dim_out or dim
+ assert dim_out % num_heads == 0
+
+ self.num_heads = num_heads
+ self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
+ self.dim_head_v = dim_out // self.num_heads
+ self.dim_out_qk = num_heads * self.dim_head_qk
+ self.dim_out_v = num_heads * self.dim_head_v
+ self.scale = self.dim_head_qk**-0.5
+ self.scale_pos_embed = scale_pos_embed
+
+ self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
+ self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
+
+ # NOTE I'm only supporting relative pos embedding for now
+ self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
+
+ self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size])
+ mlp_ratio = 4
+ self.mlp = Mlp(
+ in_features=self.dim_out_v * 2,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=nn.GELU,
+ out_features=dim_out,
+ drop=0,
+ use_conv=True,
+ )
+
+ self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in
+ trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in
+ trunc_normal_(self.pos_embed.height_rel, std=self.scale)
+ trunc_normal_(self.pos_embed.width_rel, std=self.scale)
+
+ def get_qkv(self, x, qvk_conv):
+ B, C, H, W = x.shape
+
+ x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
+
+ q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
+
+ q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
+ k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
+ v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
+
+ return q, k, v
+
+ def apply_attn(self, q, k, v, B, H, W, dropout=None):
+ if self.scale_pos_embed:
+ attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
+ else:
+ attn = (q @ k) * self.scale + self.pos_embed(q)
+ attn = attn.softmax(dim=-1)
+ if dropout:
+ attn = dropout(attn)
+
+ out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
+ return out
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+
+ dim = int(C / 2)
+ x1 = x[:, :dim, :, :]
+ x2 = x[:, dim:, :, :]
+
+ _assert(H == self.pos_embed.height, "")
+ _assert(W == self.pos_embed.width, "")
+
+ q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f)
+ q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p)
+
+ # person to face
+ out_f = self.apply_attn(q_f, k_p, v_p, B, H, W)
+ # face to person
+ out_p = self.apply_attn(q_p, k_f, v_f, B, H, W)
+
+ x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W
+ x_pf = self.norm(x_pf)
+ x_pf = self.mlp(x_pf) # B, dim_out, H, W
+
+ out = self.pool(x_pf)
+ return out
diff --git a/age_estimator/mivolo/mivolo/model/mi_volo.py b/age_estimator/mivolo/mivolo/model/mi_volo.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d4c824f8f918dc0f2854000afef66fe11573965
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/model/mi_volo.py
@@ -0,0 +1,243 @@
+import logging
+from typing import Optional
+
+import numpy as np
+import torch
+from mivolo.data.misc import prepare_classification_images
+from mivolo.model.create_timm_model import create_model
+from mivolo.structures import PersonAndFaceCrops, PersonAndFaceResult
+from timm.data import resolve_data_config
+
+_logger = logging.getLogger("MiVOLO")
+has_compile = hasattr(torch, "compile")
+
+
+class Meta:
+ def __init__(self):
+ self.min_age = None
+ self.max_age = None
+ self.avg_age = None
+ self.num_classes = None
+
+ self.in_chans = 3
+ self.with_persons_model = False
+ self.disable_faces = False
+ self.use_persons = True
+ self.only_age = False
+
+ self.num_classes_gender = 2
+ self.input_size = 224
+
+ def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta":
+
+ state = torch.load(ckpt_path, map_location="cpu")
+
+ self.min_age = state["min_age"]
+ self.max_age = state["max_age"]
+ self.avg_age = state["avg_age"]
+ self.only_age = state["no_gender"]
+
+ only_age = state["no_gender"]
+
+ self.disable_faces = disable_faces
+ if "with_persons_model" in state:
+ self.with_persons_model = state["with_persons_model"]
+ else:
+ self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False
+
+ self.num_classes = 1 if only_age else 3
+ self.in_chans = 3 if not self.with_persons_model else 6
+ self.use_persons = use_persons and self.with_persons_model
+
+ if not self.with_persons_model and self.disable_faces:
+ raise ValueError("You can not use disable-faces for faces-only model")
+ if self.with_persons_model and self.disable_faces and not self.use_persons:
+ raise ValueError(
+ "You can not disable faces and persons together. "
+ "Set --with-persons if you want to run with --disable-faces"
+ )
+ self.input_size = state["state_dict"]["pos_embed"].shape[1] * 16
+ return self
+
+ def __str__(self):
+ attrs = vars(self)
+ attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops})
+ return ", ".join("%s: %s" % item for item in attrs.items())
+
+ @property
+ def use_person_crops(self) -> bool:
+ return self.with_persons_model and self.use_persons
+
+ @property
+ def use_face_crops(self) -> bool:
+ return not self.disable_faces or not self.with_persons_model
+
+
+class MiVOLO:
+ def __init__(
+ self,
+ ckpt_path: str,
+ device: str = "cuda",
+ half: bool = True,
+ disable_faces: bool = False,
+ use_persons: bool = True,
+ verbose: bool = False,
+ torchcompile: Optional[str] = None,
+ ):
+ self.verbose = verbose
+ self.device = torch.device(device)
+ self.half = half and self.device.type != "cpu"
+
+ self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons)
+ if self.verbose:
+ _logger.info(f"Model meta:\n{str(self.meta)}")
+
+ model_name = f"mivolo_d1_{self.meta.input_size}"
+ self.model = create_model(
+ model_name=model_name,
+ num_classes=self.meta.num_classes,
+ in_chans=self.meta.in_chans,
+ pretrained=False,
+ checkpoint_path=ckpt_path,
+ filter_keys=["fds."],
+ )
+ self.param_count = sum([m.numel() for m in self.model.parameters()])
+ _logger.info(f"Model {model_name} created, param count: {self.param_count}")
+
+ self.data_config = resolve_data_config(
+ model=self.model,
+ verbose=verbose,
+ use_test_size=True,
+ )
+
+ self.data_config["crop_pct"] = 1.0
+ c, h, w = self.data_config["input_size"]
+ assert h == w, "Incorrect data_config"
+ self.input_size = w
+
+ self.model = self.model.to(self.device)
+
+ if torchcompile:
+ assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly."
+ torch._dynamo.reset()
+ self.model = torch.compile(self.model, backend=torchcompile)
+
+ self.model.eval()
+ if self.half:
+ self.model = self.model.half()
+
+ def warmup(self, batch_size: int, steps=10):
+ if self.meta.with_persons_model:
+ input_size = (6, self.input_size, self.input_size)
+ else:
+ input_size = self.data_config["input_size"]
+
+ input = torch.randn((batch_size,) + tuple(input_size)).to(self.device)
+
+ for _ in range(steps):
+ out = self.inference(input) # noqa: F841
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ def inference(self, model_input: torch.tensor) -> torch.tensor:
+
+ with torch.no_grad():
+ if self.half:
+ model_input = model_input.half()
+ output = self.model(model_input)
+ return output
+
+ def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
+ if (
+ (detected_bboxes.n_objects == 0)
+ or (not self.meta.use_persons and detected_bboxes.n_faces == 0)
+ or (self.meta.disable_faces and detected_bboxes.n_persons == 0)
+ ):
+ # nothing to process
+ return
+
+ faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes)
+
+ if faces_input is None and person_input is None:
+ # nothing to process
+ return
+
+ if self.meta.with_persons_model:
+ model_input = torch.cat((faces_input, person_input), dim=1)
+ else:
+ model_input = faces_input
+ output = self.inference(model_input)
+
+ # write gender and age results into detected_bboxes
+ self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds)
+
+ def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds):
+ if self.meta.only_age:
+ age_output = output
+ gender_probs, gender_indx = None, None
+ else:
+ age_output = output[:, 2]
+ gender_output = output[:, :2].softmax(-1)
+ gender_probs, gender_indx = gender_output.topk(1)
+
+ assert output.shape[0] == len(faces_inds) == len(bodies_inds)
+
+ # per face
+ for index in range(output.shape[0]):
+ face_ind = faces_inds[index]
+ body_ind = bodies_inds[index]
+
+ # get_age
+ age = age_output[index].item()
+ age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age
+ age = round(age, 2)
+
+ detected_bboxes.set_age(face_ind, age)
+ detected_bboxes.set_age(body_ind, age)
+
+ _logger.info(f"\tage: {age}")
+
+ if gender_probs is not None:
+ gender = "male" if gender_indx[index].item() == 0 else "female"
+ gender_score = gender_probs[index].item()
+
+ _logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]")
+
+ detected_bboxes.set_gender(face_ind, gender, gender_score)
+ detected_bboxes.set_gender(body_ind, gender, gender_score)
+
+ def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
+
+ if self.meta.use_person_crops and self.meta.use_face_crops:
+ detected_bboxes.associate_faces_with_persons()
+
+ crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image)
+ (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies(
+ self.meta.use_person_crops, self.meta.use_face_crops
+ )
+
+ if not self.meta.use_face_crops:
+ assert all(f is None for f in faces_crops)
+
+ faces_input = prepare_classification_images(
+ faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
+ )
+
+ if not self.meta.use_person_crops:
+ assert all(p is None for p in bodies_crops)
+
+ person_input = prepare_classification_images(
+ bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
+ )
+
+ _logger.info(
+ f"faces_input: {faces_input.shape if faces_input is not None else None}, "
+ f"person_input: {person_input.shape if person_input is not None else None}"
+ )
+
+ return faces_input, person_input, faces_inds, bodies_inds
+
+
+if __name__ == "__main__":
+ model = MiVOLO("../pretrained/checkpoint-377.pth.tar", half=True, device="cuda:0")
diff --git a/age_estimator/mivolo/mivolo/model/mivolo_model.py b/age_estimator/mivolo/mivolo/model/mivolo_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e76c12a2a75a6688f8aa15a0e5f98033c3ef4544
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/model/mivolo_model.py
@@ -0,0 +1,404 @@
+"""
+Code adapted from timm https://github.com/huggingface/pytorch-image-models
+
+Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
+"""
+
+import torch
+import torch.nn as nn
+from mivolo.model.cross_bottleneck_attn import CrossBottleneckAttn
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import trunc_normal_
+from timm.models._builder import build_model_with_cfg
+from timm.models._registry import register_model
+from timm.models.volo import VOLO
+
+__all__ = ["MiVOLOModel"] # model_registry will add each entrypoint fn to this
+
+
+def _cfg(url="", **kwargs):
+ return {
+ "url": url,
+ "num_classes": 1000,
+ "input_size": (3, 224, 224),
+ "pool_size": None,
+ "crop_pct": 0.96,
+ "interpolation": "bicubic",
+ "fixed_input_size": True,
+ "mean": IMAGENET_DEFAULT_MEAN,
+ "std": IMAGENET_DEFAULT_STD,
+ "first_conv": None,
+ "classifier": ("head", "aux_head"),
+ **kwargs,
+ }
+
+
+default_cfgs = {
+ "mivolo_d1_224": _cfg(
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar", crop_pct=0.96
+ ),
+ "mivolo_d1_384": _cfg(
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar",
+ crop_pct=1.0,
+ input_size=(3, 384, 384),
+ ),
+ "mivolo_d2_224": _cfg(
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar", crop_pct=0.96
+ ),
+ "mivolo_d2_384": _cfg(
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar",
+ crop_pct=1.0,
+ input_size=(3, 384, 384),
+ ),
+ "mivolo_d3_224": _cfg(
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar", crop_pct=0.96
+ ),
+ "mivolo_d3_448": _cfg(
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar",
+ crop_pct=1.0,
+ input_size=(3, 448, 448),
+ ),
+ "mivolo_d4_224": _cfg(
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar", crop_pct=0.96
+ ),
+ "mivolo_d4_448": _cfg(
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar",
+ crop_pct=1.15,
+ input_size=(3, 448, 448),
+ ),
+ "mivolo_d5_224": _cfg(
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar", crop_pct=0.96
+ ),
+ "mivolo_d5_448": _cfg(
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar",
+ crop_pct=1.15,
+ input_size=(3, 448, 448),
+ ),
+ "mivolo_d5_512": _cfg(
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar",
+ crop_pct=1.15,
+ input_size=(3, 512, 512),
+ ),
+}
+
+
+def get_output_size(input_shape, conv_layer):
+ padding = conv_layer.padding
+ dilation = conv_layer.dilation
+ kernel_size = conv_layer.kernel_size
+ stride = conv_layer.stride
+
+ output_size = [
+ ((input_shape[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i]) + 1 for i in range(2)
+ ]
+ return output_size
+
+
+def get_output_size_module(input_size, stem):
+ output_size = input_size
+
+ for module in stem:
+ if isinstance(module, nn.Conv2d):
+ output_size = [
+ (
+ (output_size[i] + 2 * module.padding[i] - module.dilation[i] * (module.kernel_size[i] - 1) - 1)
+ // module.stride[i]
+ )
+ + 1
+ for i in range(2)
+ ]
+
+ return output_size
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding."""
+
+ def __init__(
+ self, img_size=224, stem_conv=False, stem_stride=1, patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384
+ ):
+ super().__init__()
+ assert patch_size in [4, 8, 16]
+ assert in_chans in [3, 6]
+ self.with_persons_model = in_chans == 6
+ self.use_cross_attn = True
+
+ if stem_conv:
+ if not self.with_persons_model:
+ self.conv = self.create_stem(stem_stride, in_chans, hidden_dim)
+ else:
+ self.conv = True # just to match interface
+ # split
+ self.conv1 = self.create_stem(stem_stride, 3, hidden_dim)
+ self.conv2 = self.create_stem(stem_stride, 3, hidden_dim)
+ else:
+ self.conv = None
+
+ if self.with_persons_model:
+
+ self.proj1 = nn.Conv2d(
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
+ )
+ self.proj2 = nn.Conv2d(
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
+ )
+
+ stem_out_shape = get_output_size_module((img_size, img_size), self.conv1)
+ self.proj_output_size = get_output_size(stem_out_shape, self.proj1)
+
+ self.map = CrossBottleneckAttn(embed_dim, dim_out=embed_dim, num_heads=1, feat_size=self.proj_output_size)
+
+ else:
+ self.proj = nn.Conv2d(
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
+ )
+
+ self.patch_dim = img_size // patch_size
+ self.num_patches = self.patch_dim**2
+
+ def create_stem(self, stem_stride, in_chans, hidden_dim):
+ return nn.Sequential(
+ nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112
+ nn.BatchNorm2d(hidden_dim),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
+ nn.BatchNorm2d(hidden_dim),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
+ nn.BatchNorm2d(hidden_dim),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, x):
+ if self.conv is not None:
+ if self.with_persons_model:
+ x1 = x[:, :3]
+ x2 = x[:, 3:]
+
+ x1 = self.conv1(x1)
+ x1 = self.proj1(x1)
+
+ x2 = self.conv2(x2)
+ x2 = self.proj2(x2)
+
+ x = torch.cat([x1, x2], dim=1)
+ x = self.map(x)
+ else:
+ x = self.conv(x)
+ x = self.proj(x) # B, C, H, W
+
+ return x
+
+
+class MiVOLOModel(VOLO):
+ """
+ Vision Outlooker, the main class of our model
+ """
+
+ def __init__(
+ self,
+ layers,
+ img_size=224,
+ in_chans=3,
+ num_classes=1000,
+ global_pool="token",
+ patch_size=8,
+ stem_hidden_dim=64,
+ embed_dims=None,
+ num_heads=None,
+ downsamples=(True, False, False, False),
+ outlook_attention=(True, False, False, False),
+ mlp_ratio=3.0,
+ qkv_bias=False,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ norm_layer=nn.LayerNorm,
+ post_layers=("ca", "ca"),
+ use_aux_head=True,
+ use_mix_token=False,
+ pooling_scale=2,
+ ):
+ super().__init__(
+ layers,
+ img_size,
+ in_chans,
+ num_classes,
+ global_pool,
+ patch_size,
+ stem_hidden_dim,
+ embed_dims,
+ num_heads,
+ downsamples,
+ outlook_attention,
+ mlp_ratio,
+ qkv_bias,
+ drop_rate,
+ attn_drop_rate,
+ drop_path_rate,
+ norm_layer,
+ post_layers,
+ use_aux_head,
+ use_mix_token,
+ pooling_scale,
+ )
+
+ im_size = img_size[0] if isinstance(img_size, tuple) else img_size
+ self.patch_embed = PatchEmbed(
+ img_size=im_size,
+ stem_conv=True,
+ stem_stride=2,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ hidden_dim=stem_hidden_dim,
+ embed_dim=embed_dims[0],
+ )
+
+ trunc_normal_(self.pos_embed, std=0.02)
+ self.apply(self._init_weights)
+
+ def forward_features(self, x):
+ x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
+
+ # step2: tokens learning in the two stages
+ x = self.forward_tokens(x)
+
+ # step3: post network, apply class attention or not
+ if self.post_network is not None:
+ x = self.forward_cls(x)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False, targets=None, epoch=None):
+ if self.global_pool == "avg":
+ out = x.mean(dim=1)
+ elif self.global_pool == "token":
+ out = x[:, 0]
+ else:
+ out = x
+ if pre_logits:
+ return out
+
+ features = out
+ fds_enabled = hasattr(self, "_fds_forward")
+ if fds_enabled:
+ features = self._fds_forward(features, targets, epoch)
+
+ out = self.head(features)
+ if self.aux_head is not None:
+ # generate classes in all feature tokens, see token labeling
+ aux = self.aux_head(x[:, 1:])
+ out = out + 0.5 * aux.max(1)[0]
+
+ return (out, features) if (fds_enabled and self.training) else out
+
+ def forward(self, x, targets=None, epoch=None):
+ """simplified forward (without mix token training)"""
+ x = self.forward_features(x)
+ x = self.forward_head(x, targets=targets, epoch=epoch)
+ return x
+
+
+def _create_mivolo(variant, pretrained=False, **kwargs):
+ if kwargs.get("features_only", None):
+ raise RuntimeError("features_only not implemented for Vision Transformer models.")
+ return build_model_with_cfg(MiVOLOModel, variant, pretrained, **kwargs)
+
+
+@register_model
+def mivolo_d1_224(pretrained=False, **kwargs):
+ model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
+ model = _create_mivolo("mivolo_d1_224", pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mivolo_d1_384(pretrained=False, **kwargs):
+ model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
+ model = _create_mivolo("mivolo_d1_384", pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mivolo_d2_224(pretrained=False, **kwargs):
+ model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
+ model = _create_mivolo("mivolo_d2_224", pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mivolo_d2_384(pretrained=False, **kwargs):
+ model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
+ model = _create_mivolo("mivolo_d2_384", pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mivolo_d3_224(pretrained=False, **kwargs):
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
+ model = _create_mivolo("mivolo_d3_224", pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mivolo_d3_448(pretrained=False, **kwargs):
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
+ model = _create_mivolo("mivolo_d3_448", pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mivolo_d4_224(pretrained=False, **kwargs):
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
+ model = _create_mivolo("mivolo_d4_224", pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mivolo_d4_448(pretrained=False, **kwargs):
+ """VOLO-D4 model, Params: 193M"""
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
+ model = _create_mivolo("mivolo_d4_448", pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mivolo_d5_224(pretrained=False, **kwargs):
+ model_args = dict(
+ layers=(12, 12, 20, 4),
+ embed_dims=(384, 768, 768, 768),
+ num_heads=(12, 16, 16, 16),
+ mlp_ratio=4,
+ stem_hidden_dim=128,
+ **kwargs
+ )
+ model = _create_mivolo("mivolo_d5_224", pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mivolo_d5_448(pretrained=False, **kwargs):
+ model_args = dict(
+ layers=(12, 12, 20, 4),
+ embed_dims=(384, 768, 768, 768),
+ num_heads=(12, 16, 16, 16),
+ mlp_ratio=4,
+ stem_hidden_dim=128,
+ **kwargs
+ )
+ model = _create_mivolo("mivolo_d5_448", pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mivolo_d5_512(pretrained=False, **kwargs):
+ model_args = dict(
+ layers=(12, 12, 20, 4),
+ embed_dims=(384, 768, 768, 768),
+ num_heads=(12, 16, 16, 16),
+ mlp_ratio=4,
+ stem_hidden_dim=128,
+ **kwargs
+ )
+ model = _create_mivolo("mivolo_d5_512", pretrained=pretrained, **model_args)
+ return model
diff --git a/age_estimator/mivolo/mivolo/model/yolo_detector.py b/age_estimator/mivolo/mivolo/model/yolo_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc22a1f142aa3f6bb9640055b5f81e1c553f4570
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/model/yolo_detector.py
@@ -0,0 +1,46 @@
+import os
+from typing import Dict, Union
+
+import numpy as np
+import PIL
+import torch
+from mivolo.structures import PersonAndFaceResult
+from ultralytics import YOLO
+from ultralytics.engine.results import Results
+
+# because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
+os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
+
+
+class Detector:
+ def __init__(
+ self,
+ weights: str,
+ device: str = "cuda",
+ half: bool = True,
+ verbose: bool = False,
+ conf_thresh: float = 0.4,
+ iou_thresh: float = 0.7,
+ ):
+ self.yolo = YOLO(weights)
+ self.yolo.fuse()
+
+ self.device = torch.device(device)
+ self.half = half and self.device.type != "cpu"
+
+ if self.half:
+ self.yolo.model = self.yolo.model.half()
+
+ self.detector_names: Dict[int, str] = self.yolo.model.names
+
+ # init yolo.predictor
+ self.detector_kwargs = {"conf": conf_thresh, "iou": iou_thresh, "half": self.half, "verbose": verbose}
+ # self.yolo.predict(**self.detector_kwargs)
+
+ def predict(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
+ results: Results = self.yolo.predict(image, **self.detector_kwargs)[0]
+ return PersonAndFaceResult(results)
+
+ def track(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
+ results: Results = self.yolo.track(image, persist=True, **self.detector_kwargs)[0]
+ return PersonAndFaceResult(results)
diff --git a/age_estimator/mivolo/mivolo/predictor.py b/age_estimator/mivolo/mivolo/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..086515493dad7362d990127ae24e4657750c80ae
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/predictor.py
@@ -0,0 +1,80 @@
+from collections import defaultdict
+from typing import Dict, Generator, List, Optional, Tuple
+
+import cv2
+import numpy as np
+import tqdm
+from mivolo.model.mi_volo import MiVOLO
+from mivolo.model.yolo_detector import Detector
+from mivolo.structures import AGE_GENDER_TYPE, PersonAndFaceResult
+
+
+class Predictor:
+ def __init__(self, config, verbose: bool = False):
+ self.detector = Detector(config.detector_weights, config.device, verbose=verbose)
+ self.age_gender_model = MiVOLO(
+ config.checkpoint,
+ config.device,
+ half=True,
+ use_persons=config.with_persons,
+ disable_faces=config.disable_faces,
+ verbose=verbose,
+ )
+ self.draw = config.draw
+
+ def recognize(self, image: np.ndarray) -> Tuple[PersonAndFaceResult, Optional[np.ndarray]]:
+ detected_objects: PersonAndFaceResult = self.detector.predict(image)
+ self.age_gender_model.predict(image, detected_objects)
+
+ # Assuming 'results' is where detected persons and faces are stored
+ # Retrieve ages after prediction
+ age = detected_objects.get_ages() # Adjust this line if `get_ages` method has a different name or structure
+ if hasattr(detected_objects, 'results'):
+ for obj in detected_objects.results:
+ bbox = obj['bbox'] # Bounding box for person/face
+ label = obj['label'] # "person" or "face"
+ age = obj.get('age', None)
+ gender = obj.get('gender', None)
+
+ print(f"Detected {label} at {bbox} with age: {age}, gender: {gender}")
+
+ out_im = None
+ if self.draw:
+ out_im = detected_objects.plot()
+
+ return detected_objects, out_im, age
+
+
+ def recognize_video(self, source: str) -> Generator:
+ video_capture = cv2.VideoCapture(source)
+ if not video_capture.isOpened():
+ raise ValueError(f"Failed to open video source {source}")
+
+ detected_objects_history: Dict[int, List[AGE_GENDER_TYPE]] = defaultdict(list)
+
+ total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
+ for _ in tqdm.tqdm(range(total_frames)):
+ ret, frame = video_capture.read()
+ if not ret:
+ break
+
+ detected_objects: PersonAndFaceResult = self.detector.track(frame)
+ self.age_gender_model.predict(frame, detected_objects)
+
+ current_frame_objs = detected_objects.get_results_for_tracking()
+ cur_persons: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[0]
+ cur_faces: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[1]
+
+ # add tr_persons and tr_faces to history
+ for guid, data in cur_persons.items():
+ # not useful for tracking :)
+ if None not in data:
+ detected_objects_history[guid].append(data)
+ for guid, data in cur_faces.items():
+ if None not in data:
+ detected_objects_history[guid].append(data)
+
+ detected_objects.set_tracked_age_gender(detected_objects_history)
+ if self.draw:
+ frame = detected_objects.plot()
+ yield detected_objects_history, frame
diff --git a/age_estimator/mivolo/mivolo/structures.py b/age_estimator/mivolo/mivolo/structures.py
new file mode 100644
index 0000000000000000000000000000000000000000..45277f6b770d7577f2459bcce679dd98dc813095
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/structures.py
@@ -0,0 +1,478 @@
+import math
+import os
+from copy import deepcopy
+from typing import Dict, List, Optional, Tuple
+
+import cv2
+import numpy as np
+import torch
+from mivolo.data.misc import aggregate_votes_winsorized, assign_faces, box_iou
+from ultralytics.engine.results import Results
+from ultralytics.utils.plotting import Annotator, colors
+
+# because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
+os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
+
+AGE_GENDER_TYPE = Tuple[float, str]
+
+
+class PersonAndFaceCrops:
+ def __init__(self):
+ # int: index of person along results
+ self.crops_persons: Dict[int, np.ndarray] = {}
+
+ # int: index of face along results
+ self.crops_faces: Dict[int, np.ndarray] = {}
+
+ # int: index of face along results
+ self.crops_faces_wo_body: Dict[int, np.ndarray] = {}
+
+ # int: index of person along results
+ self.crops_persons_wo_face: Dict[int, np.ndarray] = {}
+
+ def _add_to_output(
+ self, crops: Dict[int, np.ndarray], out_crops: List[np.ndarray], out_crop_inds: List[Optional[int]]
+ ):
+ inds_to_add = list(crops.keys())
+ crops_to_add = list(crops.values())
+ out_crops.extend(crops_to_add)
+ out_crop_inds.extend(inds_to_add)
+
+ def _get_all_faces(
+ self, use_persons: bool, use_faces: bool
+ ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
+ """
+ Returns
+ if use_persons and use_faces
+ faces: faces_with_bodies + faces_without_bodies + [None] * len(crops_persons_wo_face)
+ if use_persons and not use_faces
+ faces: [None] * n_persons
+ if not use_persons and use_faces:
+ faces: faces_with_bodies + faces_without_bodies
+ """
+
+ def add_none_to_output(faces_inds, faces_crops, num):
+ faces_inds.extend([None for _ in range(num)])
+ faces_crops.extend([None for _ in range(num)])
+
+ faces_inds: List[Optional[int]] = []
+ faces_crops: List[Optional[np.ndarray]] = []
+
+ if not use_faces:
+ add_none_to_output(faces_inds, faces_crops, len(self.crops_persons) + len(self.crops_persons_wo_face))
+ return faces_inds, faces_crops
+
+ self._add_to_output(self.crops_faces, faces_crops, faces_inds)
+ self._add_to_output(self.crops_faces_wo_body, faces_crops, faces_inds)
+
+ if use_persons:
+ add_none_to_output(faces_inds, faces_crops, len(self.crops_persons_wo_face))
+
+ return faces_inds, faces_crops
+
+ def _get_all_bodies(
+ self, use_persons: bool, use_faces: bool
+ ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
+ """
+ Returns
+ if use_persons and use_faces
+ persons: bodies_with_faces + [None] * len(faces_without_bodies) + bodies_without_faces
+ if use_persons and not use_faces
+ persons: bodies_with_faces + bodies_without_faces
+ if not use_persons and use_faces
+ persons: [None] * n_faces
+ """
+
+ def add_none_to_output(bodies_inds, bodies_crops, num):
+ bodies_inds.extend([None for _ in range(num)])
+ bodies_crops.extend([None for _ in range(num)])
+
+ bodies_inds: List[Optional[int]] = []
+ bodies_crops: List[Optional[np.ndarray]] = []
+
+ if not use_persons:
+ add_none_to_output(bodies_inds, bodies_crops, len(self.crops_faces) + len(self.crops_faces_wo_body))
+ return bodies_inds, bodies_crops
+
+ self._add_to_output(self.crops_persons, bodies_crops, bodies_inds)
+ if use_faces:
+ add_none_to_output(bodies_inds, bodies_crops, len(self.crops_faces_wo_body))
+
+ self._add_to_output(self.crops_persons_wo_face, bodies_crops, bodies_inds)
+
+ return bodies_inds, bodies_crops
+
+ def get_faces_with_bodies(self, use_persons: bool, use_faces: bool):
+ """
+ Return
+ faces: faces_with_bodies, faces_without_bodies, [None] * len(crops_persons_wo_face)
+ persons: bodies_with_faces, [None] * len(faces_without_bodies), bodies_without_faces
+ """
+
+ bodies_inds, bodies_crops = self._get_all_bodies(use_persons, use_faces)
+ faces_inds, faces_crops = self._get_all_faces(use_persons, use_faces)
+
+ return (bodies_inds, bodies_crops), (faces_inds, faces_crops)
+
+ def save(self, out_dir="output"):
+ ind = 0
+ os.makedirs(out_dir, exist_ok=True)
+ for crops in [self.crops_persons, self.crops_faces, self.crops_faces_wo_body, self.crops_persons_wo_face]:
+ for crop in crops.values():
+ if crop is None:
+ continue
+ out_name = os.path.join(out_dir, f"{ind}_crop.jpg")
+ cv2.imwrite(out_name, crop)
+ ind += 1
+
+
+
+
+class PersonAndFaceResult:
+ def __init__(self, results: Results):
+
+ self.yolo_results = results
+ names = set(results.names.values())
+ assert "person" in names and "face" in names
+
+ # initially no faces and persons are associated to each other
+ self.face_to_person_map: Dict[int, Optional[int]] = {ind: None for ind in self.get_bboxes_inds("face")}
+ self.unassigned_persons_inds: List[int] = self.get_bboxes_inds("person")
+ n_objects = len(self.yolo_results.boxes)
+ self.ages: List[Optional[float]] = [None for _ in range(n_objects)]
+ self.genders: List[Optional[str]] = [None for _ in range(n_objects)]
+ self.gender_scores: List[Optional[float]] = [None for _ in range(n_objects)]
+
+ @property
+ def n_objects(self) -> int:
+ return len(self.yolo_results.boxes)
+
+ @property
+ def n_faces(self) -> int:
+ return len(self.get_bboxes_inds("face"))
+
+ @property
+ def n_persons(self) -> int:
+ return len(self.get_bboxes_inds("person"))
+
+ def get_bboxes_inds(self, category: str) -> List[int]:
+ bboxes: List[int] = []
+ for ind, det in enumerate(self.yolo_results.boxes):
+ name = self.yolo_results.names[int(det.cls)]
+ if name == category:
+ bboxes.append(ind)
+
+ return bboxes
+
+ def get_distance_to_center(self, bbox_ind: int) -> float:
+ """
+ Calculate euclidian distance between bbox center and image center.
+ """
+ im_h, im_w = self.yolo_results[bbox_ind].orig_shape
+ x1, y1, x2, y2 = self.get_bbox_by_ind(bbox_ind).cpu().numpy()
+ center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
+ dist = math.dist([center_x, center_y], [im_w / 2, im_h / 2])
+ return dist
+
+ def plot(
+ self,
+ conf=False,
+ line_width=None,
+ font_size=None,
+ font="Arial.ttf",
+ pil=False,
+ img=None,
+ labels=True,
+ boxes=True,
+ probs=True,
+ ages=True,
+ genders=True,
+ gender_probs=False,
+ ):
+ """
+ Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
+ Args:
+ conf (bool): Whether to plot the detection confidence score.
+ line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
+ font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
+ font (str): The font to use for the text.
+ pil (bool): Whether to return the image as a PIL Image.
+ img (numpy.ndarray): Plot to another image. if not, plot to original image.
+ labels (bool): Whether to plot the label of bounding boxes.
+ boxes (bool): Whether to plot the bounding boxes.
+ probs (bool): Whether to plot classification probability
+ ages (bool): Whether to plot the age of bounding boxes.
+ genders (bool): Whether to plot the genders of bounding boxes.
+ gender_probs (bool): Whether to plot gender classification probability
+ Returns:
+ (numpy.ndarray): A numpy array of the annotated image.
+ """
+
+ # return self.yolo_results.plot()
+ colors_by_ind = {}
+ for face_ind, person_ind in self.face_to_person_map.items():
+ if person_ind is not None:
+ colors_by_ind[face_ind] = face_ind + 2
+ colors_by_ind[person_ind] = face_ind + 2
+ else:
+ colors_by_ind[face_ind] = 0
+ for person_ind in self.unassigned_persons_inds:
+ colors_by_ind[person_ind] = 1
+
+ names = self.yolo_results.names
+ annotator = Annotator(
+ deepcopy(self.yolo_results.orig_img if img is None else img),
+ line_width,
+ font_size,
+ font,
+ pil,
+ example=names,
+ )
+ pred_boxes, show_boxes = self.yolo_results.boxes, boxes
+ pred_probs, show_probs = self.yolo_results.probs, probs
+
+ if pred_boxes and show_boxes:
+ for bb_ind, (d, age, gender, gender_score) in enumerate(
+ zip(pred_boxes, self.ages, self.genders, self.gender_scores)
+ ):
+ c, conf, guid = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
+ name = ("" if guid is None else f"id:{guid} ") + names[c]
+ label = (f"{name} {conf:.2f}" if conf else name) if labels else None
+ if ages and age is not None:
+ label += f" {age:.1f}"
+ if genders and gender is not None:
+ label += f" {'F' if gender == 'female' else 'M'}"
+ if gender_probs and gender_score is not None:
+ label += f" ({gender_score:.1f})"
+ annotator.box_label(d.xyxy.squeeze(), label, color=colors(colors_by_ind[bb_ind], True))
+
+ if pred_probs is not None and show_probs:
+ text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, "
+ annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
+
+ return annotator.result()
+
+ def set_tracked_age_gender(self, tracked_objects: Dict[int, List[AGE_GENDER_TYPE]]):
+ """
+ Update age and gender for objects based on history from tracked_objects.
+ Args:
+ tracked_objects (dict[int, list[AGE_GENDER_TYPE]]): info about tracked objects by guid
+ """
+
+ for face_ind, person_ind in self.face_to_person_map.items():
+ pguid = self._get_id_by_ind(person_ind)
+ fguid = self._get_id_by_ind(face_ind)
+
+ if fguid == -1 and pguid == -1:
+ # YOLO might not assign ids for some objects in some cases:
+ # https://github.com/ultralytics/ultralytics/issues/3830
+ continue
+ age, gender = self._gather_tracking_result(tracked_objects, fguid, pguid)
+ if age is None or gender is None:
+ continue
+ self.set_age(face_ind, age)
+ self.set_gender(face_ind, gender, 1.0)
+ if pguid != -1:
+ self.set_gender(person_ind, gender, 1.0)
+ self.set_age(person_ind, age)
+
+ for person_ind in self.unassigned_persons_inds:
+ pid = self._get_id_by_ind(person_ind)
+ if pid == -1:
+ continue
+ age, gender = self._gather_tracking_result(tracked_objects, -1, pid)
+ if age is None or gender is None:
+ continue
+ self.set_gender(person_ind, gender, 1.0)
+ self.set_age(person_ind, age)
+
+ def _get_id_by_ind(self, ind: Optional[int] = None) -> int:
+ if ind is None:
+ return -1
+ obj_id = self.yolo_results.boxes[ind].id
+ if obj_id is None:
+ return -1
+ return obj_id.item()
+
+ def get_bbox_by_ind(self, ind: int, im_h: int = None, im_w: int = None) -> torch.tensor:
+ bb = self.yolo_results.boxes[ind].xyxy.squeeze().type(torch.int32)
+ if im_h is not None and im_w is not None:
+ bb[0] = torch.clamp(bb[0], min=0, max=im_w - 1)
+ bb[1] = torch.clamp(bb[1], min=0, max=im_h - 1)
+ bb[2] = torch.clamp(bb[2], min=0, max=im_w - 1)
+ bb[3] = torch.clamp(bb[3], min=0, max=im_h - 1)
+ return bb
+
+ def set_age(self, ind: Optional[int], age: float):
+ if ind is not None:
+ self.ages[ind] = age
+
+ def set_gender(self, ind: Optional[int], gender: str, gender_score: float):
+ if ind is not None:
+ self.genders[ind] = gender
+ self.gender_scores[ind] = gender_score
+
+ @staticmethod
+ def _gather_tracking_result(
+ tracked_objects: Dict[int, List[AGE_GENDER_TYPE]],
+ fguid: int = -1,
+ pguid: int = -1,
+ minimum_sample_size: int = 10,
+ ) -> AGE_GENDER_TYPE:
+
+ assert fguid != -1 or pguid != -1, "Incorrect tracking behaviour"
+
+ face_ages = [r[0] for r in tracked_objects[fguid] if r[0] is not None] if fguid in tracked_objects else []
+ face_genders = [r[1] for r in tracked_objects[fguid] if r[1] is not None] if fguid in tracked_objects else []
+ person_ages = [r[0] for r in tracked_objects[pguid] if r[0] is not None] if pguid in tracked_objects else []
+ person_genders = [r[1] for r in tracked_objects[pguid] if r[1] is not None] if pguid in tracked_objects else []
+
+ if not face_ages and not person_ages: # both empty
+ return None, None
+
+ # You can play here with different aggregation strategies
+ # Face ages - predictions based on face or face + person, depends on history of object
+ # Person ages - predictions based on person or face + person, depends on history of object
+
+ if len(person_ages + face_ages) >= minimum_sample_size:
+ age = aggregate_votes_winsorized(person_ages + face_ages)
+ else:
+ face_age = np.mean(face_ages) if face_ages else None
+ person_age = np.mean(person_ages) if person_ages else None
+ if face_age is None:
+ face_age = person_age
+ if person_age is None:
+ person_age = face_age
+ age = (face_age + person_age) / 2.0
+
+ genders = face_genders + person_genders
+ assert len(genders) > 0
+ # take mode of genders
+ gender = max(set(genders), key=genders.count)
+
+ return age, gender
+
+ def get_results_for_tracking(self) -> Tuple[Dict[int, AGE_GENDER_TYPE], Dict[int, AGE_GENDER_TYPE]]:
+ """
+ Get objects from current frame
+ """
+ persons: Dict[int, AGE_GENDER_TYPE] = {}
+ faces: Dict[int, AGE_GENDER_TYPE] = {}
+
+ names = self.yolo_results.names
+ pred_boxes = self.yolo_results.boxes
+ for _, (det, age, gender, _) in enumerate(zip(pred_boxes, self.ages, self.genders, self.gender_scores)):
+ if det.id is None:
+ continue
+ cat_id, _, guid = int(det.cls), float(det.conf), int(det.id.item())
+ name = names[cat_id]
+ if name == "person":
+ persons[guid] = (age, gender)
+ elif name == "face":
+ faces[guid] = (age, gender)
+
+ return persons, faces
+
+ def associate_faces_with_persons(self):
+ face_bboxes_inds: List[int] = self.get_bboxes_inds("face")
+ person_bboxes_inds: List[int] = self.get_bboxes_inds("person")
+
+ face_bboxes: List[torch.tensor] = [self.get_bbox_by_ind(ind) for ind in face_bboxes_inds]
+ person_bboxes: List[torch.tensor] = [self.get_bbox_by_ind(ind) for ind in person_bboxes_inds]
+
+ self.face_to_person_map = {ind: None for ind in face_bboxes_inds}
+ assigned_faces, unassigned_persons_inds = assign_faces(person_bboxes, face_bboxes)
+
+ for face_ind, person_ind in enumerate(assigned_faces):
+ face_ind = face_bboxes_inds[face_ind]
+ person_ind = person_bboxes_inds[person_ind] if person_ind is not None else None
+ self.face_to_person_map[face_ind] = person_ind
+
+ self.unassigned_persons_inds = [person_bboxes_inds[person_ind] for person_ind in unassigned_persons_inds]
+
+ def crop_object(
+ self, full_image: np.ndarray, ind: int, cut_other_classes: Optional[List[str]] = None
+ ) -> Optional[np.ndarray]:
+
+ IOU_THRESH = 0.000001
+ MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
+ CROP_ROUND_RATE = 0.3
+ MIN_PERSON_SIZE = 50
+
+ obj_bbox = self.get_bbox_by_ind(ind, *full_image.shape[:2])
+ x1, y1, x2, y2 = obj_bbox
+ cur_cat = self.yolo_results.names[int(self.yolo_results.boxes[ind].cls)]
+ # get crop of face or person
+ obj_image = full_image[y1:y2, x1:x2].copy()
+ crop_h, crop_w = obj_image.shape[:2]
+
+ if cur_cat == "person" and (crop_h < MIN_PERSON_SIZE or crop_w < MIN_PERSON_SIZE):
+ return None
+
+ if not cut_other_classes:
+ return obj_image
+
+ # calc iou between obj_bbox and other bboxes
+ other_bboxes: List[torch.tensor] = [
+ self.get_bbox_by_ind(other_ind, *full_image.shape[:2]) for other_ind in range(len(self.yolo_results.boxes))
+ ]
+
+ iou_matrix = box_iou(torch.stack([obj_bbox]), torch.stack(other_bboxes)).cpu().numpy()[0]
+
+ # cut out other objects in case of intersection
+ for other_ind, (det, iou) in enumerate(zip(self.yolo_results.boxes, iou_matrix)):
+ other_cat = self.yolo_results.names[int(det.cls)]
+ if ind == other_ind or iou < IOU_THRESH or other_cat not in cut_other_classes:
+ continue
+ o_x1, o_y1, o_x2, o_y2 = det.xyxy.squeeze().type(torch.int32)
+
+ # remap current_person_bbox to reference_person_bbox coordinates
+ o_x1 = max(o_x1 - x1, 0)
+ o_y1 = max(o_y1 - y1, 0)
+ o_x2 = min(o_x2 - x1, crop_w)
+ o_y2 = min(o_y2 - y1, crop_h)
+
+ if other_cat != "face":
+ if (o_y1 / crop_h) < CROP_ROUND_RATE:
+ o_y1 = 0
+ if ((crop_h - o_y2) / crop_h) < CROP_ROUND_RATE:
+ o_y2 = crop_h
+ if (o_x1 / crop_w) < CROP_ROUND_RATE:
+ o_x1 = 0
+ if ((crop_w - o_x2) / crop_w) < CROP_ROUND_RATE:
+ o_x2 = crop_w
+
+ obj_image[o_y1:o_y2, o_x1:o_x2] = 0
+
+ remain_ratio = np.count_nonzero(obj_image) / (obj_image.shape[0] * obj_image.shape[1] * obj_image.shape[2])
+ if remain_ratio < MIN_PERSON_CROP_AFTERCUT_RATIO:
+ return None
+
+ return obj_image
+
+ def collect_crops(self, image) -> PersonAndFaceCrops:
+
+ crops_data = PersonAndFaceCrops()
+ for face_ind, person_ind in self.face_to_person_map.items():
+ face_image = self.crop_object(image, face_ind, cut_other_classes=[])
+
+ if person_ind is None:
+ crops_data.crops_faces_wo_body[face_ind] = face_image
+ continue
+
+ person_image = self.crop_object(image, person_ind, cut_other_classes=["face", "person"])
+
+ crops_data.crops_faces[face_ind] = face_image
+ crops_data.crops_persons[person_ind] = person_image
+
+ for person_ind in self.unassigned_persons_inds:
+ person_image = self.crop_object(image, person_ind, cut_other_classes=["face", "person"])
+ crops_data.crops_persons_wo_face[person_ind] = person_image
+
+ # uncomment to save preprocessed crops
+ # crops_data.save()
+ return crops_data
+
+ def get_ages(self):
+ # Assuming ages are stored in some internal attribute like `_ages`
+ return self.ages # Replace with the correct attribute where ages are stored
diff --git a/age_estimator/mivolo/mivolo/version.py b/age_estimator/mivolo/mivolo/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc321393ba68aa2bcb2b8ab828e95c89ddce7dc3
--- /dev/null
+++ b/age_estimator/mivolo/mivolo/version.py
@@ -0,0 +1 @@
+__version__ = "0.6.0dev"
diff --git a/age_estimator/mivolo/models/model_imdb_cross_person_4.22_99.46.pth.tar b/age_estimator/mivolo/models/model_imdb_cross_person_4.22_99.46.pth.tar
new file mode 100644
index 0000000000000000000000000000000000000000..e62bb76ed79b900b87e7e752c0223d45d40c9345
--- /dev/null
+++ b/age_estimator/mivolo/models/model_imdb_cross_person_4.22_99.46.pth.tar
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc279b6914b3ee8be6a58139c06ecb24ca95751233cf6c07804b93184614eb17
+size 109777437
diff --git a/age_estimator/mivolo/models/yolov8x_person_face.pt b/age_estimator/mivolo/models/yolov8x_person_face.pt
new file mode 100644
index 0000000000000000000000000000000000000000..c4fea0e2a9545fe4db35fc88ad60804e7209da8b
--- /dev/null
+++ b/age_estimator/mivolo/models/yolov8x_person_face.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2620f45609a65f909eb876bd7401308b5a8f3843ad5a03cb7416066a3e492989
+size 136716488
diff --git a/age_estimator/mivolo/mypy.ini b/age_estimator/mivolo/mypy.ini
new file mode 100644
index 0000000000000000000000000000000000000000..c59a7d2d5eccf1b217f2c56c1f5b533d26ce5b10
--- /dev/null
+++ b/age_estimator/mivolo/mypy.ini
@@ -0,0 +1,5 @@
+[mypy]
+python_version = 3.8
+no_strict_optional = True
+ignore_missing_imports = True
+disallow_any_unimported = False
diff --git a/age_estimator/mivolo/requirements.txt b/age_estimator/mivolo/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4e30b002db2c8d41a750c03bdd99801fd54a7200
--- /dev/null
+++ b/age_estimator/mivolo/requirements.txt
@@ -0,0 +1,5 @@
+huggingface_hub
+ultralytics==8.1.0
+timm==0.8.13.dev0
+yt_dlp
+lapx>=0.5.2
diff --git a/age_estimator/mivolo/scripts/inference.sh b/age_estimator/mivolo/scripts/inference.sh
new file mode 100644
index 0000000000000000000000000000000000000000..45cd54633b1adc09c94504724ae27c521dd42c1d
--- /dev/null
+++ b/age_estimator/mivolo/scripts/inference.sh
@@ -0,0 +1,18 @@
+
+python3 demo.py \
+--input "jennifer_lawrence.jpg" \
+--output "output" \
+--detector-weights "pretrained/yolov8x_person_face.pt" \
+--checkpoint "pretrained/checkpoint-377.pth.tar" \
+--device "cuda:0" \
+--draw \
+--with-persons
+
+python3 demo.py \
+--input "https://www.youtube.com/shorts/pVh32k0hGEI" \
+--output "output" \
+--detector-weights "pretrained/yolov8x_person_face.pt" \
+--checkpoint "pretrained/checkpoint-377.pth.tar" \
+--device "cuda:0" \
+--draw \
+--with-persons
diff --git a/age_estimator/mivolo/scripts/valid_age_gender.sh b/age_estimator/mivolo/scripts/valid_age_gender.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b2edb722909e2a20228613fbe0e35760d3f0157b
--- /dev/null
+++ b/age_estimator/mivolo/scripts/valid_age_gender.sh
@@ -0,0 +1,48 @@
+#!/usr/bin/env bash
+
+# inference utk
+python3 eval_pretrained.py \
+ --dataset_images data/utk/images \
+ --dataset_annotations data/utk/annotation \
+ --dataset_name utk \
+ --batch-size 512 \
+ --checkpoint pretrained/model_imdb_cross_person_4.24_99.46.pth.tar \
+ --split valid \
+ --half \
+ --with-persons \
+ --device "cuda:0"
+
+# inference fairface
+python3 eval_pretrained.py \
+ --dataset_images data/FairFace/fairface-img-margin125-trainval \
+ --dataset_annotations data/FairFace/annotations \
+ --dataset_name fairface \
+ --batch-size 512 \
+ --checkpoint pretrained/model_imdb_cross_person_4.24_99.46.pth.tar \
+ --split val \
+ --half \
+ --with-persons \
+ --device "cuda:0"
+
+# inference adience
+python3 eval_pretrained.py \
+ --dataset_images data/adience/faces \
+ --dataset_annotations data/adience/annotations \
+ --dataset_name adience \
+ --batch-size 512 \
+ --checkpoint pretrained/model_imdb_cross_person_4.24_99.46.pth.tar \
+ --split adience \
+ --half \
+ --with-persons \
+ --device "cuda:0"
+
+# inference agedb
+python3 eval_pretrained.py \
+ --dataset_images data/agedb/AgeDB \
+ --dataset_annotations data/agedb/annotation \
+ --dataset_name agedb \
+ --batch-size 512 \
+ --checkpoint pretrained/model_imdb_cross_person_4.24_99.46.pth.tar \
+ --split 0,1,2,3,4,5,6,7,8,9 \
+ --half \
+ --device "cuda:0"
diff --git a/age_estimator/mivolo/setup.cfg b/age_estimator/mivolo/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..2cde6a494836f93681b73aaa7bcf0d0d487de469
--- /dev/null
+++ b/age_estimator/mivolo/setup.cfg
@@ -0,0 +1,56 @@
+# Project-wide configuration file, can be used for package metadata and other toll configurations
+# Example usage: global configuration for PEP8 (via flake8) setting or default pytest arguments
+# Local usage: pip install pre-commit, pre-commit run --all-files
+
+[metadata]
+license_files = LICENSE
+description_file = README.md
+
+[tool:pytest]
+norecursedirs =
+ .git
+ dist
+ build
+addopts =
+ --doctest-modules
+ --durations=25
+ --color=yes
+
+[flake8]
+max-line-length = 120
+exclude = .tox,*.egg,build,temp
+select = E,W,F
+doctests = True
+verbose = 2
+# https://pep8.readthedocs.io/en/latest/intro.html#error-codes
+format = pylint
+# see: https://www.flake8rules.com/
+ignore = E731,F405,E402,W504,E501
+ # E731: Do not assign a lambda expression, use a def
+ # F405: name may be undefined, or defined from star imports: module
+ # E402: module level import not at top of file
+ # W504: line break after binary operator
+ # E501: line too long
+ # removed:
+ # F401: module imported but unused
+ # E231: missing whitespace after ‘,’, ‘;’, or ‘:’
+ # E127: continuation line over-indented for visual indent
+ # F403: ‘from module import *’ used; unable to detect undefined names
+
+
+[isort]
+# https://pycqa.github.io/isort/docs/configuration/options.html
+line_length = 120
+# see: https://pycqa.github.io/isort/docs/configuration/multi_line_output_modes.html
+multi_line_output = 0
+
+[yapf]
+based_on_style = pep8
+spaces_before_comment = 2
+COLUMN_LIMIT = 120
+COALESCE_BRACKETS = True
+SPACES_AROUND_POWER_OPERATOR = True
+SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = True
+SPLIT_BEFORE_CLOSING_BRACKET = False
+SPLIT_BEFORE_FIRST_ARGUMENT = False
+# EACH_DICT_ENTRY_ON_SEPARATE_LINE = False
diff --git a/age_estimator/mivolo/setup.py b/age_estimator/mivolo/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8bdbd35873b515f074af626a68ef51fda785fc8
--- /dev/null
+++ b/age_estimator/mivolo/setup.py
@@ -0,0 +1,50 @@
+# MiVOLO 🚀, Attribution-ShareAlike 4.0
+
+from pathlib import Path
+
+import pkg_resources as pkg
+from setuptools import find_packages, setup
+
+# Settings
+FILE = Path(__file__).resolve()
+PARENT = FILE.parent # root directory
+README = (PARENT / "README.md").read_text(encoding="utf-8")
+REQUIREMENTS = [f"{x.name}{x.specifier}" for x in pkg.parse_requirements((PARENT / "requirements.txt").read_text())]
+
+
+exec(open("mivolo/version.py").read())
+setup(
+ name="mivolo", # name of pypi package
+ version=__version__, # version of pypi package # noqa: F821
+ python_requires=">=3.8",
+ description="Layer MiVOLO for SOTA age and gender recognition",
+ long_description=README,
+ long_description_content_type="text/markdown",
+ url="https://github.com/WildChlamydia/MiVOLO",
+ project_urls={"Datasets": "https://wildchlamydia.github.io/lagenda/"},
+ author="Layer Team, SberDevices",
+ author_email="mvkuprashevich@gmail.com, irinakr4snova@gmail.com",
+ packages=find_packages(include=["mivolo", "mivolo.model", "mivolo.data", "mivolo.data.dataset"]), # required
+ include_package_data=True,
+ install_requires=REQUIREMENTS,
+ classifiers=[
+ "Development Status :: 4 - Beta",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Education",
+ "Intended Audience :: Science/Research",
+ "License :: Attribution-ShareAlike 4.0",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Topic :: Software Development",
+ "Topic :: Scientific/Engineering",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Scientific/Engineering :: Image Recognition",
+ "Operating System :: POSIX :: Linux",
+ "Operating System :: MacOS",
+ "Operating System :: Microsoft :: Windows",
+ ],
+ keywords="machine-learning, deep-learning, vision, ML, DL, AI, transformer, mivolo",
+)
diff --git a/age_estimator/mivolo/tools/dataset_visualization.py b/age_estimator/mivolo/tools/dataset_visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..08e0452f70ae3893c55e7a54408e6d4cb8f4b038
--- /dev/null
+++ b/age_estimator/mivolo/tools/dataset_visualization.py
@@ -0,0 +1,44 @@
+import argparse
+from typing import Dict, List
+
+import cv2
+from mivolo.data.data_reader import PictureInfo, read_csv_annotation_file
+from ultralytics.yolo.utils.plotting import Annotator, colors
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="Visualization")
+ parser.add_argument("--dataset_images", default="", type=str, required=True, help="path to images")
+ parser.add_argument("--annotation_file", default="", type=str, required=True, help="path to annotations")
+
+ return parser
+
+
+def visualize(images_dir, new_annotation_file):
+
+ bboxes_per_image: Dict[str, List[PictureInfo]] = read_csv_annotation_file(new_annotation_file, images_dir)[0]
+ print(f"Found {len(bboxes_per_image)} unique images")
+
+ for image_path, bboxes in bboxes_per_image.items():
+ im_cv = cv2.imread(image_path)
+ annotator = Annotator(im_cv)
+
+ for i, bbox_info in enumerate(bboxes):
+ label = f"{bbox_info.gender} Age: {bbox_info.age}"
+ if any(coord != -1 for coord in bbox_info.bbox):
+ # draw face bbox if exist
+ annotator.box_label(bbox_info.bbox, label, color=colors(i, True))
+
+ if any(coord != -1 for coord in bbox_info.person_bbox):
+ # draw person bbox if exist
+ annotator.box_label(bbox_info.person_bbox, "p " + label, color=colors(i, True))
+
+ im_cv = annotator.result()
+ cv2.imshow("image", im_cv)
+ cv2.waitKey(0)
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ visualize(args.dataset_images, args.annotation_file)
diff --git a/age_estimator/mivolo/tools/download_lagenda.py b/age_estimator/mivolo/tools/download_lagenda.py
new file mode 100644
index 0000000000000000000000000000000000000000..4130f9b9f1c8a722f64dcb8e36354bbb92802d61
--- /dev/null
+++ b/age_estimator/mivolo/tools/download_lagenda.py
@@ -0,0 +1,30 @@
+import os
+import zipfile
+
+# pip install gdown
+import gdown
+
+if __name__ == "__main__":
+ print("Download LAGENDA Age Gender Dataset... ")
+ out_dir = "LAGENDA"
+ os.makedirs(out_dir, exist_ok=True)
+
+ ids = ["1QXO0NlkABPZT6x1_0Uc2i6KAtdcrpTbG", "1mNYjYFb3MuKg-OL1UISoYsKObMUllbJx"]
+ dests = [f"{out_dir}/lagenda_benchmark_images.zip", f"{out_dir}/lagenda_annotation.csv"]
+
+ for file_id, destination in zip(ids, dests):
+ url = f"https://drive.google.com/uc?id={file_id}"
+ gdown.download(url, destination, quiet=False)
+
+ if not os.path.exists(destination):
+ print(f"ERROR: Can not download {destination}")
+ continue
+
+ if os.path.basename(destination).split(".")[-1] != ".zip":
+ continue
+
+ print(f"Extracting {destination} ... ")
+ with zipfile.ZipFile(destination) as zf:
+ zip_dir = zf.namelist()[0]
+ zf.extractall(f"./{out_dir}/")
+ os.remove(destination)
diff --git a/age_estimator/mivolo/tools/preparation_utils.py b/age_estimator/mivolo/tools/preparation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa769f028ef95019f4af9348f8d49d92450f6c5c
--- /dev/null
+++ b/age_estimator/mivolo/tools/preparation_utils.py
@@ -0,0 +1,140 @@
+from typing import Dict, List, Optional, Tuple
+
+import pandas as pd
+import torch
+from mivolo.data.data_reader import PictureInfo
+from mivolo.data.misc import assign_faces, box_iou
+from mivolo.model.yolo_detector import PersonAndFaceResult
+
+
+def save_annotations(images: List[PictureInfo], images_dir: str, out_file: str):
+ def get_age_str(age: Optional[str]) -> str:
+ age = "-1" if age is None else age.replace("(", "").replace(")", "").replace(" ", "").replace(",", ";")
+ return age
+
+ def get_gender_str(gender: Optional[str]) -> str:
+ gender = "-1" if gender is None else gender
+ return gender
+
+ headers = [
+ "img_name",
+ "age",
+ "gender",
+ "face_x0",
+ "face_y0",
+ "face_x1",
+ "face_y1",
+ "person_x0",
+ "person_y0",
+ "person_x1",
+ "person_y1",
+ ]
+ output_data = []
+ for image_info in images:
+ relative_image_path = image_info.image_path.replace(f"{images_dir}/", "")
+ face_x0, face_y0, face_x1, face_y1 = image_info.bbox
+ p_x0, p_y0, p_x1, p_y1 = image_info.person_bbox
+ output_data.append(
+ {
+ "img_name": relative_image_path,
+ "age": get_age_str(image_info.age),
+ "gender": get_gender_str(image_info.gender),
+ "face_x0": face_x0,
+ "face_y0": face_y0,
+ "face_x1": face_x1,
+ "face_y1": face_y1,
+ "person_x0": p_x0,
+ "person_y0": p_y0,
+ "person_x1": p_x1,
+ "person_y1": p_y1,
+ }
+ )
+ output_df = pd.DataFrame(output_data, columns=headers)
+ output_df.to_csv(out_file, sep=",", index=False)
+ print(f"Saved annotations for {len(images)} images to {out_file}")
+
+
+def get_main_face(
+ detected_objects: PersonAndFaceResult, coarse_bbox: Optional[List[int]] = None, coarse_thresh: float = 0.2
+) -> Tuple[Optional[List[int]], List[int]]:
+ """
+ Returns:
+ main_bbox (Optional[List[int]]): The most cenetered face bbox
+ other_bboxes (List[int]): indexes of other faces
+ """
+ face_bboxes_inds: List[int] = detected_objects.get_bboxes_inds("face")
+ if len(face_bboxes_inds) == 0:
+ return None, []
+
+ # sort found faces
+ face_bboxes_inds = sorted(face_bboxes_inds, key=lambda bb_ind: detected_objects.get_distance_to_center(bb_ind))
+ most_centered_bbox_ind = face_bboxes_inds[0]
+ main_bbox = detected_objects.get_bbox_by_ind(most_centered_bbox_ind).cpu().numpy().tolist()
+
+ iou_matrix: List[float] = [1.0] + [0.0] * (len(face_bboxes_inds) - 1)
+
+ if coarse_bbox is not None:
+ # calc iou between coarse_bbox and found bboxes
+ found_bboxes: List[torch.tensor] = [
+ detected_objects.get_bbox_by_ind(other_ind) for other_ind in face_bboxes_inds
+ ]
+
+ iou_matrix = (
+ box_iou(torch.stack([torch.tensor(coarse_bbox)]), torch.stack(found_bboxes).cpu()).numpy()[0].tolist()
+ )
+
+ if iou_matrix[0] < coarse_thresh:
+ # to avoid fp detections
+ main_bbox = None
+ other_bboxes = [ind for i, ind in enumerate(face_bboxes_inds[1:]) if iou_matrix[i] < coarse_thresh]
+ else:
+ other_bboxes = face_bboxes_inds[1:]
+
+ return main_bbox, other_bboxes
+
+
+def get_additional_bboxes(
+ detected_objects: PersonAndFaceResult, other_bboxes_inds: List[int], image_path: str, **kwargs
+) -> List[PictureInfo]:
+ is_face = True if "is_person" not in kwargs else False
+ is_person = False if "is_person" not in kwargs else True
+
+ additional_data: List[PictureInfo] = []
+ # extend other faces
+ for other_ind in other_bboxes_inds:
+ other_box: List[int] = detected_objects.get_bbox_by_ind(other_ind).cpu().numpy().tolist()
+ if is_face:
+ additional_data.append(PictureInfo(image_path, None, None, other_box))
+ elif is_person:
+ additional_data.append(PictureInfo(image_path, None, None, person_bbox=other_box))
+ return additional_data
+
+
+def associate_persons(face_bboxes: List[torch.tensor], detected_objects: PersonAndFaceResult):
+ person_bboxes_inds: List[int] = detected_objects.get_bboxes_inds("person")
+ person_bboxes: List[torch.tensor] = [detected_objects.get_bbox_by_ind(ind) for ind in person_bboxes_inds]
+
+ face_to_person_map: Dict[int, Optional[int]] = {ind: None for ind in range(len(face_bboxes))}
+
+ if len(person_bboxes) == 0:
+ return face_to_person_map, []
+
+ assigned_faces, unassigned_persons_inds = assign_faces(person_bboxes, face_bboxes)
+
+ for face_ind, person_ind in enumerate(assigned_faces):
+ person_ind = person_bboxes_inds[person_ind] if person_ind is not None else None
+ face_to_person_map[face_ind] = person_ind
+
+ unassigned_persons_inds = [person_bboxes_inds[person_ind] for person_ind in unassigned_persons_inds]
+ return face_to_person_map, unassigned_persons_inds
+
+
+def assign_persons(
+ faces_info: List[PictureInfo], faces_persons_map: Dict[int, int], detected_objects: PersonAndFaceResult
+):
+ for face_ind, person_ind in faces_persons_map.items():
+ if person_ind is None:
+ continue
+
+ person_bbox = detected_objects.get_bbox_by_ind(person_ind).cpu().numpy().tolist()
+ faces_info[face_ind].person_bbox = person_bbox
diff --git a/age_estimator/mivolo/tools/prepare_adience.py b/age_estimator/mivolo/tools/prepare_adience.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5e20788c2b6647d8c2b98875f65c01b52843569
--- /dev/null
+++ b/age_estimator/mivolo/tools/prepare_adience.py
@@ -0,0 +1,216 @@
+import argparse
+import os
+from collections import defaultdict
+from typing import Dict, List, Optional
+
+import cv2
+import pandas as pd
+import tqdm
+from mivolo.data.data_reader import PictureInfo, get_all_files
+from mivolo.model.yolo_detector import Detector, PersonAndFaceResult
+from preparation_utils import get_additional_bboxes, get_main_face, save_annotations
+
+
+def read_adience_annotations(annotations_files):
+ annotations_per_image = {}
+ stat_per_fold = defaultdict(int)
+ cols = ["user_id", "original_image", "face_id", "age", "gender"]
+ for file in annotations_files:
+ fold_name = os.path.basename(file).split(".")[0]
+ df = pd.read_csv(file, sep="\t", usecols=cols)
+ for index, row in df.iterrows():
+ face_id, img_name, user_id = row["face_id"], row["original_image"], row["user_id"]
+ aligned_face_path = f"faces/{user_id}/coarse_tilt_aligned_face.{face_id}.{img_name}"
+
+ age, gender = row["age"], row["gender"]
+ gender = gender.upper() if isinstance(gender, str) and gender != "u" else None
+ age = age if isinstance(age, str) else None
+
+ annotations_per_image[aligned_face_path] = {"age": age, "gender": gender, "fold": fold_name}
+ stat_per_fold[fold_name] += 1
+
+ print(f"Per fold images: {stat_per_fold}")
+ return annotations_per_image
+
+
+def read_data(images_dir, annotations_files, data_dir) -> List[PictureInfo]:
+ dataset_pictures: List[PictureInfo] = []
+
+ all_images = get_all_files(images_dir)
+ annotations_per_file = read_adience_annotations(annotations_files)
+
+ total, missed = 0, 0
+ stat_per_gender: Dict[str, int] = defaultdict(int)
+ missed_gender, missed_age, missed_gender_and_age = 0, 0, 0
+ stat_per_ages: Dict[str, int] = defaultdict(int)
+
+ # final age classes: '0;2', "4;6", "8;12", "15;20", "25;32", "38;43", "48;53", "60;100"
+
+ age_map = {
+ "2": "(0, 2)",
+ "3": "(0, 2)",
+ "13": "(8, 12)",
+ "(8, 23)": "(8, 12)",
+ "22": "(15, 20)",
+ "23": "(25, 32)",
+ "29": "(25, 32)",
+ "(27, 32)": "(25, 32)",
+ "32": "(25, 32)",
+ "34": "(25, 32)",
+ "35": "(25, 32)",
+ "36": "(38, 43)",
+ "(38, 42)": "(38, 43)",
+ "(38, 48)": "(38, 43)",
+ "42": "(38, 43)",
+ "45": "(38, 43)",
+ "46": "(48, 53)",
+ "55": "(48, 53)",
+ "56": "(48, 53)",
+ "57": "(60, 100)",
+ "58": "(60, 100)",
+ }
+ for image_path in all_images:
+ total += 1
+ relative_path = image_path.replace(f"{data_dir}/", "")
+ if relative_path not in annotations_per_file:
+ missed += 1
+ print("Can not find annotation for ", relative_path)
+ else:
+ annot = annotations_per_file[relative_path]
+ age, gender = annot["age"], annot["gender"]
+
+ if gender is None and age is not None:
+ missed_gender += 1
+ elif age is None and gender is not None:
+ missed_age += 1
+ elif gender is None and age is None:
+ missed_gender_and_age += 1
+ # skip such image
+ continue
+
+ if gender is not None:
+ stat_per_gender[gender] += 1
+
+ if age is not None:
+ age = age_map[age] if age in age_map else age
+ stat_per_ages[age] += 1
+
+ dataset_pictures.append(PictureInfo(image_path, age, gender))
+
+ print(f"Missed annots for images: {missed}/{total}")
+ print(f"Missed genders: {missed_gender}")
+ print(f"Missed ages: {missed_age}")
+ print(f"Missed ages and gender: {missed_gender_and_age}")
+ print(f"\nPer gender images: {stat_per_gender}")
+ ages = list(stat_per_ages.keys())
+ print(f"Per ages categories ({len(ages)} cats) :")
+ ages = sorted(ages, key=lambda x: int(x.split("(")[-1].split(",")[0].strip()))
+ for age in ages:
+ print(f"Age: {age} Count: {stat_per_ages[age]}")
+
+ return dataset_pictures
+
+
+def main(faces_dir: str, annotations: List[str], data_dir: str, detector_cfg: dict = None):
+ """
+ Generate a .txt annotation file with columns:
+ ["img_name", "age", "gender",
+ "face_x0", "face_y0", "face_x1", "face_y1",
+ "person_x0", "person_y0", "person_x1", "person_y1"]
+
+ All person bboxes here will be set to [-1, -1, -1, -1]
+
+ If detector_cfg is set, for each face bbox will be refined using detector.
+ Also, other detected faces wil be written to txt file (needed for further preprocessing)
+ """
+ # out directory for annotations
+ out_dir = os.path.join(data_dir, "annotations")
+ os.makedirs(out_dir, exist_ok=True)
+
+ # load annotations
+ images: List[PictureInfo] = read_data(faces_dir, annotations, data_dir)
+
+ if detector_cfg:
+ # detect faces with yolo detector
+ faces_not_found, images_with_other_faces = 0, 0
+ other_faces: List[PictureInfo] = []
+
+ detector_weights, device = detector_cfg["weights"], detector_cfg["device"]
+ detector = Detector(detector_weights, device, verbose=False, conf_thresh=0.1, iou_thresh=0.2)
+ for image_info in tqdm.tqdm(images, desc="Detecting faces: "):
+ cv_im = cv2.imread(image_info.image_path)
+ im_h, im_w = cv_im.shape[:2]
+
+ detected_objects: PersonAndFaceResult = detector.predict(cv_im)
+ main_bbox, other_bboxes_inds = get_main_face(detected_objects)
+
+ if main_bbox is None:
+ # use a full image as face bbox
+ faces_not_found += 1
+ image_info.bbox = [0, 0, im_w, im_h]
+ else:
+ image_info.bbox = main_bbox
+
+ if len(other_bboxes_inds):
+ images_with_other_faces += 1
+
+ additional_faces = get_additional_bboxes(detected_objects, other_bboxes_inds, image_info.image_path)
+ other_faces.extend(additional_faces)
+
+ print(f"Faces not detected: {faces_not_found}/{len(images)}")
+ print(f"Images with other faces: {images_with_other_faces}/{len(images)}")
+ print(f"Other faces: {len(other_faces)}")
+
+ images = images + other_faces
+
+ else:
+ # use a full image as face bbox
+ for image_info in tqdm.tqdm(images, desc="Collect face bboxes: "):
+ cv_im = cv2.imread(image_info.image_path)
+ im_h, im_w = cv_im.shape[:2]
+ image_info.bbox = [0, 0, im_w, im_h] # xyxy
+
+ save_annotations(images, faces_dir, out_file=os.path.join(out_dir, "adience_annotations.csv"))
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="Adience")
+ parser.add_argument(
+ "--dataset_path",
+ default="data/adience",
+ type=str,
+ required=True,
+ help="path to dataset with faces/ and fold_{i}_data.txt files",
+ )
+ parser.add_argument(
+ "--detector_weights", default=None, type=str, required=False, help="path to face and person detector"
+ )
+ parser.add_argument("--device", default="cuda:0", type=str, required=False, help="device to inference detector")
+
+ return parser
+
+
+if __name__ == "__main__":
+
+ parser = get_parser()
+ args = parser.parse_args()
+
+ data_dir = args.dataset_path
+ faces_dir = os.path.join(data_dir, "faces")
+
+ if data_dir[-1] == "/":
+ data_dir = data_dir[:-1]
+
+ annotations = [
+ os.path.join(data_dir, "fold_0_data.txt"),
+ os.path.join(data_dir, "fold_1_data.txt"),
+ os.path.join(data_dir, "fold_2_data.txt"),
+ os.path.join(data_dir, "fold_3_data.txt"),
+ os.path.join(data_dir, "fold_4_data.txt"),
+ ]
+
+ detector_cfg: Optional[Dict[str, str]] = None
+ if args.detector_weights is not None:
+ detector_cfg = {"weights": args.detector_weights, "device": "cuda:0"}
+
+ main(faces_dir, annotations, data_dir, detector_cfg)
diff --git a/age_estimator/mivolo/tools/prepare_agedb.py b/age_estimator/mivolo/tools/prepare_agedb.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc25801a6699ea44cff347361692031a6ba2838e
--- /dev/null
+++ b/age_estimator/mivolo/tools/prepare_agedb.py
@@ -0,0 +1,57 @@
+import argparse
+import json
+import os
+from typing import Dict, Optional
+
+from prepare_cacd import collect_faces
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="AgeDB")
+ parser.add_argument(
+ "--dataset_path",
+ default="data/AgeDB",
+ type=str,
+ required=True,
+ help="path to dataset with AgeDB folder",
+ )
+ parser.add_argument(
+ "--detector_weights", default=None, type=str, required=False, help="path to face and person detector"
+ )
+ parser.add_argument("--device", default="cuda:0", type=str, required=False, help="device to inference detector")
+
+ return parser
+
+
+if __name__ == "__main__":
+
+ parser = get_parser()
+ args = parser.parse_args()
+
+ data_dir = args.dataset_path
+ if data_dir[-1] == "/":
+ data_dir = data_dir[:-1]
+
+ faces_dir = os.path.join(data_dir, "AgeDB")
+
+ # https://github.com/paplhjak/Facial-Age-Estimation-Benchmark-Databases/tree/main
+ json_path = os.path.join(data_dir, "AgeDB.json")
+ with open(json_path, "r") as stream:
+ annotations = json.load(stream)
+
+ detector_cfg: Optional[Dict[str, str]] = None
+ if args.detector_weights is not None:
+ detector_cfg = {"weights": args.detector_weights, "device": "cuda:0"}
+
+ splits = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
+ collect_faces(
+ faces_dir,
+ annotations,
+ data_dir,
+ detector_cfg,
+ padding=0.1,
+ splits=splits,
+ db_name="agedb",
+ find_persons=True,
+ use_coarse_faces=True,
+ )
diff --git a/age_estimator/mivolo/tools/prepare_cacd.py b/age_estimator/mivolo/tools/prepare_cacd.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddff3c29f3eb804eb94a20caa7a72bddde24a02e
--- /dev/null
+++ b/age_estimator/mivolo/tools/prepare_cacd.py
@@ -0,0 +1,254 @@
+import argparse
+import json
+import os
+from collections import defaultdict
+from typing import Dict, List, Optional
+
+import cv2
+import tqdm
+from mivolo.data.data_reader import PictureInfo, get_all_files
+from mivolo.modeling.yolo_detector import Detector, PersonAndFaceResult
+from preparation_utils import get_additional_bboxes, get_main_face, save_annotations
+from prepare_fairface import find_persons_on_image
+
+
+def get_im_name(img_path):
+ im_name = img_path.split("/")[-1]
+ im_name = im_name.replace("é", "e").replace("é", "e")
+ im_name = im_name.replace("ó", "o").replace("ó", "o")
+ im_name = im_name.replace("å", "a").replace("å", "a")
+ im_name = im_name.replace("ñ", "n").replace("ñ", "n")
+ im_name = im_name.replace("ö", "o").replace("ö", "o")
+ im_name = im_name.replace("ä", "a").replace("ä", "a")
+ im_name = im_name.replace("ü", "u").replace("ü", "u")
+ im_name = im_name.replace("á", "a").replace("á", "a")
+ im_name = im_name.replace("ë", "e").replace("ë", "e")
+ im_name = im_name.replace("í", "i").replace("í", "i")
+
+ return im_name
+
+
+def read_json_annotations(annotations: List[str], splits: List[str]) -> Dict[str, dict]:
+ print("Parsing annotations")
+ annotations_per_image = {}
+ stat_per_split: Dict[str, int] = defaultdict(int)
+
+ missed = 0
+ for item_id, face in tqdm.tqdm(enumerate(annotations), total=len(annotations)):
+ im_name = get_im_name(face["img_path"])
+ split = splits[int(face["folder"])]
+
+ stat_per_split[split] += 1
+
+ gender = face["gender"] if "gender" in face else None
+ if "alignment_source" in face and face["alignment_source"] == "file not found":
+ missed += 1
+
+ annotations_per_image[im_name] = {"age": str(face["age"]), "gender": gender, "split": split}
+
+ print("missed annots: ", missed)
+
+ print(f"Per split images: {stat_per_split}")
+ print(f"Found {len(annotations_per_image)} annotations")
+ return annotations_per_image
+
+
+def read_data(images_dir, annotations, splits) -> Dict[str, List[PictureInfo]]:
+ dataset: Dict[str, List[PictureInfo]] = defaultdict(list)
+ all_images = get_all_files(images_dir)
+ print(f"Found {len(all_images)} images")
+
+ annotations_per_file: Dict[str, dict] = read_json_annotations(annotations, splits)
+
+ total, missed = 0, 0
+ missed_gender_and_age = 0
+ stat_per_ages: Dict[str, int] = defaultdict(int)
+ stat_per_gender: Dict[str, int] = defaultdict(int)
+
+ for image_path in all_images:
+ total += 1
+ image_name = get_im_name(image_path)
+
+ if image_name not in annotations_per_file:
+ missed += 1
+ print(f"Can not find annotation for {image_name}")
+ else:
+ annot = annotations_per_file[image_name]
+ age, gender, split = annot["age"], annot["gender"], annot["split"]
+
+ if gender is None and age is None:
+ missed_gender_and_age += 1
+ # skip such image
+ continue
+
+ if age is not None:
+ stat_per_ages[age] += 1
+ if gender is not None:
+ stat_per_gender[gender] += 1
+
+ info = PictureInfo(image_path, age, gender)
+ dataset[split].append(info)
+
+ print(f"Missed annots for images: {missed}/{total}")
+ print(f"Missed ages and gender: {missed_gender_and_age}")
+ ages = list(stat_per_ages.keys())
+ print(f"Per gender stat: {stat_per_gender}")
+ print(f"Per ages categories ({len(ages)} cats) :")
+ ages = sorted(ages, key=lambda x: int(x.split("(")[-1].split(",")[0].strip()))
+ for age in ages:
+ print(f"Age: {age} Count: {stat_per_ages[age]}")
+
+ return dataset
+
+
+def collect_faces(
+ faces_dir: str,
+ annotations: List[dict],
+ data_dir: str,
+ detector_cfg: dict = None,
+ padding: float = 0.1,
+ splits: List[str] = [],
+ db_name: str = "",
+ use_coarse_persons: bool = False,
+ find_persons: bool = False,
+ person_padding: float = 0.0,
+ use_coarse_faces: bool = False,
+):
+ """
+ Generate train, val, test .txt annotation files with columns:
+ ["img_name", "age", "gender",
+ "face_x0", "face_y0", "face_x1", "face_y1",
+ "person_x0", "person_y0", "person_x1", "person_y1"]
+
+ All person bboxes here will be set to [-1, -1, -1, -1]
+
+ If detector_cfg is set, for each face bbox will be refined using detector.
+ Also, other detected faces wil be written to txt file (needed for further preprocessing)
+ """
+
+ # out directory for annotations
+ out_dir = os.path.join(data_dir, "annotations")
+ os.makedirs(out_dir, exist_ok=True)
+
+ # load annotations
+ images_per_split: Dict[str, List[PictureInfo]] = read_data(faces_dir, annotations, splits)
+
+ for split_ind, (split, images) in enumerate(images_per_split.items()):
+ print(f"Processing {split} split ({split_ind}/{len(images_per_split)})...")
+ if detector_cfg:
+ # detect faces with yolo detector
+ faces_not_found, images_with_other_faces = 0, 0
+ other_faces: List[PictureInfo] = []
+
+ detector_weights, device = detector_cfg["weights"], detector_cfg["device"]
+ detector = Detector(detector_weights, device, verbose=False, conf_thresh=0.1, iou_thresh=0.2)
+ for image_info in tqdm.tqdm(images, desc="Detecting faces: "):
+ cv_im = cv2.imread(image_info.image_path)
+ im_h, im_w = cv_im.shape[:2]
+
+ pad_x, pad_y = int(padding * im_w), int(padding * im_h)
+ coarse_face_bbox = [pad_x, pad_y, im_w - pad_x, im_h - pad_y] # xyxy
+
+ detected_objects: PersonAndFaceResult = detector.predict(cv_im)
+ main_bbox, other_faces_inds = get_main_face(detected_objects, coarse_face_bbox)
+
+ if len(other_faces_inds):
+ images_with_other_faces += 1
+
+ if main_bbox is None:
+ # use a full image as a face bbox
+ faces_not_found += 1
+ main_bbox = coarse_face_bbox
+ elif use_coarse_faces:
+ main_bbox = coarse_face_bbox
+ image_info.bbox = main_bbox
+
+ if find_persons:
+ additional_faces, additional_persons = find_persons_on_image(
+ image_info, main_bbox, detected_objects, other_faces_inds, device
+ )
+ # add all additional faces
+ other_faces.extend(additional_faces)
+ # add persons with empty faces
+ other_faces.extend(additional_persons)
+ else:
+ additional_faces = get_additional_bboxes(detected_objects, other_faces_inds, image_info.image_path)
+ other_faces.extend(additional_faces)
+ # full image as a person bbox
+ coarse_person_bbox = [0, 0, im_w, im_h] # xyxy
+ if find_persons:
+ image_info.person_bbox = coarse_person_bbox
+
+ print(f"Faces not detected: {faces_not_found}/{len(images)}")
+ print(f"Images with other faces: {images_with_other_faces}/{len(images)}")
+ print(f"Other faces: {len(other_faces)}")
+
+ images = images + other_faces
+
+ else:
+ for image_info in tqdm.tqdm(images, desc="Collect face bboxes: "):
+
+ cv_im = cv2.imread(image_info.image_path)
+ im_h, im_w = cv_im.shape[:2]
+
+ # use a full image as a face bbox
+ pad_x, pad_y = int(padding * im_w), int(padding * im_h)
+ image_info.bbox = [pad_x, pad_y, im_w - pad_x, im_h - pad_y] # xyxy
+
+ if use_coarse_persons or find_persons:
+ # full image as a person bbox
+ pad_x_p, pad_y_p = int(person_padding * im_w), int(person_padding * im_h)
+ image_info.person_bbox = [pad_x_p, pad_y_p, im_w - pad_x_p, im_h] # xyxy
+
+ save_annotations(images, faces_dir, out_file=os.path.join(out_dir, f"{db_name}_{split}_annotations.csv"))
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="CACD")
+ parser.add_argument(
+ "--dataset_path",
+ default="data/CACD",
+ type=str,
+ required=True,
+ help="path to dataset with CACD200 folder",
+ )
+ parser.add_argument(
+ "--detector_weights", default=None, type=str, required=False, help="path to face and person detector"
+ )
+ parser.add_argument("--device", default="cuda:0", type=str, required=False, help="device to inference detector")
+
+ return parser
+
+
+if __name__ == "__main__":
+
+ parser = get_parser()
+ args = parser.parse_args()
+
+ data_dir = args.dataset_path
+ if data_dir[-1] == "/":
+ data_dir = data_dir[:-1]
+
+ faces_dir = os.path.join(data_dir, "CACD2000")
+
+ # https://github.com/paplhjak/Facial-Age-Estimation-Benchmark-Databases/tree/main
+ json_path = os.path.join(data_dir, "CACD2000.json")
+ with open(json_path, "r") as stream:
+ annotations = json.load(stream)
+
+ detector_cfg: Optional[Dict[str, str]] = None
+ if args.detector_weights is not None:
+ detector_cfg = {"weights": args.detector_weights, "device": "cuda:0"}
+
+ splits = ["train", "valid", "test"]
+ collect_faces(
+ faces_dir,
+ annotations,
+ data_dir,
+ detector_cfg,
+ padding=0.2,
+ splits=splits,
+ db_name="cacd",
+ find_persons=True,
+ use_coarse_faces=True,
+ )
diff --git a/age_estimator/mivolo/tools/prepare_fairface.py b/age_estimator/mivolo/tools/prepare_fairface.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3e975bdeb2f3825c60ac49317e62f140d2d21aa
--- /dev/null
+++ b/age_estimator/mivolo/tools/prepare_fairface.py
@@ -0,0 +1,205 @@
+import argparse
+import os
+from collections import defaultdict
+from typing import Dict, List, Optional, Tuple
+
+import cv2
+import pandas as pd
+import torch
+import tqdm
+from mivolo.data.data_reader import PictureInfo, get_all_files
+from mivolo.modeling.yolo_detector import Detector, PersonAndFaceResult
+from preparation_utils import assign_persons, associate_persons, get_additional_bboxes, get_main_face, save_annotations
+
+
+def read_fairface_annotations(annotations_files):
+ annotations_per_image = {}
+ cols = ["file", "age", "gender"]
+
+ for file in annotations_files:
+ split_name = os.path.basename(file).split(".")[0].split("_")[-1]
+ df = pd.read_csv(file, sep=",", usecols=cols)
+ for index, row in df.iterrows():
+ aligned_face_path = row["file"]
+
+ age, gender = row["age"], row["gender"]
+ # M or F
+ gender = gender[0].upper() if isinstance(gender, str) else None
+ age = age.replace("-", ";") if isinstance(age, str) else None
+
+ annotations_per_image[aligned_face_path] = {"age": age, "gender": gender, "split": split_name}
+ return annotations_per_image
+
+
+def read_data(images_dir, annotations_files) -> Tuple[List[PictureInfo], List[PictureInfo]]:
+ dataset_pictures_train: List[PictureInfo] = []
+ dataset_pictures_val: List[PictureInfo] = []
+
+ all_images = get_all_files(images_dir)
+ annotations_per_file = read_fairface_annotations(annotations_files)
+
+ SPLIT_TYPE = Dict[str, Dict[str, int]]
+ splits_stat_per_gender: SPLIT_TYPE = defaultdict(lambda: defaultdict(int))
+ splits_stat_per_ages: SPLIT_TYPE = defaultdict(lambda: defaultdict(int))
+
+ age_map = {"more than 70": "70;120"}
+ for image_path in all_images:
+ relative_path = image_path.replace(f"{images_dir}/", "")
+
+ annot = annotations_per_file[relative_path]
+ split = annot["split"]
+ age, gender = annot["age"], annot["gender"]
+ age = age_map[age] if age in age_map else age
+
+ splits_stat_per_gender[split][gender] += 1
+ splits_stat_per_ages[split][age] += 1
+
+ if split == "train":
+ dataset_pictures_train.append(PictureInfo(image_path, age, gender))
+ elif split == "val":
+ dataset_pictures_val.append(PictureInfo(image_path, age, gender))
+ else:
+ raise ValueError(f"Unknown split name: {split}")
+
+ print(f"Found train/val images: {len(dataset_pictures_train)}/{len(dataset_pictures_val)}")
+ for split, stat_per_gender in splits_stat_per_gender.items():
+ print(f"\n{split} Per gender images: {stat_per_gender}")
+
+ for split, stat_per_ages in splits_stat_per_ages.items():
+ ages = list(stat_per_ages.keys())
+ print(f"\n{split} Per ages categories ({len(ages)} cats) :")
+ ages = sorted(ages, key=lambda x: int(x.split(";")[0].strip()))
+ for age in ages:
+ print(f"Age: {age} Count: {stat_per_ages[age]}")
+
+ return dataset_pictures_train, dataset_pictures_val
+
+
+def find_persons_on_image(image_info, main_bbox, detected_objects, other_faces_inds, device):
+ # find person_ind for each face (main + other_faces)
+ all_faces: List[torch.tensor] = [torch.tensor(main_bbox).to(device)] + [
+ detected_objects.get_bbox_by_ind(ind) for ind in other_faces_inds
+ ]
+ faces_persons_map, other_persons_inds = associate_persons(all_faces, detected_objects)
+
+ additional_faces: List[PictureInfo] = get_additional_bboxes(
+ detected_objects, other_faces_inds, image_info.image_path
+ )
+
+ # set person bboxes for all faces (main + additional_faces)
+ assign_persons([image_info] + additional_faces, faces_persons_map, detected_objects)
+ if faces_persons_map[0] is not None:
+ assert all(coord != -1 for coord in image_info.person_bbox)
+
+ additional_persons: List[PictureInfo] = get_additional_bboxes(
+ detected_objects, other_persons_inds, image_info.image_path, is_person=True
+ )
+
+ return additional_faces, additional_persons
+
+
+def main(faces_dir: str, annotations: List[str], data_dir: str, detector_cfg: dict = None):
+ """
+ Generate a .txt annotation file with columns:
+ ["img_name", "age", "gender",
+ "face_x0", "face_y0", "face_x1", "face_y1",
+ "person_x0", "person_y0", "person_x1", "person_y1"]
+
+ If detector_cfg is set, for each face bbox will be refined using detector.
+ Person bbox will be assigned for each face.
+ Also, other detected faces and persons wil be written to txt file (needed for further preprocessing)
+ """
+ # out directory for txt annotations
+ out_dir = os.path.join(data_dir, "annotations")
+ os.makedirs(out_dir, exist_ok=True)
+
+ # load annotations
+ dataset_pictures_train, dataset_pictures_val = read_data(faces_dir, annotations)
+
+ for images, split_name in zip([dataset_pictures_train, dataset_pictures_val], ["train", "val"]):
+
+ if detector_cfg:
+ # detect faces with yolo detector
+ faces_not_found, images_with_other_faces = 0, 0
+ other_faces: List[PictureInfo] = []
+
+ detector_weights, device = detector_cfg["weights"], detector_cfg["device"]
+ detector = Detector(detector_weights, device, verbose=False, conf_thresh=0.1, iou_thresh=0.2)
+ for image_info in tqdm.tqdm(images, desc=f"Detecting {split_name} faces: "):
+ cv_im = cv2.imread(image_info.image_path)
+ im_h, im_w = cv_im.shape[:2]
+ # all images are 448x448 and with 125 padding
+ coarse_bbox = [125, 125, im_w - 125, im_h - 125] # xyxy
+
+ detected_objects: PersonAndFaceResult = detector.predict(cv_im)
+ main_bbox, other_faces_inds = get_main_face(detected_objects, coarse_bbox)
+ if len(other_faces_inds):
+ images_with_other_faces += 1
+
+ if main_bbox is None:
+ # use a full image as face bbox
+ faces_not_found += 1
+ main_bbox = coarse_bbox
+ image_info.bbox = main_bbox
+
+ additional_faces, additional_persons = find_persons_on_image(
+ image_info, main_bbox, detected_objects, other_faces_inds, device
+ )
+
+ # add all additional faces
+ other_faces.extend(additional_faces)
+
+ # add persons with empty faces
+ other_faces.extend(additional_persons)
+
+ print(f"Faces not detected: {faces_not_found}/{len(images)}")
+ print(f"Images with other faces: {images_with_other_faces}/{len(images)}")
+ print(f"Other bboxes (faces/persons): {len(other_faces)}")
+
+ images = images + other_faces
+
+ else:
+ for image_info in tqdm.tqdm(images, desc="Collect face bboxes: "):
+ cv_im = cv2.imread(image_info.image_path)
+ im_h, im_w = cv_im.shape[:2]
+ # all images are 448x448 and with 125 padding
+ image_info.bbox = [125, 125, im_w - 125, im_h - 125] # xyxy
+
+ save_annotations(images, faces_dir, out_file=os.path.join(out_dir, f"{split_name}_annotations.csv"))
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="FairFace")
+ parser.add_argument(
+ "--dataset_path",
+ default="data/FairFace",
+ type=str,
+ required=True,
+ help="path to folder with fairface-img-margin125-trainval/ and fairface_label_{split}.csv",
+ )
+ parser.add_argument(
+ "--detector_weights", default=None, type=str, required=False, help="path to face and person detector"
+ )
+ parser.add_argument("--device", default="cuda:0", type=str, required=False, help="device to inference detector")
+
+ return parser
+
+
+if __name__ == "__main__":
+
+ parser = get_parser()
+ args = parser.parse_args()
+
+ data_dir = args.dataset_path
+ faces_dir = os.path.join(data_dir, "fairface-img-margin125-trainval")
+
+ if data_dir[-1] == "/":
+ data_dir = data_dir[:-1]
+
+ annotations = [os.path.join(data_dir, "fairface_label_train.csv"), os.path.join(data_dir, "fairface_label_val.csv")]
+
+ detector_cfg: Optional[Dict[str, str]] = None
+ if args.detector_weights is not None:
+ detector_cfg = {"weights": args.detector_weights, "device": "cuda:0"}
+
+ main(faces_dir, annotations, data_dir, detector_cfg)
diff --git a/age_estimator/models.py b/age_estimator/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..71a836239075aa6e6e4ecb700e9c42c95c022d91
--- /dev/null
+++ b/age_estimator/models.py
@@ -0,0 +1,3 @@
+from django.db import models
+
+# Create your models here.
diff --git a/age_estimator/tests.py b/age_estimator/tests.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ce503c2dd97ba78597f6ff6e4393132753573f6
--- /dev/null
+++ b/age_estimator/tests.py
@@ -0,0 +1,3 @@
+from django.test import TestCase
+
+# Create your tests here.
diff --git a/age_estimator/urls.py b/age_estimator/urls.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b67fb5fb3645ac05f1978ca1f807206d024e526
--- /dev/null
+++ b/age_estimator/urls.py
@@ -0,0 +1,12 @@
+from django.urls import path
+from . import views
+
+from .views import AgeEstimation
+# from bpm_app.views import calculate_heart_rate
+
+# urlpatterns = [
+# path('calculate-bpm/', , name='calculate_bpm'),
+# ]
+urlpatterns = [
+ path('age_estimator/', AgeEstimation.as_view(), name='age_estimation_view'),
+]
diff --git a/age_estimator/views.py b/age_estimator/views.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebf2a40ac9245bd365ac40ca765ca914f613b25e
--- /dev/null
+++ b/age_estimator/views.py
@@ -0,0 +1,62 @@
+from django.shortcuts import render
+
+# Create your views here.
+from django.http import JsonResponse
+from django.views.decorators.csrf import csrf_exempt
+import json
+from rest_framework.views import APIView
+from django.core.files.storage import default_storage
+# from .demo import run_inference # Import from demo.py
+from age_estimator.mivolo.demo_copy import main
+import os
+
+
+# @csrf_exempt
+class AgeEstimation(APIView):
+ def post(self, request):
+ # Save the uploaded video file
+ try:
+ video_file = request.FILES['video_file']
+
+
+# def age_estimation_view(request):
+# if request.method == "POST":
+# try:
+# data = json.loads(request.body)
+ # video_path = data.get("video_path")
+ output_folder = 'output'
+ detector_weights = 'age_estimator/mivolo/models/yolov8x_person_face.pt'
+ checkpoint = 'age_estimator/mivolo/models/model_imdb_cross_person_4.22_99.46.pth.tar'
+ # detector_weights = data.get("detector_weights")
+ # checkpoint = data.get("checkpoint")
+ # device = data.get("device", "cpu")
+ # with_persons = data.get("with_persons", False)
+ # disable_faces = data.get("disable_faces", False)
+ # draw = data.get("draw", False)
+ device = 'cpu'
+ with_persons = True
+ disable_faces = False
+ draw = True
+ file_name = default_storage.save(video_file.name, video_file)
+ video_file_path = os.path.join(default_storage.location, file_name)
+
+ # Check for required parameters
+ if not video_file_path or not detector_weights or not checkpoint:
+ return JsonResponse({"error": "Missing required fields: 'video_path', 'detector_weights', or 'checkpoint'"}, status=400)
+
+ # Run the inference function from demo.py
+ absolute_age, lower_bound, upper_bound = main(
+ video_file_path, output_folder, detector_weights, checkpoint, device, with_persons, disable_faces, draw
+ )
+
+ # print(absolute_age)
+ # Return the result as a JSON response
+ return JsonResponse({
+ # "absolute_age": absolute_age,
+ "age_range": f"{lower_bound} - {upper_bound}"
+ })
+
+ except Exception as e:
+ return JsonResponse({"error": str(e)}, status=500)
+
+ return JsonResponse({"error": "Invalid request method"}, status=400)