diff --git a/age_estimator/.DS_Store b/age_estimator/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c7012a9446f53640aa9f7828d81d4df21da7d4bc Binary files /dev/null and b/age_estimator/.DS_Store differ diff --git a/age_estimator/__init__.py b/age_estimator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/age_estimator/__pycache__/__init__.cpython-38.pyc b/age_estimator/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee8f9f068487da04cd70ddab3bc83f13ad294f9d Binary files /dev/null and b/age_estimator/__pycache__/__init__.cpython-38.pyc differ diff --git a/age_estimator/__pycache__/admin.cpython-38.pyc b/age_estimator/__pycache__/admin.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bfc1a54189550326c801a79eaa39a7b253a33aa Binary files /dev/null and b/age_estimator/__pycache__/admin.cpython-38.pyc differ diff --git a/age_estimator/__pycache__/apps.cpython-38.pyc b/age_estimator/__pycache__/apps.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..674e8aaed47e31e1c6de9176165d290cbf0315f6 Binary files /dev/null and b/age_estimator/__pycache__/apps.cpython-38.pyc differ diff --git a/age_estimator/__pycache__/models.cpython-38.pyc b/age_estimator/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1223fb1df6113661748a85bc509b30106d745578 Binary files /dev/null and b/age_estimator/__pycache__/models.cpython-38.pyc differ diff --git a/age_estimator/__pycache__/urls.cpython-38.pyc b/age_estimator/__pycache__/urls.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a63c98a4178e55b8cf65b65e6b3fca723c138f0 Binary files /dev/null and b/age_estimator/__pycache__/urls.cpython-38.pyc differ diff --git a/age_estimator/__pycache__/views.cpython-38.pyc b/age_estimator/__pycache__/views.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3af0cfdbb41d85f501d0818b27d7e1fdd7db27f2 Binary files /dev/null and b/age_estimator/__pycache__/views.cpython-38.pyc differ diff --git a/age_estimator/admin.py b/age_estimator/admin.py new file mode 100644 index 0000000000000000000000000000000000000000..8c38f3f3dad51e4585f3984282c2a4bec5349c1e --- /dev/null +++ b/age_estimator/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/age_estimator/apps.py b/age_estimator/apps.py new file mode 100644 index 0000000000000000000000000000000000000000..bd67634fc0752a859ed6a8cba63dd9b80d97e8bb --- /dev/null +++ b/age_estimator/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class AgeEstimatorConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'age_estimator' diff --git a/age_estimator/migrations/__init__.py b/age_estimator/migrations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/age_estimator/migrations/__pycache__/__init__.cpython-38.pyc b/age_estimator/migrations/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..586cdd5989324468cc02ad0c0f839bd7604f3116 Binary files /dev/null and b/age_estimator/migrations/__pycache__/__init__.cpython-38.pyc differ diff --git a/age_estimator/mivolo/.DS_Store b/age_estimator/mivolo/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c43ee62a22d53cc75b181ee2096e26474c200495 Binary files /dev/null and b/age_estimator/mivolo/.DS_Store differ diff --git a/age_estimator/mivolo/.flake8 b/age_estimator/mivolo/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..c61c9b5694ff9580094261468d7e7e6c87abd51c --- /dev/null +++ b/age_estimator/mivolo/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 120 +inline-quotes = " +multiline-quotes = " +ignore = E203,W503 diff --git a/age_estimator/mivolo/.gitignore b/age_estimator/mivolo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8b048b39eff5d43df1234e5731f045960381751c --- /dev/null +++ b/age_estimator/mivolo/.gitignore @@ -0,0 +1,85 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +*.DS_Store + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# PyTorch weights +*.tar +*.pth +*.pt +*.torch +*.gz +Untitled.ipynb +Testing notebook.ipynb + +# Root dir exclusions +/*.csv +/*.yaml +/*.json +/*.jpg +/*.png +/*.zip +/*.tar.* +*.jpg +*.png +*.avi +*.mp4 +*.svg + +.mypy_cache/ +.vscode/ +.idea + +output/ +input/ + +run.sh diff --git a/age_estimator/mivolo/.isort.cfg b/age_estimator/mivolo/.isort.cfg new file mode 100644 index 0000000000000000000000000000000000000000..63135f1820165af7421d82b767e964f8f2daa464 --- /dev/null +++ b/age_estimator/mivolo/.isort.cfg @@ -0,0 +1,5 @@ +[settings] +profile = black +line_length = 120 +src_paths = ["mivolo", "scripts", "tools"] +filter_files = true diff --git a/age_estimator/mivolo/.pre-commit-config.yaml b/age_estimator/mivolo/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f722055e25bcc3685762d14109f68518abfce1ff --- /dev/null +++ b/age_estimator/mivolo/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.2.0 + hooks: + - id: check-yaml + args: ['--unsafe'] + - id: check-toml + - id: debug-statements + - id: end-of-file-fixer + exclude: poetry.lock + - id: trailing-whitespace +- repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: [ "--profile", "black", "--filter-files" ] +- repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + args: ["--line-length", "120"] +- repo: https://github.com/PyCQA/flake8 + rev: 3.9.2 + hooks: + - id: flake8 + args: [ "--config", ".flake8" ] + additional_dependencies: [ "flake8-quotes" ] +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.942 + hooks: + - id: mypy diff --git a/age_estimator/mivolo/CHANGELOG.md b/age_estimator/mivolo/CHANGELOG.md new file mode 100644 index 0000000000000000000000000000000000000000..91904fb73d6f47d730b38ebf97113ecf729abc59 --- /dev/null +++ b/age_estimator/mivolo/CHANGELOG.md @@ -0,0 +1,16 @@ + +## 0.4.1dev (15.08.2023) + +### Added +- Support for video streams, including YouTube URLs +- Instructions and explanations for various export types. + +### Changed +- Removed CutOff operation. It has been proven to be ineffective for inference time and quite costly at the same time. Now it is only used during training. + +## 0.4.2dev (22.09.2023) + +### Added + +- Script for AgeDB dataset convertation to csv format +- Additional metrics were added to README diff --git a/age_estimator/mivolo/README.md b/age_estimator/mivolo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7fafe2cd6b32ca98dcf7bedc3c4fdaa1276cc83d --- /dev/null +++ b/age_estimator/mivolo/README.md @@ -0,0 +1,417 @@ +
+

+ + +

+
+
+ + + +## MiVOLO: Multi-input Transformer for Age and Gender Estimation + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/mivolo-multi-input-transformer-for-age-and/age-estimation-on-utkface)](https://paperswithcode.com/sota/age-estimation-on-utkface?p=mivolo-multi-input-transformer-for-age-and) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/age-estimation-on-imdb-clean)](https://paperswithcode.com/sota/age-estimation-on-imdb-clean?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/facial-attribute-classification-on-fairface)](https://paperswithcode.com/sota/facial-attribute-classification-on-fairface?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/age-and-gender-classification-on-adience)](https://paperswithcode.com/sota/age-and-gender-classification-on-adience?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/age-and-gender-classification-on-adience-age)](https://paperswithcode.com/sota/age-and-gender-classification-on-adience-age?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/age-and-gender-estimation-on-lagenda-age)](https://paperswithcode.com/sota/age-and-gender-estimation-on-lagenda-age?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/gender-prediction-on-lagenda)](https://paperswithcode.com/sota/gender-prediction-on-lagenda?p=beyond-specialization-assessing-the-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/mivolo-multi-input-transformer-for-age-and/age-estimation-on-agedb)](https://paperswithcode.com/sota/age-estimation-on-agedb?p=mivolo-multi-input-transformer-for-age-and) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/mivolo-multi-input-transformer-for-age-and/gender-prediction-on-agedb)](https://paperswithcode.com/sota/gender-prediction-on-agedb?p=mivolo-multi-input-transformer-for-age-and) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-specialization-assessing-the-1/age-estimation-on-cacd)](https://paperswithcode.com/sota/age-estimation-on-cacd?p=beyond-specialization-assessing-the-1) + +> [**MiVOLO: Multi-input Transformer for Age and Gender Estimation**](https://arxiv.org/abs/2307.04616), +> Maksim Kuprashevich, Irina Tolstykh, +> *2023 [arXiv 2307.04616](https://arxiv.org/abs/2307.04616)* + +> [**Beyond Specialization: Assessing the Capabilities of MLLMs in Age and Gender Estimation**](https://arxiv.org/abs/2403.02302), +> Maksim Kuprashevich, Grigorii Alekseenko, Irina Tolstykh +> *2024 [arXiv 2403.02302](https://arxiv.org/abs/2403.02302)* + +[[`Paper 2023`](https://arxiv.org/abs/2307.04616)] [[`Paper 2024`](https://arxiv.org/abs/2403.02302)] [[`Demo`](https://huggingface.co/spaces/iitolstykh/age_gender_estimation_demo)] [[`Telegram Bot`](https://t.me/AnyAgeBot)] [[`BibTex`](#citing)] [[`Data`](https://wildchlamydia.github.io/lagenda/)] + + +## MiVOLO pretrained models + +Gender & Age recognition performance. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelTypeDataset (train and test)Age MAEAge CS@5Gender Accuracydownload
volo_d1face_only, ageIMDB-cleaned4.2967.71-checkpoint
volo_d1face_only, age, genderIMDB-cleaned4.2268.6899.38checkpoint
mivolo_d1face_body, age, genderIMDB-cleaned4.24 [face+body]
6.87 [body]
68.32 [face+body]
46.32 [body]
99.46 [face+body]
96.48 [body]
model_imdb_cross_person_4.24_99.46.pth.tar
volo_d1face_only, ageUTKFace4.2369.72-checkpoint
volo_d1face_only, age, genderUTKFace4.2369.7897.69checkpoint
mivolo_d1face_body, age, genderLagenda3.99 [face+body]71.27 [face+body]97.36 [face+body]demo
mivolov2_d1_384x384face_body, age, genderLagenda3.65 [face+body]74.48 [face+body]97.99 [face+body]telegram bot
+ +## MiVOLO regression benchmarks + +Gender & Age recognition performance. + +Use [valid_age_gender.sh](scripts/valid_age_gender.sh) to reproduce results with our checkpoints. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelTypeTrain DatasetTest DatasetAge MAEAge CS@5Gender Accuracydownload
mivolo_d1face_body, age, genderLagendaAgeDB5.55 [face]55.08 [face]98.3 [face]demo
mivolo_d1face_body, age, genderIMDB-cleanedAgeDB5.58 [face]55.54 [face]97.93 [face]model_imdb_cross_person_4.24_99.46.pth.tar
+ +## MiVOLO classification benchmarks + +Gender & Age recognition performance. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelTypeTrain DatasetTest DatasetAge AccuracyGender Accuracy
mivolo_d1face_body, age, genderLagendaFairFace61.07 [face+body]95.73 [face+body]
mivolo_d1face_body, age, genderLagendaAdience68.69 [face]96.51[face]
mivolov2_d1_384face_body, age, genderLagendaAdience69.43 [face]97.39[face]
+ +## Dataset + +**Please, [cite our papers](#citing) if you use any this data!** + +- Lagenda dataset: [images](https://drive.google.com/file/d/1QXO0NlkABPZT6x1_0Uc2i6KAtdcrpTbG/view?usp=sharing) and [annotation](https://drive.google.com/file/d/1mNYjYFb3MuKg-OL1UISoYsKObMUllbJx/view?usp=sharing). +- IMDB-clean: follow [these instructions](https://github.com/yiminglin-ai/imdb-clean) to get images and [download](https://drive.google.com/file/d/17uEqyU3uQ5trWZ5vRJKzh41yeuDe5hyL/view?usp=sharing) our annotations. +- UTK dataset: [origin full images](https://susanqq.github.io/UTKFace/) and our annotation: [split from the article](https://drive.google.com/file/d/1Fo1vPWrKtC5bPtnnVWNTdD4ZTKRXL9kv/view?usp=sharing), [random full split](https://drive.google.com/file/d/177AV631C3SIfi5nrmZA8CEihIt29cznJ/view?usp=sharing). +- Adience dataset: follow [these instructions](https://talhassner.github.io/home/projects/Adience/Adience-data.html) to get images and [download](https://drive.google.com/file/d/1wS1Q4FpksxnCR88A1tGLsLIr91xHwcVv/view?usp=sharing) our annotations. +
+ Click to expand! + + After downloading them, your `data` directory should look something like this: + + ```console + data + └── Adience + ├── annotations (folder with our annotations) + ├── aligned (will not be used) + ├── faces + ├── fold_0_data.txt + ├── fold_1_data.txt + ├── fold_2_data.txt + ├── fold_3_data.txt + └── fold_4_data.txt + ``` + + We use coarse aligned images from `faces/` dir. + + Using our detector we found a face bbox for each image (see [tools/prepare_adience.py](tools/prepare_adience.py)). + + This dataset has five folds. The performance metric is accuracy on five-fold cross validation. + + | images before removal | fold 0 | fold 1 | fold 2 | fold 3 | fold 4 | + | --------------------- | ------ | ------ | ------ | ------ | ------ | + | 19,370 | 4,484 | 3,730 | 3,894 | 3,446 | 3,816 | + + Not complete data + + | only age not found | only gender not found | SUM | + | ------------------ | --------------------- | ------------- | + | 40 | 1170 | 1,210 (6.2 %) | + + Removed data + + | failed to process image | age and gender not found | SUM | + | ----------------------- | ------------------------ | ----------- | + | 0 | 708 | 708 (3.6 %) | + + Genders + + | female | male | + | ------ | ----- | + | 9,372 | 8,120 | + + Ages (8 classes) after mapping to not intersected ages intervals + + | 0-2 | 4-6 | 8-12 | 15-20 | 25-32 | 38-43 | 48-53 | 60-100 | + | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ------ | + | 2,509 | 2,140 | 2,293 | 1,791 | 5,589 | 2,490 | 909 | 901 | + +
+ +- FairFace dataset: follow [these instructions](https://github.com/joojs/fairface) to get images and [download](https://drive.google.com/file/d/1EdY30A1SQmox96Y39VhBxdgALYhbkzdm/view?usp=drive_link) our annotations. +
+ Click to expand! + + After downloading them, your `data` directory should look something like this: + + ```console + data + └── FairFace + ├── annotations (folder with our annotations) + ├── fairface-img-margin025-trainval (will not be used) + ├── train + ├── val + ├── fairface-img-margin125-trainval + ├── train + ├── val + ├── fairface_label_train.csv + ├── fairface_label_val.csv + + ``` + + We use aligned images from `fairface-img-margin125-trainval/` dir. + + Using our detector we found a face bbox for each image and added a person bbox if it was possible (see [tools/prepare_fairface.py](tools/prepare_fairface.py)). + + This dataset has 2 splits: train and val. The performance metric is accuracy on validation. + + | images train | images val | + | ------------ | ---------- | + | 86,744 | 10,954 | + + Genders for **validation** + + | female | male | + | ------ | ----- | + | 5,162 | 5,792 | + + Ages for **validation** (9 classes): + + | 0-2 | 3-9 | 10-19 | 20-29 | 30-39 | 40-49 | 50-59 | 60-69 | 70+ | + | --- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | --- | + | 199 | 1,356 | 1,181 | 3,300 | 2,330 | 1,353 | 796 | 321 | 118 | + +
+- AgeDB dataset: follow [these instructions](https://ibug.doc.ic.ac.uk/resources/agedb/) to get images and [download](https://drive.google.com/file/d/1Dp72BUlAsyUKeSoyE_DOsFRS1x6ZBJen/view) our annotations. +
+ Click to expand! + + **Ages**: 1 - 101 + + **Genders**: 9788 faces of `M`, 6700 faces of `F` + + | images 0 | images 1 | images 2 | images 3 | images 4 | images 5 | images 6 | images 7 | images 8 | images 9 | + |----------|----------|----------|----------|----------|----------|----------|----------|----------|----------| + | 1701 | 1721 | 1615 | 1619 | 1626 | 1643 | 1634 | 1596 | 1676 | 1657 | + + Data splits were taken from [here](https://github.com/paplhjak/Facial-Age-Estimation-Benchmark-Databases) + + !! **All splits(all dataset) were used for models evaluation.** +
+ +## Install + +Install pytorch 1.13+ and other requirements. + +``` +pip install -r requirements.txt +pip install . +``` + + +## Demo + +1. [Download](https://drive.google.com/file/d/1CGNCkZQNj5WkP3rLpENWAOgrBQkUWRdw/view) body + face detector model to `models/yolov8x_person_face.pt` +2. [Download](https://drive.google.com/file/d/11i8pKctxz3wVkDBlWKvhYIh7kpVFXSZ4/view) mivolo checkpoint to `models/mivolo_imbd.pth.tar` + +```bash +wget https://variety.com/wp-content/uploads/2023/04/MCDNOHA_SP001.jpg -O jennifer_lawrence.jpg + +python3 demo.py \ +--input "jennifer_lawrence.jpg" \ +--output "output" \ +--detector-weights "models/yolov8x_person_face.pt " \ +--checkpoint "models/mivolo_imbd.pth.tar" \ +--device "cuda:0" \ +--with-persons \ +--draw +``` + +To run demo for a youtube video: +```bash +python3 demo.py \ +--input "https://www.youtube.com/shorts/pVh32k0hGEI" \ +--output "output" \ +--detector-weights "models/yolov8x_person_face.pt" \ +--checkpoint "models/mivolo_imbd.pth.tar" \ +--device "cuda:0" \ +--draw \ +--with-persons +``` + + +## Validation + +To reproduce validation metrics: + +1. Download prepared annotations for imbd-clean / utk / adience / lagenda / fairface. +2. Download checkpoint +3. Run validation: + +```bash +python3 eval_pretrained.py \ + --dataset_images /path/to/dataset/utk/images \ + --dataset_annotations /path/to/dataset/utk/annotation \ + --dataset_name utk \ + --split valid \ + --batch-size 512 \ + --checkpoint models/mivolo_imbd.pth.tar \ + --half \ + --with-persons \ + --device "cuda:0" +```` + +Supported dataset names: "utk", "imdb", "lagenda", "fairface", "adience". + + +## Changelog + +[CHANGELOG.md](CHANGELOG.md) + +## ONNX and TensorRT export + +As of now (11.08.2023), while ONNX export is technically feasible, it is not advisable due to the poor performance of the resulting model with batch processing. +**TensorRT** and **OpenVINO** export is impossible due to its lack of support for col2im. + +If you remain absolutely committed to utilizing ONNX export, you can refer to [these instructions](https://github.com/WildChlamydia/MiVOLO/issues/14#issuecomment-1675245889). + +The most highly recommended export method at present **is using TorchScript**. You can achieve this with a single line of code: +```python +torch.jit.trace(model) +``` +This approach provides you with a model that maintains its original speed and only requires a single file for usage, eliminating the need for additional code. + +## License + +Please, see [here](./license) + + +## Citing + +If you use our models, code or dataset, we kindly request you to cite the following paper and give repository a :star: + +```bibtex +@article{mivolo2023, + Author = {Maksim Kuprashevich and Irina Tolstykh}, + Title = {MiVOLO: Multi-input Transformer for Age and Gender Estimation}, + Year = {2023}, + Eprint = {arXiv:2307.04616}, +} +``` +```bibtex +@article{mivolo2024, + Author = {Maksim Kuprashevich and Grigorii Alekseenko and Irina Tolstykh}, + Title = {Beyond Specialization: Assessing the Capabilities of MLLMs in Age and Gender Estimation}, + Year = {2024}, + Eprint = {arXiv:2403.02302}, +} +``` diff --git a/age_estimator/mivolo/__pycache__/demo_copy.cpython-38.pyc b/age_estimator/mivolo/__pycache__/demo_copy.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32763331326c804b18d5cc804cf3f584cff1550a Binary files /dev/null and b/age_estimator/mivolo/__pycache__/demo_copy.cpython-38.pyc differ diff --git a/age_estimator/mivolo/demo.py b/age_estimator/mivolo/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..51e367b5035c7ce0bc21a1f57340bba4db95eba4 --- /dev/null +++ b/age_estimator/mivolo/demo.py @@ -0,0 +1,145 @@ +import argparse +import logging +import os +import random + +import cv2 +import torch +import yt_dlp +from mivolo.data.data_reader import InputType, get_all_files, get_input_type +from mivolo.predictor import Predictor +from timm.utils import setup_default_logging + +_logger = logging.getLogger("inference") + + +def get_direct_video_url(video_url): + ydl_opts = { + "format": "bestvideo", + "quiet": True, # Suppress terminal output + } + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info_dict = ydl.extract_info(video_url, download=False) + + if "url" in info_dict: + direct_url = info_dict["url"] + resolution = (info_dict["width"], info_dict["height"]) + fps = info_dict["fps"] + yid = info_dict["id"] + return direct_url, resolution, fps, yid + + return None, None, None, None + + +def get_local_video_info(vid_uri): + cap = cv2.VideoCapture(vid_uri) + if not cap.isOpened(): + raise ValueError(f"Failed to open video source {vid_uri}") + res = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) + fps = cap.get(cv2.CAP_PROP_FPS) + return res, fps + + +def get_random_frames(cap, num_frames): + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_indices = random.sample(range(total_frames), num_frames) + + frames = [] + for idx in frame_indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + if ret: + frames.append(frame) + return frames + + +def get_parser(): + parser = argparse.ArgumentParser(description="PyTorch MiVOLO Inference") + parser.add_argument("--input", type=str, default=None, required=True, help="image file or folder with images") + parser.add_argument("--output", type=str, default=None, required=True, help="folder for output results") + parser.add_argument("--detector-weights", type=str, default=None, required=True, help="Detector weights (YOLOv8).") + parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint") + + parser.add_argument( + "--with-persons", action="store_true", default=False, help="If set model will run with persons, if available" + ) + parser.add_argument( + "--disable-faces", action="store_true", default=False, help="If set model will use only persons if available" + ) + + parser.add_argument("--draw", action="store_true", default=False, help="If set, resulted images will be drawn") + parser.add_argument("--device", default="cuda", type=str, help="Device (accelerator) to use.") + + return parser + + +def main(): + parser = get_parser() + setup_default_logging() + args = parser.parse_args() + + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + os.makedirs(args.output, exist_ok=True) + + predictor = Predictor(args, verbose=True) + + input_type = get_input_type(args.input) + + if input_type == InputType.Video or input_type == InputType.VideoStream: + if "youtube" in args.input: + args.input, res, fps, yid = get_direct_video_url(args.input) + if not args.input: + raise ValueError(f"Failed to get direct video url {args.input}") + else: + cap = cv2.VideoCapture(args.input) + if not cap.isOpened(): + raise ValueError(f"Failed to open video source {args.input}") + + # Extract 4-5 random frames from the video + random_frames = get_random_frames(cap, num_frames=5) + + age_list = [] + for frame in random_frames: + detected_objects, out_im, age = predictor.recognize(frame) + age_list.append(age[0]) + + if args.draw: + bname = os.path.splitext(os.path.basename(args.input))[0] + filename = os.path.join(args.output, f"out_{bname}.jpg") + cv2.imwrite(filename, out_im) + _logger.info(f"Saved result to {filename}") + + # Calculate and print average age + avg_age = sum(age_list) / len(age_list) if age_list else 0 + print(f"Age list: {age_list}") + print(f"Average age: {avg_age:.2f}") + absolute_age = round(abs(avg_age)) + # Define the range + lower_bound = absolute_age - 2 + upper_bound = absolute_age + 2 + + + return absolute_age, lower_bound, upper_bound + + elif input_type == InputType.Image: + image_files = get_all_files(args.input) if os.path.isdir(args.input) else [args.input] + + for img_p in image_files: + img = cv2.imread(img_p) + detected_objects, out_im, age = predictor.recognize(img) + + if args.draw: + bname = os.path.splitext(os.path.basename(img_p))[0] + filename = os.path.join(args.output, f"out_{bname}.jpg") + cv2.imwrite(filename, out_im) + _logger.info(f"Saved result to {filename}") + + +if __name__ == "__main__": + absolute_age, lower_bound, upper_bound = main() + # Output the results in the desired format + print(f"Absolute Age: {absolute_age}") + print(f"Range: {lower_bound} - {upper_bound}") \ No newline at end of file diff --git a/age_estimator/mivolo/demo_copy.py b/age_estimator/mivolo/demo_copy.py new file mode 100644 index 0000000000000000000000000000000000000000..a915cb1daae47b4d4fa8e72d531b4cb459b661ab --- /dev/null +++ b/age_estimator/mivolo/demo_copy.py @@ -0,0 +1,144 @@ +import argparse +import logging +import os +import random + +import cv2 +import torch +import yt_dlp +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '././'))) + +from mivolo.data.data_reader import InputType, get_all_files, get_input_type +from mivolo.predictor import Predictor +from timm.utils import setup_default_logging + +_logger = logging.getLogger("inference") + + +def get_direct_video_url(video_url): + ydl_opts = { + "format": "bestvideo", + "quiet": True, # Suppress terminal output + } + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info_dict = ydl.extract_info(video_url, download=False) + + if "url" in info_dict: + direct_url = info_dict["url"] + resolution = (info_dict["width"], info_dict["height"]) + fps = info_dict["fps"] + yid = info_dict["id"] + return direct_url, resolution, fps, yid + + return None, None, None, None + + +def get_random_frames(cap, num_frames): + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_indices = random.sample(range(total_frames), num_frames) + + frames = [] + for idx in frame_indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + if ret: + frames.append(frame) + return frames + + +def get_parser(): + parser = argparse.ArgumentParser(description="PyTorch MiVOLO Inference") + parser.add_argument("--input", type=str, default=None, required=True, help="image file or folder with images") + parser.add_argument("--output", type=str, default=None, required=True, help="folder for output results") + parser.add_argument("--detector-weights", type=str, default=None, required=True, help="Detector weights (YOLOv8).") + parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint") + + parser.add_argument( + "--with_persons", action="store_true", default=False, help="If set model will run with persons, if available" + ) + parser.add_argument( + "--disable_faces", action="store_true", default=False, help="If set model will use only persons if available" + ) + + parser.add_argument("--draw", action="store_true", default=False, help="If set, resulted images will be drawn") + parser.add_argument("--device", default="cpu", type=str, help="Device (accelerator) to use.") + + return parser + + +def main(video_path, output_folder, detector_weights, checkpoint, device, with_persons, disable_faces,draw=False): + setup_default_logging() + + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + + os.makedirs(output_folder, exist_ok=True) + + # Initialize predictor + args = argparse.Namespace( + input=video_path, + output=output_folder, + detector_weights=detector_weights, + checkpoint=checkpoint, + draw=draw, + device=device, + with_persons=with_persons, + disable_faces=disable_faces + ) + + predictor = Predictor(args, verbose=True) + + if "youtube" in video_path: + video_path, res, fps, yid = get_direct_video_url(video_path) + if not video_path: + raise ValueError(f"Failed to get direct video url {video_path}") + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Failed to open video source {video_path}") + + # Extract 4-5 random frames from the video + random_frames = get_random_frames(cap, num_frames=10) + age_list = [] + + for frame in random_frames: + detected_objects, out_im, age = predictor.recognize(frame) + try: + age_list.append(age[0]) # Attempt to access the first element of age + if draw: + bname = os.path.splitext(os.path.basename(video_path))[0] + filename = os.path.join(output_folder, f"out_{bname}.jpg") + cv2.imwrite(filename, out_im) + _logger.info(f"Saved result to {filename}") + except IndexError: + continue + + if len(age_list)==0: + raise ValueError("No person was detected in the frame. Please upload a proper face video.") + + + + # Calculate and print average age + avg_age = sum(age_list) / len(age_list) if age_list else 0 + print(f"Age list: {age_list}") + print(f"Average age: {avg_age:.2f}") + absolute_age = round(abs(avg_age)) + + # Define the range + lower_bound = absolute_age - 2 + upper_bound = absolute_age + 2 + + return absolute_age, lower_bound, upper_bound + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + + absolute_age, lower_bound, upper_bound = main(args.input, args.output, args.detector_weights, args.checkpoint, args.device, args.with_persons, args.disable_faces ,args.draw) + # Output the results in the desired format + print(f"Absolute Age: {absolute_age}") + print(f"Range: {lower_bound} - {upper_bound}") diff --git a/age_estimator/mivolo/eval_pretrained.py b/age_estimator/mivolo/eval_pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..567b5d5a464d4e8ca3234ba11e08148ecd2373ff --- /dev/null +++ b/age_estimator/mivolo/eval_pretrained.py @@ -0,0 +1,232 @@ +import argparse +import json +import logging +from typing import Tuple + +import matplotlib.pyplot as plt +import seaborn as sns +import torch +from eval_tools import Metrics, time_sync, write_results +from mivolo.data.dataset import build as build_data +from mivolo.model.mi_volo import MiVOLO +from timm.utils import setup_default_logging + +_logger = logging.getLogger("inference") +LOG_FREQUENCY = 10 + + +def get_parser(): + parser = argparse.ArgumentParser(description="PyTorch MiVOLO Validation") + parser.add_argument("--dataset_images", default="", type=str, required=True, help="path to images") + parser.add_argument("--dataset_annotations", default="", type=str, required=True, help="path to annotations") + parser.add_argument( + "--dataset_name", + default=None, + type=str, + required=True, + choices=["utk", "imdb", "lagenda", "fairface", "adience", "agedb", "cacd"], + help="dataset name", + ) + parser.add_argument("--split", default="validation", help="dataset splits separated by comma (default: validation)") + parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint") + + parser.add_argument("--batch-size", default=64, type=int, help="batch size") + parser.add_argument( + "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)" + ) + parser.add_argument("--device", default="cuda", type=str, help="Device (accelerator) to use.") + parser.add_argument("--l-for-cs", type=int, default=5, help="L for CS (cumulative score)") + + parser.add_argument("--half", action="store_true", default=False, help="use half-precision model") + parser.add_argument( + "--with-persons", action="store_true", default=False, help="If the model will run with persons, if available" + ) + parser.add_argument( + "--disable-faces", action="store_true", default=False, help="If the model will use only persons if available" + ) + + parser.add_argument("--draw-hist", action="store_true", help="Draws the hist of error by age") + parser.add_argument( + "--results-file", + default="", + type=str, + metavar="FILENAME", + help="Output csv file for validation results (summary)", + ) + parser.add_argument( + "--results-format", default="csv", type=str, help="Format for results file one of (csv, json) (default: csv)." + ) + + return parser + + +def process_batch( + mivolo_model: MiVOLO, + input: torch.tensor, + target: torch.tensor, + num_classes_gender: int = 2, +): + + start = time_sync() + output = mivolo_model.inference(input) + # target with age == -1 and gender == -1 marks that sample is not valid + assert not (all(target[:, 0] == -1) and all(target[:, 1] == -1)) + + if not mivolo_model.meta.only_age: + gender_out = output[:, :num_classes_gender] + gender_target = target[:, 1] + age_out = output[:, num_classes_gender:] + else: + age_out = output + gender_out, gender_target = None, None + + # measure elapsed time + process_time = time_sync() - start + + age_target = target[:, 0].unsqueeze(1) + + return age_out, age_target, gender_out, gender_target, process_time + + +def _filter_invalid_target(out: torch.tensor, target: torch.tensor): + # exclude samples where target gt == -1, that marks sample is not valid + mask = target != -1 + return out[mask], target[mask] + + +def postprocess_gender(gender_out: torch.tensor, gender_target: torch.tensor) -> Tuple[torch.tensor, torch.tensor]: + if gender_target is None: + return gender_out, gender_target + return _filter_invalid_target(gender_out, gender_target) + + +def postprocess_age(age_out: torch.tensor, age_target: torch.tensor, dataset) -> Tuple[torch.tensor, torch.tensor]: + # Revert _norm_age() operation. Output is 2 float tensors + + age_out, age_target = _filter_invalid_target(age_out, age_target) + + age_out = age_out * (dataset.max_age - dataset.min_age) + dataset.avg_age + # clamp to 0 because age can be below zero + age_out = torch.clamp(age_out, min=0) + + if dataset.age_classes is not None: + # classification case + age_out = torch.round(age_out) + if dataset._intervals.device != age_out.device: + dataset._intervals = dataset._intervals.to(age_out.device) + age_inds = torch.searchsorted(dataset._intervals, age_out, side="right") - 1 + age_out = age_inds + else: + age_target = age_target * (dataset.max_age - dataset.min_age) + dataset.avg_age + return age_out, age_target + + +def validate(args): + + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + + mivolo_model = MiVOLO( + args.checkpoint, + args.device, + half=args.half, + use_persons=args.with_persons, + disable_faces=args.disable_faces, + verbose=True, + ) + + dataset, loader = build_data( + name=args.dataset_name, + images_path=args.dataset_images, + annotations_path=args.dataset_annotations, + split=args.split, + mivolo_model=mivolo_model, # to get meta information from model + workers=args.workers, + batch_size=args.batch_size, + ) + + d_stat = Metrics(args.l_for_cs, args.draw_hist, dataset.age_classes) + + # warmup, reduce variability of first batch time, especially for comparing torchscript vs non + mivolo_model.warmup(args.batch_size) + + preproc_end = time_sync() + for batch_idx, (input, target) in enumerate(loader): + + preprocess_time = time_sync() - preproc_end + # get output and calculate loss + age_out, age_target, gender_out, gender_target, process_time = process_batch( + mivolo_model, input, target, dataset.num_classes_gender + ) + + gender_out, gender_target = postprocess_gender(gender_out, gender_target) + age_out, age_target = postprocess_age(age_out, age_target, dataset) + + d_stat.update_gender_accuracy(gender_out, gender_target) + if d_stat.is_regression: + d_stat.update_regression_age_metrics(age_out, age_target) + else: + d_stat.update_age_accuracy(age_out, age_target) + d_stat.update_time(process_time, preprocess_time, input.shape[0]) + + if batch_idx % LOG_FREQUENCY == 0: + _logger.info( + "Test: [{0:>4d}/{1}] " "{2}".format(batch_idx, len(loader), d_stat.get_info_str(input.size(0))) + ) + + preproc_end = time_sync() + + # model info + results = dict( + model=args.checkpoint, + dataset_name=args.dataset_name, + param_count=round(mivolo_model.param_count / 1e6, 2), + img_size=mivolo_model.input_size, + use_faces=mivolo_model.meta.use_face_crops, + use_persons=mivolo_model.meta.use_persons, + in_chans=mivolo_model.meta.in_chans, + batch=args.batch_size, + ) + # metrics info + results.update(d_stat.get_result()) + return results + + +def main(): + parser = get_parser() + setup_default_logging() + args = parser.parse_args() + + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + + results = validate(args) + + result_str = " * Age Acc@1 {:.3f} ({:.3f})".format(results["agetop1"], results["agetop1_err"]) + if "gendertop1" in results: + result_str += " Gender Acc@1 1 {:.3f} ({:.3f})".format(results["gendertop1"], results["gendertop1_err"]) + result_str += " Mean inference time {:.3f} ms Mean preprocessing time {:.3f}".format( + results["mean_inference_time"], results["mean_preprocessing_time"] + ) + _logger.info(result_str) + + if args.draw_hist and "per_age_error" in results: + err = [sum(v) / len(v) for k, v in results["per_age_error"].items()] + ages = list(results["per_age_error"].keys()) + sns.scatterplot(x=ages, y=err, hue=err) + plt.legend([], [], frameon=False) + plt.xlabel("Age") + plt.ylabel("MAE") + plt.savefig("age_error.png", dpi=300) + + if args.results_file: + write_results(args.results_file, results, format=args.results_format) + + # output results in JSON to stdout w/ delimiter for runner script + print(f"--result\n{json.dumps(results, indent=4)}") + + +if __name__ == "__main__": + main() diff --git a/age_estimator/mivolo/eval_tools.py b/age_estimator/mivolo/eval_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..db9a29973cc68796f98c80559f158ef21f812a65 --- /dev/null +++ b/age_estimator/mivolo/eval_tools.py @@ -0,0 +1,149 @@ +import csv +import json +import time +from collections import OrderedDict, defaultdict + +import torch +from mivolo.data.misc import cumulative_error, cumulative_score +from timm.utils import AverageMeter, accuracy + + +def time_sync(): + # pytorch-accurate time + if torch.cuda.is_available(): + torch.cuda.synchronize() + return time.time() + + +def write_results(results_file, results, format="csv"): + with open(results_file, mode="w") as cf: + if format == "json": + json.dump(results, cf, indent=4) + else: + if not isinstance(results, (list, tuple)): + results = [results] + if not results: + return + dw = csv.DictWriter(cf, fieldnames=results[0].keys()) + dw.writeheader() + for r in results: + dw.writerow(r) + cf.flush() + + +class Metrics: + def __init__(self, l_for_cs, draw_hist, age_classes=None): + self.batch_time = AverageMeter() + self.preproc_batch_time = AverageMeter() + self.seen = 0 + + self.losses = AverageMeter() + self.top1_m_gender = AverageMeter() + self.top1_m_age = AverageMeter() + + if age_classes is None: + self.is_regression = True + self.av_csl_age = AverageMeter() + self.max_error = AverageMeter() + self.per_age_error = defaultdict(list) + self.l_for_cs = l_for_cs + else: + self.is_regression = False + + self.draw_hist = draw_hist + + def update_regression_age_metrics(self, age_out, age_target): + batch_size = age_out.size(0) + + age_abs_err = torch.abs(age_out - age_target) + age_acc1 = torch.sum(age_abs_err) / age_out.shape[0] + age_csl = cumulative_score(age_out, age_target, self.l_for_cs) + me = cumulative_error(age_out, age_target, 20) + + self.top1_m_age.update(age_acc1.item(), batch_size) + self.av_csl_age.update(age_csl.item(), batch_size) + self.max_error.update(me.item(), batch_size) + + if self.draw_hist: + for i in range(age_out.shape[0]): + self.per_age_error[int(age_target[i].item())].append(age_abs_err[i].item()) + + def update_age_accuracy(self, age_out, age_target): + batch_size = age_out.size(0) + if batch_size == 0: + return + correct = torch.sum(age_out == age_target) + age_acc1 = correct * 100.0 / batch_size + self.top1_m_age.update(age_acc1.item(), batch_size) + + def update_gender_accuracy(self, gender_out, gender_target): + if gender_out is None or gender_out.size(0) == 0: + return + batch_size = gender_out.size(0) + gender_acc1 = accuracy(gender_out, gender_target, topk=(1,))[0] + if gender_acc1 is not None: + self.top1_m_gender.update(gender_acc1.item(), batch_size) + + def update_loss(self, loss, batch_size): + self.losses.update(loss.item(), batch_size) + + def update_time(self, process_time, preprocess_time, batch_size): + self.seen += batch_size + self.batch_time.update(process_time) + self.preproc_batch_time.update(preprocess_time) + + def get_info_str(self, batch_size): + avg_time = (self.preproc_batch_time.sum + self.batch_time.sum) / self.batch_time.count + cur_time = self.batch_time.val + self.preproc_batch_time.val + middle_info = ( + "Time: {cur_time:.3f}s ({avg_time:.3f}s, {rate_avg:>7.2f}/s) " + "Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) " + "Gender Acc: {top1gender.val:>7.2f} ({top1gender.avg:>7.2f}) ".format( + cur_time=cur_time, + avg_time=avg_time, + rate_avg=batch_size / avg_time, + loss=self.losses, + top1gender=self.top1_m_gender, + ) + ) + + if self.is_regression: + age_info = ( + "Age CS@{l_for_cs}: {csl.val:>7.4f} ({csl.avg:>7.4f}) " + "Age CE@20: {max_error.val:>7.4f} ({max_error.avg:>7.4f}) " + "Age ME: {top1age.val:>7.2f} ({top1age.avg:>7.2f})".format( + top1age=self.top1_m_age, csl=self.av_csl_age, max_error=self.max_error, l_for_cs=self.l_for_cs + ) + ) + else: + age_info = "Age Acc: {top1age.val:>7.2f} ({top1age.avg:>7.2f})".format(top1age=self.top1_m_age) + + return middle_info + age_info + + def get_result(self): + age_top1a = self.top1_m_age.avg + gender_top1 = self.top1_m_gender.avg if self.top1_m_gender.count > 0 else None + + mean_per_image_time = self.batch_time.sum / self.seen + mean_preprocessing_time = self.preproc_batch_time.sum / self.seen + + results = OrderedDict( + mean_inference_time=mean_per_image_time * 1e3, + mean_preprocessing_time=mean_preprocessing_time * 1e3, + agetop1=round(age_top1a, 4), + agetop1_err=round(100 - age_top1a, 4), + ) + + if self.is_regression: + results.update( + dict( + max_error=self.max_error.avg, + csl=self.av_csl_age.avg, + per_age_error=self.per_age_error, + ) + ) + + if gender_top1 is not None: + results.update(dict(gendertop1=round(gender_top1, 4), gendertop1_err=round(100 - gender_top1, 4))) + + return results diff --git a/age_estimator/mivolo/images/MiVOLO.jpg b/age_estimator/mivolo/images/MiVOLO.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4dbe50c3eb596321671ebc3443cf01467deb2593 Binary files /dev/null and b/age_estimator/mivolo/images/MiVOLO.jpg differ diff --git a/age_estimator/mivolo/infer.py b/age_estimator/mivolo/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3d79c79d3cd4e4444b2f2afc8428f19c99d711 --- /dev/null +++ b/age_estimator/mivolo/infer.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +import pathlib +import os +import huggingface_hub +import numpy as np +import argparse +from dataclasses import dataclass +from mivolo.predictor import Predictor +from PIL import Image + +@dataclass +class Cfg: + detector_weights: str + checkpoint: str + device: str = "cpu" + with_persons: bool = True + disable_faces: bool = False + draw: bool = True + + +def load_models(): + detector_path = huggingface_hub.hf_hub_download('iitolstykh/demo_yolov8_detector', + 'yolov8x_person_face.pt') + + age_gender_path_v1 = 'age_estimator/MiVOLO-main/models/model_imdb_cross_person_4.22_99.46.pth.tar' + predictor_cfg_v1 = Cfg(detector_path, age_gender_path_v1) + + predictor_v1 = Predictor(predictor_cfg_v1) + + return predictor_v1 + +def detect(image: np.ndarray, score_threshold: float, iou_threshold: float, mode: str, predictor: Predictor) -> np.ndarray: + predictor.detector.detector_kwargs['conf'] = score_threshold + predictor.detector.detector_kwargs['iou'] = iou_threshold + + if mode == "Use persons and faces": + use_persons = True + disable_faces = False + elif mode == "Use persons only": + use_persons = True + disable_faces = True + elif mode == "Use faces only": + use_persons = False + disable_faces = False + + predictor.age_gender_model.meta.use_persons = use_persons + predictor.age_gender_model.meta.disable_faces = disable_faces + + image = image[:, :, ::-1] # RGB -> BGR for OpenCV + detected_objects, out_im = predictor.recognize(image) + return out_im[:, :, ::-1] # BGR -> RGB + +def load_image(image_path: str): + image = Image.open(image_path) + image_np = np.array(image) + return image_np + +def main(args): + # Load models + predictor_v1 = load_models() + + # Set parameters from args + score_threshold = args.score_threshold + iou_threshold = args.iou_threshold + mode = args.mode + + # Load and process image + image_np = load_image(args.image_path) + + # Predict with model + result = detect(image_np, score_threshold, iou_threshold, mode, predictor_v1) + + output_image = Image.fromarray(result) + output_image.save(args.output_path) + output_image.show() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Object Detection with YOLOv8 and Age/Gender Prediction') + parser.add_argument('--image_path', type=str, required=True, help='Path to the input image') + parser.add_argument('--output_path', type=str, default='output_image.jpg', help='Path to save the output image') + parser.add_argument('--score_threshold', type=float, default=0.4, help='Score threshold for detection') + parser.add_argument('--iou_threshold', type=float, default=0.7, help='IoU threshold for detection') + parser.add_argument('--mode', type=str, choices=["Use persons and faces", "Use persons only", "Use faces only"], + default="Use persons and faces", help='Detection mode') + + args = parser.parse_args() + main(args) + diff --git a/age_estimator/mivolo/license/en_us.pdf b/age_estimator/mivolo/license/en_us.pdf new file mode 100644 index 0000000000000000000000000000000000000000..bbe714b77de6ce80dd7e09a086151688f33b6aff Binary files /dev/null and b/age_estimator/mivolo/license/en_us.pdf differ diff --git a/age_estimator/mivolo/license/ru.pdf b/age_estimator/mivolo/license/ru.pdf new file mode 100644 index 0000000000000000000000000000000000000000..a22946d5e9e6010cec1b21c0e219249df4c44fd2 Binary files /dev/null and b/age_estimator/mivolo/license/ru.pdf differ diff --git a/age_estimator/mivolo/measure_time.py b/age_estimator/mivolo/measure_time.py new file mode 100644 index 0000000000000000000000000000000000000000..c01f74ac729cde64e237c11b3183a380e6d3082a --- /dev/null +++ b/age_estimator/mivolo/measure_time.py @@ -0,0 +1,77 @@ +import pandas as pd +import torch +import tqdm +from eval_tools import time_sync +from mivolo.model.create_timm_model import create_model + +if __name__ == "__main__": + + face_person_ckpt_path = "/data/dataset/iikrasnova/age_gender/pretrained/checkpoint-377.pth.tar" + face_person_input_size = [6, 224, 224] + + face_age_ckpt_path = "/data/dataset/iikrasnova/age_gender/pretrained/model_only_age_imdb_4.32.pth.tar" + face_input_size = [3, 224, 224] + + model_names = ["face_body_model", "face_model"] + # batch_size = 16 + steps = 1000 + warmup_steps = 10 + device = torch.device("cuda:1") + + df_data = [] + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + + for ckpt_path, input_size, model_name, num_classes in zip( + [face_person_ckpt_path, face_age_ckpt_path], [face_person_input_size, face_input_size], model_names, [3, 1] + ): + + in_chans = input_size[0] + print(f"Collecting stat for {ckpt_path} ...") + model = create_model( + "mivolo_d1_224", + num_classes=num_classes, + in_chans=in_chans, + pretrained=False, + checkpoint_path=ckpt_path, + filter_keys=["fds."], + ) + model = model.to(device) + model.eval() + model = model.half() + + time_per_batch = {} + for batch_size in batch_sizes: + create_t0 = time_sync() + for _ in range(steps): + inputs = torch.randn((batch_size,) + tuple(input_size)).to(device).half() + create_t1 = time_sync() + create_taken = create_t1 - create_t0 + + with torch.no_grad(): + inputs = torch.randn((batch_size,) + tuple(input_size)).to(device).half() + for _ in range(warmup_steps): + out = model(inputs) + + all_time = 0 + for _ in tqdm.tqdm(range(steps), desc=f"{model_name} batch {batch_size}"): + start = time_sync() + inputs = torch.randn((batch_size,) + tuple(input_size)).to(device).half() + out = model(inputs) + out += 1 + end = time_sync() + all_time += end - start + + time_taken = (all_time - create_taken) * 1000 / steps / batch_size + print(f"Inference {inputs.shape}, steps: {steps}. Mean time taken {time_taken} ms / image") + + time_per_batch[str(batch_size)] = f"{time_taken:.2f}" + df_data.append(time_per_batch) + + headers = list(map(str, batch_sizes)) + output_df = pd.DataFrame(df_data, columns=headers) + output_df.index = model_names + + df2_transposed = output_df.T + out_file = "batch_sizes.csv" + df2_transposed.to_csv(out_file, sep=",") + print(f"Saved time stat for {len(df2_transposed)} batches to {out_file}") diff --git a/age_estimator/mivolo/mivolo/__init__.py b/age_estimator/mivolo/mivolo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/age_estimator/mivolo/mivolo/__pycache__/__init__.cpython-38.pyc b/age_estimator/mivolo/mivolo/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0551922d63de57d3a5afbe87a5c0c959a104ce8d Binary files /dev/null and b/age_estimator/mivolo/mivolo/__pycache__/__init__.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/__pycache__/predictor.cpython-38.pyc b/age_estimator/mivolo/mivolo/__pycache__/predictor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f94779610f26d1daa5ee6c21909ce16d18cd7b8 Binary files /dev/null and b/age_estimator/mivolo/mivolo/__pycache__/predictor.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/__pycache__/structures.cpython-38.pyc b/age_estimator/mivolo/mivolo/__pycache__/structures.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a681475e2b883e03d909b735d056830675f4eafc Binary files /dev/null and b/age_estimator/mivolo/mivolo/__pycache__/structures.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/data/__init__.py b/age_estimator/mivolo/mivolo/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/age_estimator/mivolo/mivolo/data/__pycache__/__init__.cpython-38.pyc b/age_estimator/mivolo/mivolo/data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeb2030d2141b89a5df92b36f1aa01c64af25015 Binary files /dev/null and b/age_estimator/mivolo/mivolo/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/data/__pycache__/data_reader.cpython-38.pyc b/age_estimator/mivolo/mivolo/data/__pycache__/data_reader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ad45d32e93ae108962d7a069fa694385a4452c3 Binary files /dev/null and b/age_estimator/mivolo/mivolo/data/__pycache__/data_reader.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/data/__pycache__/misc.cpython-38.pyc b/age_estimator/mivolo/mivolo/data/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee1b9aac35447e19e2c44442e081a63e372bc0a1 Binary files /dev/null and b/age_estimator/mivolo/mivolo/data/__pycache__/misc.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/data/data_reader.py b/age_estimator/mivolo/mivolo/data/data_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..44b819c6736d6a12fa0327d57152878dca1ebb0f --- /dev/null +++ b/age_estimator/mivolo/mivolo/data/data_reader.py @@ -0,0 +1,125 @@ +import os +from collections import defaultdict +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import pandas as pd + +IMAGES_EXT: Tuple = (".jpeg", ".jpg", ".png", ".webp", ".bmp", ".gif") +VIDEO_EXT: Tuple = (".mp4", ".avi", ".mov", ".mkv", ".webm") + + +@dataclass +class PictureInfo: + image_path: str + age: Optional[str] # age or age range(start;end format) or "-1" + gender: Optional[str] # "M" of "F" or "-1" + bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # face bbox: xyxy + person_bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # person bbox: xyxy + + @property + def has_person_bbox(self) -> bool: + return any(coord != -1 for coord in self.person_bbox) + + @property + def has_face_bbox(self) -> bool: + return any(coord != -1 for coord in self.bbox) + + def has_gt(self, only_age: bool = False) -> bool: + if only_age: + return self.age != "-1" + else: + return not (self.age == "-1" and self.gender == "-1") + + def clear_person_bbox(self): + self.person_bbox = [-1, -1, -1, -1] + + def clear_face_bbox(self): + self.bbox = [-1, -1, -1, -1] + + +class AnnotType(Enum): + ORIGINAL = "original" + PERSONS = "persons" + NONE = "none" + + @classmethod + def _missing_(cls, value): + print(f"WARN: Unknown annotation type {value}.") + return AnnotType.NONE + + +def get_all_files(path: str, extensions: Tuple = IMAGES_EXT): + files_all = [] + for root, subFolders, files in os.walk(path): + for name in files: + # linux tricks with .directory that still is file + if "directory" not in name and sum([ext.lower() in name.lower() for ext in extensions]) > 0: + files_all.append(os.path.join(root, name)) + return files_all + + +class InputType(Enum): + Image = 0 + Video = 1 + VideoStream = 2 + + +def get_input_type(input_path: str) -> InputType: + if os.path.isdir(input_path): + print("Input is a folder, only images will be processed") + return InputType.Image + elif os.path.isfile(input_path): + if input_path.endswith(VIDEO_EXT): + return InputType.Video + if input_path.endswith(IMAGES_EXT): + return InputType.Image + else: + raise ValueError( + f"Unknown or unsupported input file format {input_path}, \ + supported video formats: {VIDEO_EXT}, \ + supported image formats: {IMAGES_EXT}" + ) + elif input_path.startswith("http") and not input_path.endswith(IMAGES_EXT): + return InputType.VideoStream + else: + raise ValueError(f"Unknown input {input_path}") + + +def read_csv_annotation_file(annotation_file: str, images_dir: str, ignore_without_gt=False): + bboxes_per_image: Dict[str, List[PictureInfo]] = defaultdict(list) + + df = pd.read_csv(annotation_file, sep=",") + + annot_type = AnnotType("persons") if "person_x0" in df.columns else AnnotType("original") + print(f"Reading {annotation_file} (type: {annot_type})...") + + missing_images = 0 + for index, row in df.iterrows(): + img_path = os.path.join(images_dir, row["img_name"]) + if not os.path.exists(img_path): + missing_images += 1 + continue + + face_x1, face_y1, face_x2, face_y2 = row["face_x0"], row["face_y0"], row["face_x1"], row["face_y1"] + age, gender = str(row["age"]), str(row["gender"]) + + if ignore_without_gt and (age == "-1" or gender == "-1"): + continue + + if annot_type == AnnotType.PERSONS: + p_x1, p_y1, p_x2, p_y2 = row["person_x0"], row["person_y0"], row["person_x1"], row["person_y1"] + person_bbox = list(map(int, [p_x1, p_y1, p_x2, p_y2])) + else: + person_bbox = [-1, -1, -1, -1] + + bbox = list(map(int, [face_x1, face_y1, face_x2, face_y2])) + pic_info = PictureInfo(img_path, age, gender, bbox, person_bbox) + assert isinstance(pic_info.person_bbox, list) + + bboxes_per_image[img_path].append(pic_info) + + if missing_images > 0: + print(f"WARNING: Missing images: {missing_images}/{len(df)}") + return bboxes_per_image, annot_type diff --git a/age_estimator/mivolo/mivolo/data/dataset/__init__.py b/age_estimator/mivolo/mivolo/data/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c847b48bfb57251bd6e290a61059813c3b698a3 --- /dev/null +++ b/age_estimator/mivolo/mivolo/data/dataset/__init__.py @@ -0,0 +1,66 @@ +from typing import Tuple + +import torch +from mivolo.model.mi_volo import MiVOLO + +from .age_gender_dataset import AgeGenderDataset +from .age_gender_loader import create_loader +from .classification_dataset import AdienceDataset, FairFaceDataset + +DATASET_CLASS_MAP = { + "utk": AgeGenderDataset, + "lagenda": AgeGenderDataset, + "imdb": AgeGenderDataset, + "agedb": AgeGenderDataset, + "cacd": AgeGenderDataset, + "adience": AdienceDataset, + "fairface": FairFaceDataset, +} + + +def build( + name: str, + images_path: str, + annotations_path: str, + split: str, + mivolo_model: MiVOLO, + workers: int, + batch_size: int, +) -> Tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]: + + dataset_class = DATASET_CLASS_MAP[name] + + dataset: torch.utils.data.Dataset = dataset_class( + images_path=images_path, + annotations_path=annotations_path, + name=name, + split=split, + target_size=mivolo_model.input_size, + max_age=mivolo_model.meta.max_age, + min_age=mivolo_model.meta.min_age, + model_with_persons=mivolo_model.meta.with_persons_model, + use_persons=mivolo_model.meta.use_persons, + disable_faces=mivolo_model.meta.disable_faces, + only_age=mivolo_model.meta.only_age, + ) + + data_config = mivolo_model.data_config + + in_chans = 3 if not mivolo_model.meta.with_persons_model else 6 + input_size = (in_chans, mivolo_model.input_size, mivolo_model.input_size) + + dataset_loader: torch.utils.data.DataLoader = create_loader( + dataset, + input_size=input_size, + batch_size=batch_size, + mean=data_config["mean"], + std=data_config["std"], + num_workers=workers, + crop_pct=data_config["crop_pct"], + crop_mode=data_config["crop_mode"], + pin_memory=False, + device=mivolo_model.device, + target_type=dataset.target_dtype, + ) + + return dataset, dataset_loader diff --git a/age_estimator/mivolo/mivolo/data/dataset/age_gender_dataset.py b/age_estimator/mivolo/mivolo/data/dataset/age_gender_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb3cb9922ac7591f9c01f70fc16245b40483871 --- /dev/null +++ b/age_estimator/mivolo/mivolo/data/dataset/age_gender_dataset.py @@ -0,0 +1,194 @@ +import logging +from typing import Any, List, Optional, Set + +import cv2 +import numpy as np +import torch +from mivolo.data.dataset.reader_age_gender import ReaderAgeGender +from PIL import Image +from torchvision import transforms + +_logger = logging.getLogger("AgeGenderDataset") + + +class AgeGenderDataset(torch.utils.data.Dataset): + def __init__( + self, + images_path, + annotations_path, + name=None, + split="train", + load_bytes=False, + img_mode="RGB", + transform=None, + is_training=False, + seed=1234, + target_size=224, + min_age=None, + max_age=None, + model_with_persons=False, + use_persons=False, + disable_faces=False, + only_age=False, + ): + reader = ReaderAgeGender( + images_path, + annotations_path, + split=split, + seed=seed, + target_size=target_size, + with_persons=use_persons, + disable_faces=disable_faces, + only_age=only_age, + ) + + self.name = name + self.model_with_persons = model_with_persons + self.reader = reader + self.load_bytes = load_bytes + self.img_mode = img_mode + self.transform = transform + self._consecutive_errors = 0 + self.is_training = is_training + self.random_flip = 0.0 + + # Setting up classes. + # If min and max classes are passed - use them to have the same preprocessing for validation + self.max_age: float = None + self.min_age: float = None + self.avg_age: float = None + self.set_ages_min_max(min_age, max_age) + + self.genders = ["M", "F"] + self.num_classes_gender = len(self.genders) + + self.age_classes: Optional[List[str]] = self.set_age_classes() + + self.num_classes_age = 1 if self.age_classes is None else len(self.age_classes) + self.num_classes: int = self.num_classes_age + self.num_classes_gender + self.target_dtype = torch.float32 + + def set_age_classes(self) -> Optional[List[str]]: + return None # for regression dataset + + def set_ages_min_max(self, min_age: Optional[float], max_age: Optional[float]): + + assert all(age is None for age in [min_age, max_age]) or all( + age is not None for age in [min_age, max_age] + ), "Both min and max age must be passed or none of them" + + if max_age is not None and min_age is not None: + _logger.info(f"Received predefined min_age {min_age} and max_age {max_age}") + self.max_age = max_age + self.min_age = min_age + else: + # collect statistics from loaded dataset + all_ages_set: Set[int] = set() + for img_path, image_samples in self.reader._ann.items(): + for image_sample_info in image_samples: + if image_sample_info.age == "-1": + continue + age = round(float(image_sample_info.age)) + all_ages_set.add(age) + + self.max_age = max(all_ages_set) + self.min_age = min(all_ages_set) + + self.avg_age = (self.max_age + self.min_age) / 2.0 + + def _norm_age(self, age): + return (age - self.avg_age) / (self.max_age - self.min_age) + + def parse_gender(self, _gender: str) -> float: + if _gender != "-1": + gender = float(0 if _gender == "M" or _gender == "0" else 1) + else: + gender = -1 + return gender + + def parse_target(self, _age: str, gender: str) -> List[Any]: + if _age != "-1": + age = round(float(_age)) + age = self._norm_age(float(age)) + else: + age = -1 + + target: List[float] = [age, self.parse_gender(gender)] + return target + + @property + def transform(self): + return self._transform + + @transform.setter + def transform(self, transform): + # Disable pretrained monkey-patched transforms + if not transform: + return + + _trans = [] + for trans in transform.transforms: + if "Resize" in str(trans): + continue + if "Crop" in str(trans): + continue + _trans.append(trans) + self._transform = transforms.Compose(_trans) + + def apply_tranforms(self, image: Optional[np.ndarray]) -> np.ndarray: + if image is None: + return None + + if self.transform is None: + return image + + image = convert_to_pil(image, self.img_mode) + for trans in self.transform.transforms: + image = trans(image) + return image + + def __getitem__(self, index): + # get preprocessed face and person crops (np.ndarray) + # resize + pad, for person crops: cut off other bboxes + images, target = self.reader[index] + + target = self.parse_target(*target) + + if self.model_with_persons: + face_image, person_image = images + person_image: np.ndarray = self.apply_tranforms(person_image) + else: + face_image = images[0] + person_image = None + + face_image: np.ndarray = self.apply_tranforms(face_image) + + if person_image is not None: + img = np.concatenate([face_image, person_image], axis=0) + else: + img = face_image + + return img, target + + def __len__(self): + return len(self.reader) + + def filename(self, index, basename=False, absolute=False): + return self.reader.filename(index, basename, absolute) + + def filenames(self, basename=False, absolute=False): + return self.reader.filenames(basename, absolute) + + +def convert_to_pil(cv_im: Optional[np.ndarray], img_mode: str = "RGB") -> "Image": + if cv_im is None: + return None + + if img_mode == "RGB": + cv_im = cv2.cvtColor(cv_im, cv2.COLOR_BGR2RGB) + else: + raise Exception("Incorrect image mode has been passed!") + + cv_im = np.ascontiguousarray(cv_im) + pil_image = Image.fromarray(cv_im) + return pil_image diff --git a/age_estimator/mivolo/mivolo/data/dataset/age_gender_loader.py b/age_estimator/mivolo/mivolo/data/dataset/age_gender_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..260e310de7c4a33dc11c6b0809f000940199a5c3 --- /dev/null +++ b/age_estimator/mivolo/mivolo/data/dataset/age_gender_loader.py @@ -0,0 +1,169 @@ +""" +Code adapted from timm https://github.com/huggingface/pytorch-image-models + +Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich +""" + +import logging +from contextlib import suppress +from functools import partial +from itertools import repeat + +import numpy as np +import torch +import torch.utils.data +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data.dataset import IterableImageDataset +from timm.data.loader import PrefetchLoader, _worker_init +from timm.data.transforms_factory import create_transform + +_logger = logging.getLogger(__name__) + + +def fast_collate(batch, target_dtype=torch.uint8): + """A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)""" + assert isinstance(batch[0], tuple) + batch_size = len(batch) + if isinstance(batch[0][0], np.ndarray): + targets = torch.tensor([b[1] for b in batch], dtype=target_dtype) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i] += torch.from_numpy(batch[i][0]) + return tensor, targets + else: + raise ValueError(f"Incorrect batch type: {type(batch[0][0])}") + + +def adapt_to_chs(x, n): + if not isinstance(x, (tuple, list)): + x = tuple(repeat(x, n)) + elif len(x) != n: + # doubled channels + if len(x) * 2 == n: + x = np.concatenate((x, x)) + _logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.") + else: + x_mean = np.mean(x).item() + x = (x_mean,) * n + _logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.") + else: + assert len(x) == n, "normalization stats must match image channels" + return x + + +class PrefetchLoaderForMultiInput(PrefetchLoader): + def __init__( + self, + loader, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + channels=3, + device=torch.device("cuda"), + img_dtype=torch.float32, + ): + + mean = adapt_to_chs(mean, channels) + std = adapt_to_chs(std, channels) + normalization_shape = (1, channels, 1, 1) + + self.loader = loader + self.device = device + self.img_dtype = img_dtype + self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape) + self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape) + + self.is_cuda = torch.cuda.is_available() and device.type == "cuda" + + def __iter__(self): + first = True + if self.is_cuda: + stream = torch.cuda.Stream() + stream_context = partial(torch.cuda.stream, stream=stream) + else: + stream = None + stream_context = suppress + + for next_input, next_target in self.loader: + + with stream_context(): + next_input = next_input.to(device=self.device, non_blocking=True) + next_target = next_target.to(device=self.device, non_blocking=True) + next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std) + + if not first: + yield input, target # noqa: F823, F821 + else: + first = False + + if stream is not None: + torch.cuda.current_stream().wait_stream(stream) + + input = next_input + target = next_target + + yield input, target + + +def create_loader( + dataset, + input_size, + batch_size, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_workers=1, + crop_pct=None, + crop_mode=None, + pin_memory=False, + img_dtype=torch.float32, + device=torch.device("cuda"), + persistent_workers=True, + worker_seeding="all", + target_type=torch.int64, +): + + transform = create_transform( + input_size, + is_training=False, + use_prefetcher=True, + mean=mean, + std=std, + crop_pct=crop_pct, + crop_mode=crop_mode, + ) + dataset.transform = transform + + if isinstance(dataset, IterableImageDataset): + # give Iterable datasets early knowledge of num_workers so that sample estimates + # are correct before worker processes are launched + dataset.set_loader_cfg(num_workers=num_workers) + raise ValueError("Incorrect dataset type: IterableImageDataset") + + loader_class = torch.utils.data.DataLoader + loader_args = dict( + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=None, + collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type), + pin_memory=pin_memory, + drop_last=False, + worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), + persistent_workers=persistent_workers, + ) + try: + loader = loader_class(dataset, **loader_args) + except TypeError: + loader_args.pop("persistent_workers") # only in Pytorch 1.7+ + loader = loader_class(dataset, **loader_args) + + loader = PrefetchLoaderForMultiInput( + loader, + mean=mean, + std=std, + channels=input_size[0], + device=device, + img_dtype=img_dtype, + ) + + return loader diff --git a/age_estimator/mivolo/mivolo/data/dataset/classification_dataset.py b/age_estimator/mivolo/mivolo/data/dataset/classification_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2157dd548ef3749daa6704b597bfd4776781fe2c --- /dev/null +++ b/age_estimator/mivolo/mivolo/data/dataset/classification_dataset.py @@ -0,0 +1,47 @@ +from typing import Any, List, Optional + +import torch + +from .age_gender_dataset import AgeGenderDataset + + +class ClassificationDataset(AgeGenderDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.target_dtype = torch.int32 + + def set_age_classes(self) -> Optional[List[str]]: + raise NotImplementedError + + def parse_target(self, age: str, gender: str) -> List[Any]: + assert self.age_classes is not None + if age != "-1": + assert age in self.age_classes, f"Unknown category in {self.name} dataset: {age}" + age_ind = self.age_classes.index(age) + else: + age_ind = -1 + + target: List[int] = [age_ind, int(self.parse_gender(gender))] + return target + + +class FairFaceDataset(ClassificationDataset): + def set_age_classes(self) -> Optional[List[str]]: + age_classes = ["0;2", "3;9", "10;19", "20;29", "30;39", "40;49", "50;59", "60;69", "70;120"] + # a[i-1] <= v < a[i] => age_classes[i-1] + self._intervals = torch.tensor([0, 3, 10, 20, 30, 40, 50, 60, 70]) + return age_classes + + +class AdienceDataset(ClassificationDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.target_dtype = torch.int32 + + def set_age_classes(self) -> Optional[List[str]]: + age_classes = ["0;2", "4;6", "8;12", "15;20", "25;32", "38;43", "48;53", "60;100"] + # a[i-1] <= v < a[i] => age_classes[i-1] + self._intervals = torch.tensor([0, 4, 7, 14, 24, 36, 46, 57]) + return age_classes diff --git a/age_estimator/mivolo/mivolo/data/dataset/reader_age_gender.py b/age_estimator/mivolo/mivolo/data/dataset/reader_age_gender.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b7d82b25fceeb56791ca5a77f4d6e8db72c2f5 --- /dev/null +++ b/age_estimator/mivolo/mivolo/data/dataset/reader_age_gender.py @@ -0,0 +1,492 @@ +import logging +import os +from functools import partial +from multiprocessing.pool import ThreadPool +from typing import Dict, List, Optional, Tuple + +import cv2 +import numpy as np +from mivolo.data.data_reader import AnnotType, PictureInfo, get_all_files, read_csv_annotation_file +from mivolo.data.misc import IOU, class_letterbox +from timm.data.readers.reader import Reader +from tqdm import tqdm + +CROP_ROUND_TOL = 0.3 +MIN_PERSON_SIZE = 100 +MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4 + +_logger = logging.getLogger("ReaderAgeGender") + + +class ReaderAgeGender(Reader): + """ + Reader for almost original imdb-wiki cleaned dataset. + Two changes: + 1. Your annotation must be in ./annotation subdir of dataset root + 2. Images must be in images subdir + + """ + + def __init__( + self, + images_path, + annotations_path, + split="validation", + target_size=224, + min_size=5, + seed=1234, + with_persons=False, + min_person_size=MIN_PERSON_SIZE, + disable_faces=False, + only_age=False, + min_person_aftercut_ratio=MIN_PERSON_CROP_AFTERCUT_RATIO, + crop_round_tol=CROP_ROUND_TOL, + ): + super().__init__() + + self.with_persons = with_persons + self.disable_faces = disable_faces + self.only_age = only_age + + # can be only black for now, even though it's not very good with further normalization + self.crop_out_color = (0, 0, 0) + + self.empty_crop = np.ones((target_size, target_size, 3)) * self.crop_out_color + self.empty_crop = self.empty_crop.astype(np.uint8) + + self.min_person_size = min_person_size + self.min_person_aftercut_ratio = min_person_aftercut_ratio + self.crop_round_tol = crop_round_tol + + splits = split.split(",") + self.splits = [split.strip() for split in splits if len(split.strip())] + assert len(self.splits), "Incorrect split arg" + + self.min_size = min_size + self.seed = seed + self.target_size = target_size + + # Reading annotations. Can be multiple files if annotations_path dir + self._ann: Dict[str, List[PictureInfo]] = {} # list of samples for each image + self._associated_objects: Dict[str, Dict[int, List[List[int]]]] = {} + self._faces_list: List[Tuple[str, int]] = [] # samples from this list will be loaded in __getitem__ + + self._read_annotations(images_path, annotations_path) + _logger.info(f"Dataset length: {len(self._faces_list)} crops") + + def __getitem__(self, index): + return self._read_img_and_label(index) + + def __len__(self): + return len(self._faces_list) + + def _filename(self, index, basename=False, absolute=False): + img_p = self._faces_list[index][0] + return os.path.basename(img_p) if basename else img_p + + def _read_annotations(self, images_path, csvs_path): + self._ann = {} + self._faces_list = [] + self._associated_objects = {} + + csvs = get_all_files(csvs_path, [".csv"]) + csvs = [c for c in csvs if any(split_name in os.path.basename(c) for split_name in self.splits)] + + # load annotations per image + for csv in csvs: + db, ann_type = read_csv_annotation_file(csv, images_path) + if self.with_persons and ann_type != AnnotType.PERSONS: + raise ValueError( + f"Annotation type in file {csv} contains no persons, " + f"but annotations with persons are requested." + ) + self._ann.update(db) + + if len(self._ann) == 0: + raise ValueError("Annotations are empty!") + + self._ann, self._associated_objects = self.prepare_annotations() + images_list = list(self._ann.keys()) + + for img_path in images_list: + for index, image_sample_info in enumerate(self._ann[img_path]): + assert image_sample_info.has_gt( + self.only_age + ), "Annotations must be checked with self.prepare_annotations() func" + self._faces_list.append((img_path, index)) + + def _read_img_and_label(self, index): + if not isinstance(index, int): + raise TypeError("ReaderAgeGender expected index to be integer") + + img_p, face_index = self._faces_list[index] + ann: PictureInfo = self._ann[img_p][face_index] + img = cv2.imread(img_p) + + face_empty = True + if ann.has_face_bbox and not (self.with_persons and self.disable_faces): + face_crop, face_empty = self._get_crop(ann.bbox, img) + + if not self.with_persons and face_empty: + # model without persons + raise ValueError("Annotations must be checked with self.prepare_annotations() func") + + if face_empty: + face_crop = self.empty_crop + + person_empty = True + if self.with_persons or self.disable_faces: + if ann.has_person_bbox: + # cut off all associated objects from person crop + objects = self._associated_objects[img_p][face_index] + person_crop, person_empty = self._get_crop( + ann.person_bbox, + img, + crop_out_color=self.crop_out_color, + asced_objects=objects, + ) + + if face_empty and person_empty: + raise ValueError("Annotations must be checked with self.prepare_annotations() func") + + if person_empty: + person_crop = self.empty_crop + + return (face_crop, person_crop), [ann.age, ann.gender] + + def _get_crop( + self, + bbox, + img, + asced_objects=None, + crop_out_color=(0, 0, 0), + ) -> Tuple[np.ndarray, bool]: + + empty_bbox = False + + xmin, ymin, xmax, ymax = bbox + assert not ( + ymax - ymin < self.min_size or xmax - xmin < self.min_size + ), "Annotations must be checked with self.prepare_annotations() func" + + crop = img[ymin:ymax, xmin:xmax] + + if asced_objects: + # cut off other objects for person crop + crop, empty_bbox = _cropout_asced_objs( + asced_objects, + bbox, + crop.copy(), + crop_out_color=crop_out_color, + min_person_size=self.min_person_size, + crop_round_tol=self.crop_round_tol, + min_person_aftercut_ratio=self.min_person_aftercut_ratio, + ) + if empty_bbox: + crop = self.empty_crop + + crop = class_letterbox(crop, new_shape=(self.target_size, self.target_size), color=crop_out_color) + return crop, empty_bbox + + def prepare_annotations(self): + + good_anns: Dict[str, List[PictureInfo]] = {} + all_associated_objects: Dict[str, Dict[int, List[List[int]]]] = {} + + if not self.with_persons: + # remove all persons + for img_path, bboxes in self._ann.items(): + for sample in bboxes: + sample.clear_person_bbox() + + # check dataset and collect associated_objects + verify_images_func = partial( + verify_images, + min_size=self.min_size, + min_person_size=self.min_person_size, + with_persons=self.with_persons, + disable_faces=self.disable_faces, + crop_round_tol=self.crop_round_tol, + min_person_aftercut_ratio=self.min_person_aftercut_ratio, + only_age=self.only_age, + ) + num_threads = min(8, os.cpu_count()) + + all_msgs = [] + broken = 0 + skipped = 0 + all_skipped_crops = 0 + desc = "Check annotations..." + with ThreadPool(num_threads) as pool: + pbar = tqdm( + pool.imap_unordered(verify_images_func, list(self._ann.items())), + desc=desc, + total=len(self._ann), + ) + + for (img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops) in pbar: + broken += 1 if is_corrupted else 0 + all_msgs.extend(msgs) + all_skipped_crops += skipped_crops + skipped += 1 if is_empty_annotations else 0 + if img_info is not None: + img_path, img_samples = img_info + good_anns[img_path] = img_samples + all_associated_objects.update({img_path: associated_objects}) + + pbar.desc = ( + f"{desc} {skipped} images skipped ({all_skipped_crops} crops are incorrect); " + f"{broken} images corrupted" + ) + + pbar.close() + + for msg in all_msgs: + print(msg) + print(f"\nLeft images: {len(good_anns)}") + + return good_anns, all_associated_objects + + +def verify_images( + img_info, + min_size: int, + min_person_size: int, + with_persons: bool, + disable_faces: bool, + crop_round_tol: float, + min_person_aftercut_ratio: float, + only_age: bool, +): + # If crop is too small, if image can not be read or if image does not exist + # then filter out this sample + + disable_faces = disable_faces and with_persons + kwargs = dict( + min_person_size=min_person_size, + disable_faces=disable_faces, + with_persons=with_persons, + crop_round_tol=crop_round_tol, + min_person_aftercut_ratio=min_person_aftercut_ratio, + only_age=only_age, + ) + + def bbox_correct(bbox, min_size, im_h, im_w) -> Tuple[bool, List[int]]: + ymin, ymax, xmin, xmax = _correct_bbox(bbox, im_h, im_w) + crop_h, crop_w = ymax - ymin, xmax - xmin + if crop_h < min_size or crop_w < min_size: + return False, [-1, -1, -1, -1] + bbox = [xmin, ymin, xmax, ymax] + return True, bbox + + msgs = [] + skipped_crops = 0 + is_corrupted = False + is_empty_annotations = False + + img_path: str = img_info[0] + img_samples: List[PictureInfo] = img_info[1] + try: + im_cv = cv2.imread(img_path) + im_h, im_w = im_cv.shape[:2] + except Exception: + msgs.append(f"Can not load image {img_path}") + is_corrupted = True + return None, {}, msgs, is_corrupted, is_empty_annotations, skipped_crops + + out_samples: List[PictureInfo] = [] + for sample in img_samples: + # correct face bbox + if sample.has_face_bbox: + is_correct, sample.bbox = bbox_correct(sample.bbox, min_size, im_h, im_w) + if not is_correct and sample.has_gt(only_age): + msgs.append("Small face. Passing..") + skipped_crops += 1 + + # correct person bbox + if sample.has_person_bbox: + is_correct, sample.person_bbox = bbox_correct( + sample.person_bbox, max(min_person_size, min_size), im_h, im_w + ) + if not is_correct and sample.has_gt(only_age): + msgs.append(f"Small person {img_path}. Passing..") + skipped_crops += 1 + + if sample.has_face_bbox or sample.has_person_bbox: + out_samples.append(sample) + elif sample.has_gt(only_age): + msgs.append("Sample has no face and no body. Passing..") + skipped_crops += 1 + + # sort that samples with undefined age and gender be the last + out_samples = sorted(out_samples, key=lambda sample: 1 if not sample.has_gt(only_age) else 0) + + # for each person find other faces and persons bboxes, intersected with it + associated_objects: Dict[int, List[List[int]]] = find_associated_objects(out_samples, only_age=only_age) + + out_samples, associated_objects, skipped_crops = filter_bad_samples( + out_samples, associated_objects, im_cv, msgs, skipped_crops, **kwargs + ) + + out_img_info: Optional[Tuple[str, List]] = (img_path, out_samples) + if len(out_samples) == 0: + out_img_info = None + is_empty_annotations = True + + return out_img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops + + +def filter_bad_samples( + out_samples: List[PictureInfo], + associated_objects: dict, + im_cv: np.ndarray, + msgs: List[str], + skipped_crops: int, + **kwargs, +): + with_persons, disable_faces, min_person_size, crop_round_tol, min_person_aftercut_ratio, only_age = ( + kwargs["with_persons"], + kwargs["disable_faces"], + kwargs["min_person_size"], + kwargs["crop_round_tol"], + kwargs["min_person_aftercut_ratio"], + kwargs["only_age"], + ) + + # left only samples with annotations + inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_gt(only_age)] + out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) + + if kwargs["disable_faces"]: + # clear all faces + for ind, sample in enumerate(out_samples): + sample.clear_face_bbox() + + # left only samples with person_bbox + inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_person_bbox] + out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) + + if with_persons or disable_faces: + # check that preprocessing func + # _cropout_asced_objs() return not empty person_image for each out sample + + inds = [] + for ind, sample in enumerate(out_samples): + person_empty = True + if sample.has_person_bbox: + xmin, ymin, xmax, ymax = sample.person_bbox + crop = im_cv[ymin:ymax, xmin:xmax] + # cut off all associated objects from person crop + _, person_empty = _cropout_asced_objs( + associated_objects[ind], + sample.person_bbox, + crop.copy(), + min_person_size=min_person_size, + crop_round_tol=crop_round_tol, + min_person_aftercut_ratio=min_person_aftercut_ratio, + ) + + if person_empty and not sample.has_face_bbox: + msgs.append("Small person after preprocessing. Passing..") + skipped_crops += 1 + else: + inds.append(ind) + out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) + + assert len(associated_objects) == len(out_samples) + return out_samples, associated_objects, skipped_crops + + +def _filter_by_ind(out_samples, associated_objects, inds): + _associated_objects = {} + _out_samples = [] + for ind, sample in enumerate(out_samples): + if ind in inds: + _associated_objects[len(_out_samples)] = associated_objects[ind] + _out_samples.append(sample) + + return _out_samples, _associated_objects + + +def find_associated_objects( + image_samples: List[PictureInfo], iou_thresh=0.0001, only_age=False +) -> Dict[int, List[List[int]]]: + """ + For each person (which has gt age and gt gender) find other faces and persons bboxes, intersected with it + """ + associated_objects: Dict[int, List[List[int]]] = {} + + for iindex, image_sample_info in enumerate(image_samples): + # add own face + associated_objects[iindex] = [image_sample_info.bbox] if image_sample_info.has_face_bbox else [] + + if not image_sample_info.has_person_bbox or not image_sample_info.has_gt(only_age): + # if sample has not gt => not be used + continue + + iperson_box = image_sample_info.person_bbox + for jindex, other_image_sample in enumerate(image_samples): + if iindex == jindex: + continue + if other_image_sample.has_face_bbox: + jface_bbox = other_image_sample.bbox + iou = _get_iou(jface_bbox, iperson_box) + if iou >= iou_thresh: + associated_objects[iindex].append(jface_bbox) + if other_image_sample.has_person_bbox: + jperson_bbox = other_image_sample.person_bbox + iou = _get_iou(jperson_bbox, iperson_box) + if iou >= iou_thresh: + associated_objects[iindex].append(jperson_bbox) + + return associated_objects + + +def _cropout_asced_objs( + asced_objects, + person_bbox, + crop, + min_person_size, + crop_round_tol, + min_person_aftercut_ratio, + crop_out_color=(0, 0, 0), +): + empty = False + xmin, ymin, xmax, ymax = person_bbox + + for a_obj in asced_objects: + aobj_xmin, aobj_ymin, aobj_xmax, aobj_ymax = a_obj + + aobj_ymin = int(max(aobj_ymin - ymin, 0)) + aobj_xmin = int(max(aobj_xmin - xmin, 0)) + aobj_ymax = int(min(aobj_ymax - ymin, ymax - ymin)) + aobj_xmax = int(min(aobj_xmax - xmin, xmax - xmin)) + + crop[aobj_ymin:aobj_ymax, aobj_xmin:aobj_xmax] = crop_out_color + + # calc useful non-black area + remain_ratio = np.count_nonzero(crop) / (crop.shape[0] * crop.shape[1] * crop.shape[2]) + if (crop.shape[0] < min_person_size or crop.shape[1] < min_person_size) or remain_ratio < min_person_aftercut_ratio: + crop = None + empty = True + + return crop, empty + + +def _correct_bbox(bbox, h, w): + xmin, ymin, xmax, ymax = bbox + ymin = min(max(ymin, 0), h) + ymax = min(max(ymax, 0), h) + xmin = min(max(xmin, 0), w) + xmax = min(max(xmax, 0), w) + return ymin, ymax, xmin, xmax + + +def _get_iou(bbox1, bbox2): + xmin1, ymin1, xmax1, ymax1 = bbox1 + xmin2, ymin2, xmax2, ymax2 = bbox2 + iou = IOU( + [ymin1, xmin1, ymax1, xmax1], + [ymin2, xmin2, ymax2, xmax2], + ) + return iou diff --git a/age_estimator/mivolo/mivolo/data/misc.py b/age_estimator/mivolo/mivolo/data/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..603223b90a3b296cfd7f6da9044e1cb288e03621 --- /dev/null +++ b/age_estimator/mivolo/mivolo/data/misc.py @@ -0,0 +1,246 @@ +import argparse +import ast +import re +from typing import List, Optional, Tuple, Union + +import cv2 +import numpy as np +import torch +import torchvision.transforms.functional as F +from scipy.optimize import linear_sum_assignment +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + +CROP_ROUND_RATE = 0.1 +MIN_PERSON_CROP_NONZERO = 0.5 + + +def aggregate_votes_winsorized(ages, max_age_dist=6): + # Replace any annotation that is more than a max_age_dist away from the median + # with the median + max_age_dist if higher or max_age_dist - max_age_dist if below + median = np.median(ages) + ages = np.clip(ages, median - max_age_dist, median + max_age_dist) + return np.mean(ages) + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def add_bool_arg(parser, name, default=False, help=""): + dest_name = name.replace("-", "_") + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument("--" + name, dest=dest_name, action="store_true", help=help) + group.add_argument("--no-" + name, dest=dest_name, action="store_false", help=help) + parser.set_defaults(**{dest_name: default}) + + +def cumulative_score(pred_ages, gt_ages, L, tol=1e-6): + n = pred_ages.shape[0] + num_correct = torch.sum(torch.abs(pred_ages - gt_ages) <= L + tol) + cs_score = num_correct / n + return cs_score + + +def cumulative_error(pred_ages, gt_ages, L, tol=1e-6): + n = pred_ages.shape[0] + num_correct = torch.sum(torch.abs(pred_ages - gt_ages) >= L + tol) + cs_score = num_correct / n + return cs_score + + +class ParseKwargs(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + kw = {} + for value in values: + key, value = value.split("=") + try: + kw[key] = ast.literal_eval(value) + except ValueError: + kw[key] = str(value) # fallback to string (avoid need to escape on command line) + setattr(namespace, self.dest, kw) + + +def box_iou(box1, box2, over_second=False): + """ + Return intersection-over-union (Jaccard index) of boxes. + If over_second == True, return mean(intersection-over-union, (inter / area2)) + + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + # box = 4xn + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + + iou = inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) + if over_second: + return (inter / area2 + iou) / 2 # mean(inter / area2, iou) + else: + return iou + + +def split_batch(bs: int, dev: int) -> Tuple[int, int]: + full_bs = (bs // dev) * dev + part_bs = bs - full_bs + return full_bs, part_bs + + +def assign_faces( + persons_bboxes: List[torch.tensor], faces_bboxes: List[torch.tensor], iou_thresh: float = 0.0001 +) -> Tuple[List[Optional[int]], List[int]]: + """ + Assign person to each face if it is possible. + Return: + - assigned_faces List[Optional[int]]: mapping of face_ind to person_ind + ( assigned_faces[face_ind] = person_ind ). person_ind can be None + - unassigned_persons_inds List[int]: persons indexes without any assigned face + """ + + assigned_faces: List[Optional[int]] = [None for _ in range(len(faces_bboxes))] + unassigned_persons_inds: List[int] = [p_ind for p_ind in range(len(persons_bboxes))] + + if len(persons_bboxes) == 0 or len(faces_bboxes) == 0: + return assigned_faces, unassigned_persons_inds + + cost_matrix = box_iou(torch.stack(persons_bboxes), torch.stack(faces_bboxes), over_second=True).cpu().numpy() + persons_indexes, face_indexes = [], [] + + if len(cost_matrix) > 0: + persons_indexes, face_indexes = linear_sum_assignment(cost_matrix, maximize=True) + + matched_persons = set() + for person_idx, face_idx in zip(persons_indexes, face_indexes): + ciou = cost_matrix[person_idx][face_idx] + if ciou > iou_thresh: + if person_idx in matched_persons: + # Person can not be assigned twice, in reality this should not happen + continue + assigned_faces[face_idx] = person_idx + matched_persons.add(person_idx) + + unassigned_persons_inds = [p_ind for p_ind in range(len(persons_bboxes)) if p_ind not in matched_persons] + + return assigned_faces, unassigned_persons_inds + + +def class_letterbox(im, new_shape=(640, 640), color=(0, 0, 0), scaleup=True): + # Resize and pad image while meeting stride-multiple constraints + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + if im.shape[0] == new_shape[0] and im.shape[1] == new_shape[1]: + return im + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + # ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return im + + +def prepare_classification_images( + img_list: List[Optional[np.ndarray]], + target_size: int = 224, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + device=None, +) -> torch.tensor: + + prepared_images: List[torch.tensor] = [] + + for img in img_list: + if img is None: + img = torch.zeros((3, target_size, target_size), dtype=torch.float32) + img = F.normalize(img, mean=mean, std=std) + img = img.unsqueeze(0) + prepared_images.append(img) + continue + img = class_letterbox(img, new_shape=(target_size, target_size)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = img / 255.0 + img = (img - mean) / std + img = img.astype(dtype=np.float32) + + img = img.transpose((2, 0, 1)) + img = np.ascontiguousarray(img) + img = torch.from_numpy(img) + img = img.unsqueeze(0) + + prepared_images.append(img) + + if len(prepared_images) == 0: + return None + + prepared_input = torch.concat(prepared_images) + + if device: + prepared_input = prepared_input.to(device) + + return prepared_input + + +def IOU(bb1: Union[tuple, list], bb2: Union[tuple, list], norm_second_bbox: bool = False) -> float: + # expects [ymin, xmin, ymax, xmax], doesnt matter absolute or relative + assert bb1[1] < bb1[3] + assert bb1[0] < bb1[2] + assert bb2[1] < bb2[3] + assert bb2[0] < bb2[2] + + # determine the coordinates of the intersection rectangle + x_left = max(bb1[1], bb2[1]) + y_top = max(bb1[0], bb2[0]) + x_right = min(bb1[3], bb2[3]) + y_bottom = min(bb1[2], bb2[2]) + + if x_right < x_left or y_bottom < y_top: + return 0.0 + + # The intersection of two axis-aligned bounding boxes is always an + # axis-aligned bounding box + intersection_area = (x_right - x_left) * (y_bottom - y_top) + # compute the area of both AABBs + bb1_area = (bb1[3] - bb1[1]) * (bb1[2] - bb1[0]) + bb2_area = (bb2[3] - bb2[1]) * (bb2[2] - bb2[0]) + if not norm_second_bbox: + # compute the intersection over union by taking the intersection + # area and dividing it by the sum of prediction + ground-truth + # areas - the interesection area + iou = intersection_area / float(bb1_area + bb2_area - intersection_area) + else: + # for cases when we search if second bbox is inside first one + iou = intersection_area / float(bb2_area) + + assert iou >= 0.0 + assert iou <= 1.01 + + return iou diff --git a/age_estimator/mivolo/mivolo/model/__init__.py b/age_estimator/mivolo/mivolo/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/__init__.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d87c2c0392526ecce3b2f9b0c0d6f86757d907a7 Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/create_timm_model.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/create_timm_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b757480d8cf16cf410324c62d0f115bcac4137b Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/create_timm_model.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/cross_bottleneck_attn.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/cross_bottleneck_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a475ed6f1765509d3fcfe86045d370e4ef762a8 Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/cross_bottleneck_attn.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/mi_volo.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/mi_volo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bb5fb506a9cc1d226c50f23512662b2db4b33e7 Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/mi_volo.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/mivolo_model.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/mivolo_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c24ba1eef8ec2ae4717dd358a88be8362d05cea1 Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/mivolo_model.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/model/__pycache__/yolo_detector.cpython-38.pyc b/age_estimator/mivolo/mivolo/model/__pycache__/yolo_detector.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01af016e7ad0f0286ba1457da77811eea2a35c3d Binary files /dev/null and b/age_estimator/mivolo/mivolo/model/__pycache__/yolo_detector.cpython-38.pyc differ diff --git a/age_estimator/mivolo/mivolo/model/create_timm_model.py b/age_estimator/mivolo/mivolo/model/create_timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d7545b01cd1a3c612787c25c06d926f6cade4d98 --- /dev/null +++ b/age_estimator/mivolo/mivolo/model/create_timm_model.py @@ -0,0 +1,107 @@ +""" +Code adapted from timm https://github.com/huggingface/pytorch-image-models + +Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich +""" + +import os +from typing import Any, Dict, Optional, Union + +import timm + +# register new models +from mivolo.model.mivolo_model import * # noqa: F403, F401 +from timm.layers import set_layer_config +from timm.models._factory import parse_model_name +from timm.models._helpers import load_state_dict, remap_checkpoint +from timm.models._hub import load_model_config_from_hf +from timm.models._pretrained import PretrainedCfg, split_model_name_tag +from timm.models._registry import is_model, model_entrypoint + + +def load_checkpoint( + model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None +): + if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, "load_pretrained"): + timm.models._model_builder.load_pretrained(checkpoint_path) + else: + raise NotImplementedError("Model cannot load numpy checkpoint") + return + state_dict = load_state_dict(checkpoint_path, use_ema) + if remap: + state_dict = remap_checkpoint(model, state_dict) + if filter_keys: + for sd_key in list(state_dict.keys()): + for filter_key in filter_keys: + if filter_key in sd_key: + if sd_key in state_dict: + del state_dict[sd_key] + + rep = [] + if state_dict_map is not None: + # 'patch_embed.conv1.' : 'patch_embed.conv.' + for state_k in list(state_dict.keys()): + for target_k, target_v in state_dict_map.items(): + if target_v in state_k: + target_name = state_k.replace(target_v, target_k) + state_dict[target_name] = state_dict[state_k] + rep.append(state_k) + for r in rep: + if r in state_dict: + del state_dict[r] + + incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False) + return incompatible_keys + + +def create_model( + model_name: str, + pretrained: bool = False, + pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, + pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, + checkpoint_path: str = "", + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + filter_keys=None, + state_dict_map=None, + **kwargs, +): + """Create a model + Lookup model's entrypoint function and pass relevant args to create a new model. + """ + # Parameters that aren't supported by all models or are intended to only override model defaults if set + # should default to None in command line args/cfg. Remove them if they are present and not set so that + # non-supporting models don't break and default args remain in effect. + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + model_source, model_name = parse_model_name(model_name) + if model_source == "hf-hub": + assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub." + # For model names specified in the form `hf-hub:path/architecture_name@revision`, + # load model weights + pretrained_cfg from Hugging Face hub. + pretrained_cfg, model_name = load_model_config_from_hf(model_name) + else: + model_name, pretrained_tag = split_model_name_tag(model_name) + if not pretrained_cfg: + # a valid pretrained_cfg argument takes priority over tag in model name + pretrained_cfg = pretrained_tag + + if not is_model(model_name): + raise RuntimeError("Unknown model (%s)" % model_name) + + create_fn = model_entrypoint(model_name) + with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): + model = create_fn( + pretrained=pretrained, + pretrained_cfg=pretrained_cfg, + pretrained_cfg_overlay=pretrained_cfg_overlay, + **kwargs, + ) + + if checkpoint_path: + load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map) + + return model diff --git a/age_estimator/mivolo/mivolo/model/cross_bottleneck_attn.py b/age_estimator/mivolo/mivolo/model/cross_bottleneck_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..44976bf39ed22280ce640d5f41b6627e21bd543a --- /dev/null +++ b/age_estimator/mivolo/mivolo/model/cross_bottleneck_attn.py @@ -0,0 +1,116 @@ +""" +Code based on timm https://github.com/huggingface/pytorch-image-models + +Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich +""" + +import torch +import torch.nn as nn +from timm.layers.bottleneck_attn import PosEmbedRel +from timm.layers.helpers import make_divisible +from timm.layers.mlp import Mlp +from timm.layers.trace_utils import _assert +from timm.layers.weight_init import trunc_normal_ + + +class CrossBottleneckAttn(nn.Module): + def __init__( + self, + dim, + dim_out=None, + feat_size=None, + stride=1, + num_heads=4, + dim_head=None, + qk_ratio=1.0, + qkv_bias=False, + scale_pos_embed=False, + ): + super().__init__() + assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required" + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + + self.num_heads = num_heads + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk**-0.5 + self.scale_pos_embed = scale_pos_embed + + self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) + self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) + + # NOTE I'm only supporting relative pos embedding for now + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale) + + self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size]) + mlp_ratio = 4 + self.mlp = Mlp( + in_features=self.dim_out_v * 2, + hidden_features=int(dim * mlp_ratio), + act_layer=nn.GELU, + out_features=dim_out, + drop=0, + use_conv=True, + ) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + self.reset_parameters() + + def reset_parameters(self): + trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in + trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def get_qkv(self, x, qvk_conv): + B, C, H, W = x.shape + + x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W + + q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1) + + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2) + k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k + v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) + + return q, k, v + + def apply_attn(self, q, k, v, B, H, W, dropout=None): + if self.scale_pos_embed: + attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W + else: + attn = (q @ k) * self.scale + self.pos_embed(q) + attn = attn.softmax(dim=-1) + if dropout: + attn = dropout(attn) + + out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W + return out + + def forward(self, x): + B, C, H, W = x.shape + + dim = int(C / 2) + x1 = x[:, :dim, :, :] + x2 = x[:, dim:, :, :] + + _assert(H == self.pos_embed.height, "") + _assert(W == self.pos_embed.width, "") + + q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f) + q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p) + + # person to face + out_f = self.apply_attn(q_f, k_p, v_p, B, H, W) + # face to person + out_p = self.apply_attn(q_p, k_f, v_f, B, H, W) + + x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W + x_pf = self.norm(x_pf) + x_pf = self.mlp(x_pf) # B, dim_out, H, W + + out = self.pool(x_pf) + return out diff --git a/age_estimator/mivolo/mivolo/model/mi_volo.py b/age_estimator/mivolo/mivolo/model/mi_volo.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4c824f8f918dc0f2854000afef66fe11573965 --- /dev/null +++ b/age_estimator/mivolo/mivolo/model/mi_volo.py @@ -0,0 +1,243 @@ +import logging +from typing import Optional + +import numpy as np +import torch +from mivolo.data.misc import prepare_classification_images +from mivolo.model.create_timm_model import create_model +from mivolo.structures import PersonAndFaceCrops, PersonAndFaceResult +from timm.data import resolve_data_config + +_logger = logging.getLogger("MiVOLO") +has_compile = hasattr(torch, "compile") + + +class Meta: + def __init__(self): + self.min_age = None + self.max_age = None + self.avg_age = None + self.num_classes = None + + self.in_chans = 3 + self.with_persons_model = False + self.disable_faces = False + self.use_persons = True + self.only_age = False + + self.num_classes_gender = 2 + self.input_size = 224 + + def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta": + + state = torch.load(ckpt_path, map_location="cpu") + + self.min_age = state["min_age"] + self.max_age = state["max_age"] + self.avg_age = state["avg_age"] + self.only_age = state["no_gender"] + + only_age = state["no_gender"] + + self.disable_faces = disable_faces + if "with_persons_model" in state: + self.with_persons_model = state["with_persons_model"] + else: + self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False + + self.num_classes = 1 if only_age else 3 + self.in_chans = 3 if not self.with_persons_model else 6 + self.use_persons = use_persons and self.with_persons_model + + if not self.with_persons_model and self.disable_faces: + raise ValueError("You can not use disable-faces for faces-only model") + if self.with_persons_model and self.disable_faces and not self.use_persons: + raise ValueError( + "You can not disable faces and persons together. " + "Set --with-persons if you want to run with --disable-faces" + ) + self.input_size = state["state_dict"]["pos_embed"].shape[1] * 16 + return self + + def __str__(self): + attrs = vars(self) + attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops}) + return ", ".join("%s: %s" % item for item in attrs.items()) + + @property + def use_person_crops(self) -> bool: + return self.with_persons_model and self.use_persons + + @property + def use_face_crops(self) -> bool: + return not self.disable_faces or not self.with_persons_model + + +class MiVOLO: + def __init__( + self, + ckpt_path: str, + device: str = "cuda", + half: bool = True, + disable_faces: bool = False, + use_persons: bool = True, + verbose: bool = False, + torchcompile: Optional[str] = None, + ): + self.verbose = verbose + self.device = torch.device(device) + self.half = half and self.device.type != "cpu" + + self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons) + if self.verbose: + _logger.info(f"Model meta:\n{str(self.meta)}") + + model_name = f"mivolo_d1_{self.meta.input_size}" + self.model = create_model( + model_name=model_name, + num_classes=self.meta.num_classes, + in_chans=self.meta.in_chans, + pretrained=False, + checkpoint_path=ckpt_path, + filter_keys=["fds."], + ) + self.param_count = sum([m.numel() for m in self.model.parameters()]) + _logger.info(f"Model {model_name} created, param count: {self.param_count}") + + self.data_config = resolve_data_config( + model=self.model, + verbose=verbose, + use_test_size=True, + ) + + self.data_config["crop_pct"] = 1.0 + c, h, w = self.data_config["input_size"] + assert h == w, "Incorrect data_config" + self.input_size = w + + self.model = self.model.to(self.device) + + if torchcompile: + assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly." + torch._dynamo.reset() + self.model = torch.compile(self.model, backend=torchcompile) + + self.model.eval() + if self.half: + self.model = self.model.half() + + def warmup(self, batch_size: int, steps=10): + if self.meta.with_persons_model: + input_size = (6, self.input_size, self.input_size) + else: + input_size = self.data_config["input_size"] + + input = torch.randn((batch_size,) + tuple(input_size)).to(self.device) + + for _ in range(steps): + out = self.inference(input) # noqa: F841 + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def inference(self, model_input: torch.tensor) -> torch.tensor: + + with torch.no_grad(): + if self.half: + model_input = model_input.half() + output = self.model(model_input) + return output + + def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult): + if ( + (detected_bboxes.n_objects == 0) + or (not self.meta.use_persons and detected_bboxes.n_faces == 0) + or (self.meta.disable_faces and detected_bboxes.n_persons == 0) + ): + # nothing to process + return + + faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes) + + if faces_input is None and person_input is None: + # nothing to process + return + + if self.meta.with_persons_model: + model_input = torch.cat((faces_input, person_input), dim=1) + else: + model_input = faces_input + output = self.inference(model_input) + + # write gender and age results into detected_bboxes + self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds) + + def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds): + if self.meta.only_age: + age_output = output + gender_probs, gender_indx = None, None + else: + age_output = output[:, 2] + gender_output = output[:, :2].softmax(-1) + gender_probs, gender_indx = gender_output.topk(1) + + assert output.shape[0] == len(faces_inds) == len(bodies_inds) + + # per face + for index in range(output.shape[0]): + face_ind = faces_inds[index] + body_ind = bodies_inds[index] + + # get_age + age = age_output[index].item() + age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age + age = round(age, 2) + + detected_bboxes.set_age(face_ind, age) + detected_bboxes.set_age(body_ind, age) + + _logger.info(f"\tage: {age}") + + if gender_probs is not None: + gender = "male" if gender_indx[index].item() == 0 else "female" + gender_score = gender_probs[index].item() + + _logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]") + + detected_bboxes.set_gender(face_ind, gender, gender_score) + detected_bboxes.set_gender(body_ind, gender, gender_score) + + def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult): + + if self.meta.use_person_crops and self.meta.use_face_crops: + detected_bboxes.associate_faces_with_persons() + + crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image) + (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies( + self.meta.use_person_crops, self.meta.use_face_crops + ) + + if not self.meta.use_face_crops: + assert all(f is None for f in faces_crops) + + faces_input = prepare_classification_images( + faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device + ) + + if not self.meta.use_person_crops: + assert all(p is None for p in bodies_crops) + + person_input = prepare_classification_images( + bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device + ) + + _logger.info( + f"faces_input: {faces_input.shape if faces_input is not None else None}, " + f"person_input: {person_input.shape if person_input is not None else None}" + ) + + return faces_input, person_input, faces_inds, bodies_inds + + +if __name__ == "__main__": + model = MiVOLO("../pretrained/checkpoint-377.pth.tar", half=True, device="cuda:0") diff --git a/age_estimator/mivolo/mivolo/model/mivolo_model.py b/age_estimator/mivolo/mivolo/model/mivolo_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e76c12a2a75a6688f8aa15a0e5f98033c3ef4544 --- /dev/null +++ b/age_estimator/mivolo/mivolo/model/mivolo_model.py @@ -0,0 +1,404 @@ +""" +Code adapted from timm https://github.com/huggingface/pytorch-image-models + +Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich +""" + +import torch +import torch.nn as nn +from mivolo.model.cross_bottleneck_attn import CrossBottleneckAttn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import trunc_normal_ +from timm.models._builder import build_model_with_cfg +from timm.models._registry import register_model +from timm.models.volo import VOLO + +__all__ = ["MiVOLOModel"] # model_registry will add each entrypoint fn to this + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "pool_size": None, + "crop_pct": 0.96, + "interpolation": "bicubic", + "fixed_input_size": True, + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + "first_conv": None, + "classifier": ("head", "aux_head"), + **kwargs, + } + + +default_cfgs = { + "mivolo_d1_224": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar", crop_pct=0.96 + ), + "mivolo_d1_384": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar", + crop_pct=1.0, + input_size=(3, 384, 384), + ), + "mivolo_d2_224": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar", crop_pct=0.96 + ), + "mivolo_d2_384": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar", + crop_pct=1.0, + input_size=(3, 384, 384), + ), + "mivolo_d3_224": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar", crop_pct=0.96 + ), + "mivolo_d3_448": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar", + crop_pct=1.0, + input_size=(3, 448, 448), + ), + "mivolo_d4_224": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar", crop_pct=0.96 + ), + "mivolo_d4_448": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar", + crop_pct=1.15, + input_size=(3, 448, 448), + ), + "mivolo_d5_224": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar", crop_pct=0.96 + ), + "mivolo_d5_448": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar", + crop_pct=1.15, + input_size=(3, 448, 448), + ), + "mivolo_d5_512": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar", + crop_pct=1.15, + input_size=(3, 512, 512), + ), +} + + +def get_output_size(input_shape, conv_layer): + padding = conv_layer.padding + dilation = conv_layer.dilation + kernel_size = conv_layer.kernel_size + stride = conv_layer.stride + + output_size = [ + ((input_shape[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i]) + 1 for i in range(2) + ] + return output_size + + +def get_output_size_module(input_size, stem): + output_size = input_size + + for module in stem: + if isinstance(module, nn.Conv2d): + output_size = [ + ( + (output_size[i] + 2 * module.padding[i] - module.dilation[i] * (module.kernel_size[i] - 1) - 1) + // module.stride[i] + ) + + 1 + for i in range(2) + ] + + return output_size + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding.""" + + def __init__( + self, img_size=224, stem_conv=False, stem_stride=1, patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384 + ): + super().__init__() + assert patch_size in [4, 8, 16] + assert in_chans in [3, 6] + self.with_persons_model = in_chans == 6 + self.use_cross_attn = True + + if stem_conv: + if not self.with_persons_model: + self.conv = self.create_stem(stem_stride, in_chans, hidden_dim) + else: + self.conv = True # just to match interface + # split + self.conv1 = self.create_stem(stem_stride, 3, hidden_dim) + self.conv2 = self.create_stem(stem_stride, 3, hidden_dim) + else: + self.conv = None + + if self.with_persons_model: + + self.proj1 = nn.Conv2d( + hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride + ) + self.proj2 = nn.Conv2d( + hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride + ) + + stem_out_shape = get_output_size_module((img_size, img_size), self.conv1) + self.proj_output_size = get_output_size(stem_out_shape, self.proj1) + + self.map = CrossBottleneckAttn(embed_dim, dim_out=embed_dim, num_heads=1, feat_size=self.proj_output_size) + + else: + self.proj = nn.Conv2d( + hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride + ) + + self.patch_dim = img_size // patch_size + self.num_patches = self.patch_dim**2 + + def create_stem(self, stem_stride, in_chans, hidden_dim): + return nn.Sequential( + nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + if self.conv is not None: + if self.with_persons_model: + x1 = x[:, :3] + x2 = x[:, 3:] + + x1 = self.conv1(x1) + x1 = self.proj1(x1) + + x2 = self.conv2(x2) + x2 = self.proj2(x2) + + x = torch.cat([x1, x2], dim=1) + x = self.map(x) + else: + x = self.conv(x) + x = self.proj(x) # B, C, H, W + + return x + + +class MiVOLOModel(VOLO): + """ + Vision Outlooker, the main class of our model + """ + + def __init__( + self, + layers, + img_size=224, + in_chans=3, + num_classes=1000, + global_pool="token", + patch_size=8, + stem_hidden_dim=64, + embed_dims=None, + num_heads=None, + downsamples=(True, False, False, False), + outlook_attention=(True, False, False, False), + mlp_ratio=3.0, + qkv_bias=False, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + post_layers=("ca", "ca"), + use_aux_head=True, + use_mix_token=False, + pooling_scale=2, + ): + super().__init__( + layers, + img_size, + in_chans, + num_classes, + global_pool, + patch_size, + stem_hidden_dim, + embed_dims, + num_heads, + downsamples, + outlook_attention, + mlp_ratio, + qkv_bias, + drop_rate, + attn_drop_rate, + drop_path_rate, + norm_layer, + post_layers, + use_aux_head, + use_mix_token, + pooling_scale, + ) + + im_size = img_size[0] if isinstance(img_size, tuple) else img_size + self.patch_embed = PatchEmbed( + img_size=im_size, + stem_conv=True, + stem_stride=2, + patch_size=patch_size, + in_chans=in_chans, + hidden_dim=stem_hidden_dim, + embed_dim=embed_dims[0], + ) + + trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) + + def forward_features(self, x): + x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C + + # step2: tokens learning in the two stages + x = self.forward_tokens(x) + + # step3: post network, apply class attention or not + if self.post_network is not None: + x = self.forward_cls(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False, targets=None, epoch=None): + if self.global_pool == "avg": + out = x.mean(dim=1) + elif self.global_pool == "token": + out = x[:, 0] + else: + out = x + if pre_logits: + return out + + features = out + fds_enabled = hasattr(self, "_fds_forward") + if fds_enabled: + features = self._fds_forward(features, targets, epoch) + + out = self.head(features) + if self.aux_head is not None: + # generate classes in all feature tokens, see token labeling + aux = self.aux_head(x[:, 1:]) + out = out + 0.5 * aux.max(1)[0] + + return (out, features) if (fds_enabled and self.training) else out + + def forward(self, x, targets=None, epoch=None): + """simplified forward (without mix token training)""" + x = self.forward_features(x) + x = self.forward_head(x, targets=targets, epoch=epoch) + return x + + +def _create_mivolo(variant, pretrained=False, **kwargs): + if kwargs.get("features_only", None): + raise RuntimeError("features_only not implemented for Vision Transformer models.") + return build_model_with_cfg(MiVOLOModel, variant, pretrained, **kwargs) + + +@register_model +def mivolo_d1_224(pretrained=False, **kwargs): + model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) + model = _create_mivolo("mivolo_d1_224", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d1_384(pretrained=False, **kwargs): + model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) + model = _create_mivolo("mivolo_d1_384", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d2_224(pretrained=False, **kwargs): + model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d2_224", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d2_384(pretrained=False, **kwargs): + model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d2_384", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d3_224(pretrained=False, **kwargs): + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d3_224", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d3_448(pretrained=False, **kwargs): + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d3_448", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d4_224(pretrained=False, **kwargs): + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d4_224", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d4_448(pretrained=False, **kwargs): + """VOLO-D4 model, Params: 193M""" + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d4_448", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d5_224(pretrained=False, **kwargs): + model_args = dict( + layers=(12, 12, 20, 4), + embed_dims=(384, 768, 768, 768), + num_heads=(12, 16, 16, 16), + mlp_ratio=4, + stem_hidden_dim=128, + **kwargs + ) + model = _create_mivolo("mivolo_d5_224", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d5_448(pretrained=False, **kwargs): + model_args = dict( + layers=(12, 12, 20, 4), + embed_dims=(384, 768, 768, 768), + num_heads=(12, 16, 16, 16), + mlp_ratio=4, + stem_hidden_dim=128, + **kwargs + ) + model = _create_mivolo("mivolo_d5_448", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d5_512(pretrained=False, **kwargs): + model_args = dict( + layers=(12, 12, 20, 4), + embed_dims=(384, 768, 768, 768), + num_heads=(12, 16, 16, 16), + mlp_ratio=4, + stem_hidden_dim=128, + **kwargs + ) + model = _create_mivolo("mivolo_d5_512", pretrained=pretrained, **model_args) + return model diff --git a/age_estimator/mivolo/mivolo/model/yolo_detector.py b/age_estimator/mivolo/mivolo/model/yolo_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..dc22a1f142aa3f6bb9640055b5f81e1c553f4570 --- /dev/null +++ b/age_estimator/mivolo/mivolo/model/yolo_detector.py @@ -0,0 +1,46 @@ +import os +from typing import Dict, Union + +import numpy as np +import PIL +import torch +from mivolo.structures import PersonAndFaceResult +from ultralytics import YOLO +from ultralytics.engine.results import Results + +# because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing +os.unsetenv("CUBLAS_WORKSPACE_CONFIG") + + +class Detector: + def __init__( + self, + weights: str, + device: str = "cuda", + half: bool = True, + verbose: bool = False, + conf_thresh: float = 0.4, + iou_thresh: float = 0.7, + ): + self.yolo = YOLO(weights) + self.yolo.fuse() + + self.device = torch.device(device) + self.half = half and self.device.type != "cpu" + + if self.half: + self.yolo.model = self.yolo.model.half() + + self.detector_names: Dict[int, str] = self.yolo.model.names + + # init yolo.predictor + self.detector_kwargs = {"conf": conf_thresh, "iou": iou_thresh, "half": self.half, "verbose": verbose} + # self.yolo.predict(**self.detector_kwargs) + + def predict(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult: + results: Results = self.yolo.predict(image, **self.detector_kwargs)[0] + return PersonAndFaceResult(results) + + def track(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult: + results: Results = self.yolo.track(image, persist=True, **self.detector_kwargs)[0] + return PersonAndFaceResult(results) diff --git a/age_estimator/mivolo/mivolo/predictor.py b/age_estimator/mivolo/mivolo/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..086515493dad7362d990127ae24e4657750c80ae --- /dev/null +++ b/age_estimator/mivolo/mivolo/predictor.py @@ -0,0 +1,80 @@ +from collections import defaultdict +from typing import Dict, Generator, List, Optional, Tuple + +import cv2 +import numpy as np +import tqdm +from mivolo.model.mi_volo import MiVOLO +from mivolo.model.yolo_detector import Detector +from mivolo.structures import AGE_GENDER_TYPE, PersonAndFaceResult + + +class Predictor: + def __init__(self, config, verbose: bool = False): + self.detector = Detector(config.detector_weights, config.device, verbose=verbose) + self.age_gender_model = MiVOLO( + config.checkpoint, + config.device, + half=True, + use_persons=config.with_persons, + disable_faces=config.disable_faces, + verbose=verbose, + ) + self.draw = config.draw + + def recognize(self, image: np.ndarray) -> Tuple[PersonAndFaceResult, Optional[np.ndarray]]: + detected_objects: PersonAndFaceResult = self.detector.predict(image) + self.age_gender_model.predict(image, detected_objects) + + # Assuming 'results' is where detected persons and faces are stored + # Retrieve ages after prediction + age = detected_objects.get_ages() # Adjust this line if `get_ages` method has a different name or structure + if hasattr(detected_objects, 'results'): + for obj in detected_objects.results: + bbox = obj['bbox'] # Bounding box for person/face + label = obj['label'] # "person" or "face" + age = obj.get('age', None) + gender = obj.get('gender', None) + + print(f"Detected {label} at {bbox} with age: {age}, gender: {gender}") + + out_im = None + if self.draw: + out_im = detected_objects.plot() + + return detected_objects, out_im, age + + + def recognize_video(self, source: str) -> Generator: + video_capture = cv2.VideoCapture(source) + if not video_capture.isOpened(): + raise ValueError(f"Failed to open video source {source}") + + detected_objects_history: Dict[int, List[AGE_GENDER_TYPE]] = defaultdict(list) + + total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) + for _ in tqdm.tqdm(range(total_frames)): + ret, frame = video_capture.read() + if not ret: + break + + detected_objects: PersonAndFaceResult = self.detector.track(frame) + self.age_gender_model.predict(frame, detected_objects) + + current_frame_objs = detected_objects.get_results_for_tracking() + cur_persons: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[0] + cur_faces: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[1] + + # add tr_persons and tr_faces to history + for guid, data in cur_persons.items(): + # not useful for tracking :) + if None not in data: + detected_objects_history[guid].append(data) + for guid, data in cur_faces.items(): + if None not in data: + detected_objects_history[guid].append(data) + + detected_objects.set_tracked_age_gender(detected_objects_history) + if self.draw: + frame = detected_objects.plot() + yield detected_objects_history, frame diff --git a/age_estimator/mivolo/mivolo/structures.py b/age_estimator/mivolo/mivolo/structures.py new file mode 100644 index 0000000000000000000000000000000000000000..45277f6b770d7577f2459bcce679dd98dc813095 --- /dev/null +++ b/age_estimator/mivolo/mivolo/structures.py @@ -0,0 +1,478 @@ +import math +import os +from copy import deepcopy +from typing import Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import torch +from mivolo.data.misc import aggregate_votes_winsorized, assign_faces, box_iou +from ultralytics.engine.results import Results +from ultralytics.utils.plotting import Annotator, colors + +# because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing +os.unsetenv("CUBLAS_WORKSPACE_CONFIG") + +AGE_GENDER_TYPE = Tuple[float, str] + + +class PersonAndFaceCrops: + def __init__(self): + # int: index of person along results + self.crops_persons: Dict[int, np.ndarray] = {} + + # int: index of face along results + self.crops_faces: Dict[int, np.ndarray] = {} + + # int: index of face along results + self.crops_faces_wo_body: Dict[int, np.ndarray] = {} + + # int: index of person along results + self.crops_persons_wo_face: Dict[int, np.ndarray] = {} + + def _add_to_output( + self, crops: Dict[int, np.ndarray], out_crops: List[np.ndarray], out_crop_inds: List[Optional[int]] + ): + inds_to_add = list(crops.keys()) + crops_to_add = list(crops.values()) + out_crops.extend(crops_to_add) + out_crop_inds.extend(inds_to_add) + + def _get_all_faces( + self, use_persons: bool, use_faces: bool + ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]: + """ + Returns + if use_persons and use_faces + faces: faces_with_bodies + faces_without_bodies + [None] * len(crops_persons_wo_face) + if use_persons and not use_faces + faces: [None] * n_persons + if not use_persons and use_faces: + faces: faces_with_bodies + faces_without_bodies + """ + + def add_none_to_output(faces_inds, faces_crops, num): + faces_inds.extend([None for _ in range(num)]) + faces_crops.extend([None for _ in range(num)]) + + faces_inds: List[Optional[int]] = [] + faces_crops: List[Optional[np.ndarray]] = [] + + if not use_faces: + add_none_to_output(faces_inds, faces_crops, len(self.crops_persons) + len(self.crops_persons_wo_face)) + return faces_inds, faces_crops + + self._add_to_output(self.crops_faces, faces_crops, faces_inds) + self._add_to_output(self.crops_faces_wo_body, faces_crops, faces_inds) + + if use_persons: + add_none_to_output(faces_inds, faces_crops, len(self.crops_persons_wo_face)) + + return faces_inds, faces_crops + + def _get_all_bodies( + self, use_persons: bool, use_faces: bool + ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]: + """ + Returns + if use_persons and use_faces + persons: bodies_with_faces + [None] * len(faces_without_bodies) + bodies_without_faces + if use_persons and not use_faces + persons: bodies_with_faces + bodies_without_faces + if not use_persons and use_faces + persons: [None] * n_faces + """ + + def add_none_to_output(bodies_inds, bodies_crops, num): + bodies_inds.extend([None for _ in range(num)]) + bodies_crops.extend([None for _ in range(num)]) + + bodies_inds: List[Optional[int]] = [] + bodies_crops: List[Optional[np.ndarray]] = [] + + if not use_persons: + add_none_to_output(bodies_inds, bodies_crops, len(self.crops_faces) + len(self.crops_faces_wo_body)) + return bodies_inds, bodies_crops + + self._add_to_output(self.crops_persons, bodies_crops, bodies_inds) + if use_faces: + add_none_to_output(bodies_inds, bodies_crops, len(self.crops_faces_wo_body)) + + self._add_to_output(self.crops_persons_wo_face, bodies_crops, bodies_inds) + + return bodies_inds, bodies_crops + + def get_faces_with_bodies(self, use_persons: bool, use_faces: bool): + """ + Return + faces: faces_with_bodies, faces_without_bodies, [None] * len(crops_persons_wo_face) + persons: bodies_with_faces, [None] * len(faces_without_bodies), bodies_without_faces + """ + + bodies_inds, bodies_crops = self._get_all_bodies(use_persons, use_faces) + faces_inds, faces_crops = self._get_all_faces(use_persons, use_faces) + + return (bodies_inds, bodies_crops), (faces_inds, faces_crops) + + def save(self, out_dir="output"): + ind = 0 + os.makedirs(out_dir, exist_ok=True) + for crops in [self.crops_persons, self.crops_faces, self.crops_faces_wo_body, self.crops_persons_wo_face]: + for crop in crops.values(): + if crop is None: + continue + out_name = os.path.join(out_dir, f"{ind}_crop.jpg") + cv2.imwrite(out_name, crop) + ind += 1 + + + + +class PersonAndFaceResult: + def __init__(self, results: Results): + + self.yolo_results = results + names = set(results.names.values()) + assert "person" in names and "face" in names + + # initially no faces and persons are associated to each other + self.face_to_person_map: Dict[int, Optional[int]] = {ind: None for ind in self.get_bboxes_inds("face")} + self.unassigned_persons_inds: List[int] = self.get_bboxes_inds("person") + n_objects = len(self.yolo_results.boxes) + self.ages: List[Optional[float]] = [None for _ in range(n_objects)] + self.genders: List[Optional[str]] = [None for _ in range(n_objects)] + self.gender_scores: List[Optional[float]] = [None for _ in range(n_objects)] + + @property + def n_objects(self) -> int: + return len(self.yolo_results.boxes) + + @property + def n_faces(self) -> int: + return len(self.get_bboxes_inds("face")) + + @property + def n_persons(self) -> int: + return len(self.get_bboxes_inds("person")) + + def get_bboxes_inds(self, category: str) -> List[int]: + bboxes: List[int] = [] + for ind, det in enumerate(self.yolo_results.boxes): + name = self.yolo_results.names[int(det.cls)] + if name == category: + bboxes.append(ind) + + return bboxes + + def get_distance_to_center(self, bbox_ind: int) -> float: + """ + Calculate euclidian distance between bbox center and image center. + """ + im_h, im_w = self.yolo_results[bbox_ind].orig_shape + x1, y1, x2, y2 = self.get_bbox_by_ind(bbox_ind).cpu().numpy() + center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2 + dist = math.dist([center_x, center_y], [im_w / 2, im_h / 2]) + return dist + + def plot( + self, + conf=False, + line_width=None, + font_size=None, + font="Arial.ttf", + pil=False, + img=None, + labels=True, + boxes=True, + probs=True, + ages=True, + genders=True, + gender_probs=False, + ): + """ + Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image. + Args: + conf (bool): Whether to plot the detection confidence score. + line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size. + font_size (float, optional): The font size of the text. If None, it is scaled to the image size. + font (str): The font to use for the text. + pil (bool): Whether to return the image as a PIL Image. + img (numpy.ndarray): Plot to another image. if not, plot to original image. + labels (bool): Whether to plot the label of bounding boxes. + boxes (bool): Whether to plot the bounding boxes. + probs (bool): Whether to plot classification probability + ages (bool): Whether to plot the age of bounding boxes. + genders (bool): Whether to plot the genders of bounding boxes. + gender_probs (bool): Whether to plot gender classification probability + Returns: + (numpy.ndarray): A numpy array of the annotated image. + """ + + # return self.yolo_results.plot() + colors_by_ind = {} + for face_ind, person_ind in self.face_to_person_map.items(): + if person_ind is not None: + colors_by_ind[face_ind] = face_ind + 2 + colors_by_ind[person_ind] = face_ind + 2 + else: + colors_by_ind[face_ind] = 0 + for person_ind in self.unassigned_persons_inds: + colors_by_ind[person_ind] = 1 + + names = self.yolo_results.names + annotator = Annotator( + deepcopy(self.yolo_results.orig_img if img is None else img), + line_width, + font_size, + font, + pil, + example=names, + ) + pred_boxes, show_boxes = self.yolo_results.boxes, boxes + pred_probs, show_probs = self.yolo_results.probs, probs + + if pred_boxes and show_boxes: + for bb_ind, (d, age, gender, gender_score) in enumerate( + zip(pred_boxes, self.ages, self.genders, self.gender_scores) + ): + c, conf, guid = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item()) + name = ("" if guid is None else f"id:{guid} ") + names[c] + label = (f"{name} {conf:.2f}" if conf else name) if labels else None + if ages and age is not None: + label += f" {age:.1f}" + if genders and gender is not None: + label += f" {'F' if gender == 'female' else 'M'}" + if gender_probs and gender_score is not None: + label += f" ({gender_score:.1f})" + annotator.box_label(d.xyxy.squeeze(), label, color=colors(colors_by_ind[bb_ind], True)) + + if pred_probs is not None and show_probs: + text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, " + annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors + + return annotator.result() + + def set_tracked_age_gender(self, tracked_objects: Dict[int, List[AGE_GENDER_TYPE]]): + """ + Update age and gender for objects based on history from tracked_objects. + Args: + tracked_objects (dict[int, list[AGE_GENDER_TYPE]]): info about tracked objects by guid + """ + + for face_ind, person_ind in self.face_to_person_map.items(): + pguid = self._get_id_by_ind(person_ind) + fguid = self._get_id_by_ind(face_ind) + + if fguid == -1 and pguid == -1: + # YOLO might not assign ids for some objects in some cases: + # https://github.com/ultralytics/ultralytics/issues/3830 + continue + age, gender = self._gather_tracking_result(tracked_objects, fguid, pguid) + if age is None or gender is None: + continue + self.set_age(face_ind, age) + self.set_gender(face_ind, gender, 1.0) + if pguid != -1: + self.set_gender(person_ind, gender, 1.0) + self.set_age(person_ind, age) + + for person_ind in self.unassigned_persons_inds: + pid = self._get_id_by_ind(person_ind) + if pid == -1: + continue + age, gender = self._gather_tracking_result(tracked_objects, -1, pid) + if age is None or gender is None: + continue + self.set_gender(person_ind, gender, 1.0) + self.set_age(person_ind, age) + + def _get_id_by_ind(self, ind: Optional[int] = None) -> int: + if ind is None: + return -1 + obj_id = self.yolo_results.boxes[ind].id + if obj_id is None: + return -1 + return obj_id.item() + + def get_bbox_by_ind(self, ind: int, im_h: int = None, im_w: int = None) -> torch.tensor: + bb = self.yolo_results.boxes[ind].xyxy.squeeze().type(torch.int32) + if im_h is not None and im_w is not None: + bb[0] = torch.clamp(bb[0], min=0, max=im_w - 1) + bb[1] = torch.clamp(bb[1], min=0, max=im_h - 1) + bb[2] = torch.clamp(bb[2], min=0, max=im_w - 1) + bb[3] = torch.clamp(bb[3], min=0, max=im_h - 1) + return bb + + def set_age(self, ind: Optional[int], age: float): + if ind is not None: + self.ages[ind] = age + + def set_gender(self, ind: Optional[int], gender: str, gender_score: float): + if ind is not None: + self.genders[ind] = gender + self.gender_scores[ind] = gender_score + + @staticmethod + def _gather_tracking_result( + tracked_objects: Dict[int, List[AGE_GENDER_TYPE]], + fguid: int = -1, + pguid: int = -1, + minimum_sample_size: int = 10, + ) -> AGE_GENDER_TYPE: + + assert fguid != -1 or pguid != -1, "Incorrect tracking behaviour" + + face_ages = [r[0] for r in tracked_objects[fguid] if r[0] is not None] if fguid in tracked_objects else [] + face_genders = [r[1] for r in tracked_objects[fguid] if r[1] is not None] if fguid in tracked_objects else [] + person_ages = [r[0] for r in tracked_objects[pguid] if r[0] is not None] if pguid in tracked_objects else [] + person_genders = [r[1] for r in tracked_objects[pguid] if r[1] is not None] if pguid in tracked_objects else [] + + if not face_ages and not person_ages: # both empty + return None, None + + # You can play here with different aggregation strategies + # Face ages - predictions based on face or face + person, depends on history of object + # Person ages - predictions based on person or face + person, depends on history of object + + if len(person_ages + face_ages) >= minimum_sample_size: + age = aggregate_votes_winsorized(person_ages + face_ages) + else: + face_age = np.mean(face_ages) if face_ages else None + person_age = np.mean(person_ages) if person_ages else None + if face_age is None: + face_age = person_age + if person_age is None: + person_age = face_age + age = (face_age + person_age) / 2.0 + + genders = face_genders + person_genders + assert len(genders) > 0 + # take mode of genders + gender = max(set(genders), key=genders.count) + + return age, gender + + def get_results_for_tracking(self) -> Tuple[Dict[int, AGE_GENDER_TYPE], Dict[int, AGE_GENDER_TYPE]]: + """ + Get objects from current frame + """ + persons: Dict[int, AGE_GENDER_TYPE] = {} + faces: Dict[int, AGE_GENDER_TYPE] = {} + + names = self.yolo_results.names + pred_boxes = self.yolo_results.boxes + for _, (det, age, gender, _) in enumerate(zip(pred_boxes, self.ages, self.genders, self.gender_scores)): + if det.id is None: + continue + cat_id, _, guid = int(det.cls), float(det.conf), int(det.id.item()) + name = names[cat_id] + if name == "person": + persons[guid] = (age, gender) + elif name == "face": + faces[guid] = (age, gender) + + return persons, faces + + def associate_faces_with_persons(self): + face_bboxes_inds: List[int] = self.get_bboxes_inds("face") + person_bboxes_inds: List[int] = self.get_bboxes_inds("person") + + face_bboxes: List[torch.tensor] = [self.get_bbox_by_ind(ind) for ind in face_bboxes_inds] + person_bboxes: List[torch.tensor] = [self.get_bbox_by_ind(ind) for ind in person_bboxes_inds] + + self.face_to_person_map = {ind: None for ind in face_bboxes_inds} + assigned_faces, unassigned_persons_inds = assign_faces(person_bboxes, face_bboxes) + + for face_ind, person_ind in enumerate(assigned_faces): + face_ind = face_bboxes_inds[face_ind] + person_ind = person_bboxes_inds[person_ind] if person_ind is not None else None + self.face_to_person_map[face_ind] = person_ind + + self.unassigned_persons_inds = [person_bboxes_inds[person_ind] for person_ind in unassigned_persons_inds] + + def crop_object( + self, full_image: np.ndarray, ind: int, cut_other_classes: Optional[List[str]] = None + ) -> Optional[np.ndarray]: + + IOU_THRESH = 0.000001 + MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4 + CROP_ROUND_RATE = 0.3 + MIN_PERSON_SIZE = 50 + + obj_bbox = self.get_bbox_by_ind(ind, *full_image.shape[:2]) + x1, y1, x2, y2 = obj_bbox + cur_cat = self.yolo_results.names[int(self.yolo_results.boxes[ind].cls)] + # get crop of face or person + obj_image = full_image[y1:y2, x1:x2].copy() + crop_h, crop_w = obj_image.shape[:2] + + if cur_cat == "person" and (crop_h < MIN_PERSON_SIZE or crop_w < MIN_PERSON_SIZE): + return None + + if not cut_other_classes: + return obj_image + + # calc iou between obj_bbox and other bboxes + other_bboxes: List[torch.tensor] = [ + self.get_bbox_by_ind(other_ind, *full_image.shape[:2]) for other_ind in range(len(self.yolo_results.boxes)) + ] + + iou_matrix = box_iou(torch.stack([obj_bbox]), torch.stack(other_bboxes)).cpu().numpy()[0] + + # cut out other objects in case of intersection + for other_ind, (det, iou) in enumerate(zip(self.yolo_results.boxes, iou_matrix)): + other_cat = self.yolo_results.names[int(det.cls)] + if ind == other_ind or iou < IOU_THRESH or other_cat not in cut_other_classes: + continue + o_x1, o_y1, o_x2, o_y2 = det.xyxy.squeeze().type(torch.int32) + + # remap current_person_bbox to reference_person_bbox coordinates + o_x1 = max(o_x1 - x1, 0) + o_y1 = max(o_y1 - y1, 0) + o_x2 = min(o_x2 - x1, crop_w) + o_y2 = min(o_y2 - y1, crop_h) + + if other_cat != "face": + if (o_y1 / crop_h) < CROP_ROUND_RATE: + o_y1 = 0 + if ((crop_h - o_y2) / crop_h) < CROP_ROUND_RATE: + o_y2 = crop_h + if (o_x1 / crop_w) < CROP_ROUND_RATE: + o_x1 = 0 + if ((crop_w - o_x2) / crop_w) < CROP_ROUND_RATE: + o_x2 = crop_w + + obj_image[o_y1:o_y2, o_x1:o_x2] = 0 + + remain_ratio = np.count_nonzero(obj_image) / (obj_image.shape[0] * obj_image.shape[1] * obj_image.shape[2]) + if remain_ratio < MIN_PERSON_CROP_AFTERCUT_RATIO: + return None + + return obj_image + + def collect_crops(self, image) -> PersonAndFaceCrops: + + crops_data = PersonAndFaceCrops() + for face_ind, person_ind in self.face_to_person_map.items(): + face_image = self.crop_object(image, face_ind, cut_other_classes=[]) + + if person_ind is None: + crops_data.crops_faces_wo_body[face_ind] = face_image + continue + + person_image = self.crop_object(image, person_ind, cut_other_classes=["face", "person"]) + + crops_data.crops_faces[face_ind] = face_image + crops_data.crops_persons[person_ind] = person_image + + for person_ind in self.unassigned_persons_inds: + person_image = self.crop_object(image, person_ind, cut_other_classes=["face", "person"]) + crops_data.crops_persons_wo_face[person_ind] = person_image + + # uncomment to save preprocessed crops + # crops_data.save() + return crops_data + + def get_ages(self): + # Assuming ages are stored in some internal attribute like `_ages` + return self.ages # Replace with the correct attribute where ages are stored diff --git a/age_estimator/mivolo/mivolo/version.py b/age_estimator/mivolo/mivolo/version.py new file mode 100644 index 0000000000000000000000000000000000000000..cc321393ba68aa2bcb2b8ab828e95c89ddce7dc3 --- /dev/null +++ b/age_estimator/mivolo/mivolo/version.py @@ -0,0 +1 @@ +__version__ = "0.6.0dev" diff --git a/age_estimator/mivolo/models/model_imdb_cross_person_4.22_99.46.pth.tar b/age_estimator/mivolo/models/model_imdb_cross_person_4.22_99.46.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..e62bb76ed79b900b87e7e752c0223d45d40c9345 --- /dev/null +++ b/age_estimator/mivolo/models/model_imdb_cross_person_4.22_99.46.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc279b6914b3ee8be6a58139c06ecb24ca95751233cf6c07804b93184614eb17 +size 109777437 diff --git a/age_estimator/mivolo/models/yolov8x_person_face.pt b/age_estimator/mivolo/models/yolov8x_person_face.pt new file mode 100644 index 0000000000000000000000000000000000000000..c4fea0e2a9545fe4db35fc88ad60804e7209da8b --- /dev/null +++ b/age_estimator/mivolo/models/yolov8x_person_face.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2620f45609a65f909eb876bd7401308b5a8f3843ad5a03cb7416066a3e492989 +size 136716488 diff --git a/age_estimator/mivolo/mypy.ini b/age_estimator/mivolo/mypy.ini new file mode 100644 index 0000000000000000000000000000000000000000..c59a7d2d5eccf1b217f2c56c1f5b533d26ce5b10 --- /dev/null +++ b/age_estimator/mivolo/mypy.ini @@ -0,0 +1,5 @@ +[mypy] +python_version = 3.8 +no_strict_optional = True +ignore_missing_imports = True +disallow_any_unimported = False diff --git a/age_estimator/mivolo/requirements.txt b/age_estimator/mivolo/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4e30b002db2c8d41a750c03bdd99801fd54a7200 --- /dev/null +++ b/age_estimator/mivolo/requirements.txt @@ -0,0 +1,5 @@ +huggingface_hub +ultralytics==8.1.0 +timm==0.8.13.dev0 +yt_dlp +lapx>=0.5.2 diff --git a/age_estimator/mivolo/scripts/inference.sh b/age_estimator/mivolo/scripts/inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..45cd54633b1adc09c94504724ae27c521dd42c1d --- /dev/null +++ b/age_estimator/mivolo/scripts/inference.sh @@ -0,0 +1,18 @@ + +python3 demo.py \ +--input "jennifer_lawrence.jpg" \ +--output "output" \ +--detector-weights "pretrained/yolov8x_person_face.pt" \ +--checkpoint "pretrained/checkpoint-377.pth.tar" \ +--device "cuda:0" \ +--draw \ +--with-persons + +python3 demo.py \ +--input "https://www.youtube.com/shorts/pVh32k0hGEI" \ +--output "output" \ +--detector-weights "pretrained/yolov8x_person_face.pt" \ +--checkpoint "pretrained/checkpoint-377.pth.tar" \ +--device "cuda:0" \ +--draw \ +--with-persons diff --git a/age_estimator/mivolo/scripts/valid_age_gender.sh b/age_estimator/mivolo/scripts/valid_age_gender.sh new file mode 100644 index 0000000000000000000000000000000000000000..b2edb722909e2a20228613fbe0e35760d3f0157b --- /dev/null +++ b/age_estimator/mivolo/scripts/valid_age_gender.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash + +# inference utk +python3 eval_pretrained.py \ + --dataset_images data/utk/images \ + --dataset_annotations data/utk/annotation \ + --dataset_name utk \ + --batch-size 512 \ + --checkpoint pretrained/model_imdb_cross_person_4.24_99.46.pth.tar \ + --split valid \ + --half \ + --with-persons \ + --device "cuda:0" + +# inference fairface +python3 eval_pretrained.py \ + --dataset_images data/FairFace/fairface-img-margin125-trainval \ + --dataset_annotations data/FairFace/annotations \ + --dataset_name fairface \ + --batch-size 512 \ + --checkpoint pretrained/model_imdb_cross_person_4.24_99.46.pth.tar \ + --split val \ + --half \ + --with-persons \ + --device "cuda:0" + +# inference adience +python3 eval_pretrained.py \ + --dataset_images data/adience/faces \ + --dataset_annotations data/adience/annotations \ + --dataset_name adience \ + --batch-size 512 \ + --checkpoint pretrained/model_imdb_cross_person_4.24_99.46.pth.tar \ + --split adience \ + --half \ + --with-persons \ + --device "cuda:0" + +# inference agedb +python3 eval_pretrained.py \ + --dataset_images data/agedb/AgeDB \ + --dataset_annotations data/agedb/annotation \ + --dataset_name agedb \ + --batch-size 512 \ + --checkpoint pretrained/model_imdb_cross_person_4.24_99.46.pth.tar \ + --split 0,1,2,3,4,5,6,7,8,9 \ + --half \ + --device "cuda:0" diff --git a/age_estimator/mivolo/setup.cfg b/age_estimator/mivolo/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..2cde6a494836f93681b73aaa7bcf0d0d487de469 --- /dev/null +++ b/age_estimator/mivolo/setup.cfg @@ -0,0 +1,56 @@ +# Project-wide configuration file, can be used for package metadata and other toll configurations +# Example usage: global configuration for PEP8 (via flake8) setting or default pytest arguments +# Local usage: pip install pre-commit, pre-commit run --all-files + +[metadata] +license_files = LICENSE +description_file = README.md + +[tool:pytest] +norecursedirs = + .git + dist + build +addopts = + --doctest-modules + --durations=25 + --color=yes + +[flake8] +max-line-length = 120 +exclude = .tox,*.egg,build,temp +select = E,W,F +doctests = True +verbose = 2 +# https://pep8.readthedocs.io/en/latest/intro.html#error-codes +format = pylint +# see: https://www.flake8rules.com/ +ignore = E731,F405,E402,W504,E501 + # E731: Do not assign a lambda expression, use a def + # F405: name may be undefined, or defined from star imports: module + # E402: module level import not at top of file + # W504: line break after binary operator + # E501: line too long + # removed: + # F401: module imported but unused + # E231: missing whitespace after ‘,’, ‘;’, or ‘:’ + # E127: continuation line over-indented for visual indent + # F403: ‘from module import *’ used; unable to detect undefined names + + +[isort] +# https://pycqa.github.io/isort/docs/configuration/options.html +line_length = 120 +# see: https://pycqa.github.io/isort/docs/configuration/multi_line_output_modes.html +multi_line_output = 0 + +[yapf] +based_on_style = pep8 +spaces_before_comment = 2 +COLUMN_LIMIT = 120 +COALESCE_BRACKETS = True +SPACES_AROUND_POWER_OPERATOR = True +SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = True +SPLIT_BEFORE_CLOSING_BRACKET = False +SPLIT_BEFORE_FIRST_ARGUMENT = False +# EACH_DICT_ENTRY_ON_SEPARATE_LINE = False diff --git a/age_estimator/mivolo/setup.py b/age_estimator/mivolo/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a8bdbd35873b515f074af626a68ef51fda785fc8 --- /dev/null +++ b/age_estimator/mivolo/setup.py @@ -0,0 +1,50 @@ +# MiVOLO 🚀, Attribution-ShareAlike 4.0 + +from pathlib import Path + +import pkg_resources as pkg +from setuptools import find_packages, setup + +# Settings +FILE = Path(__file__).resolve() +PARENT = FILE.parent # root directory +README = (PARENT / "README.md").read_text(encoding="utf-8") +REQUIREMENTS = [f"{x.name}{x.specifier}" for x in pkg.parse_requirements((PARENT / "requirements.txt").read_text())] + + +exec(open("mivolo/version.py").read()) +setup( + name="mivolo", # name of pypi package + version=__version__, # version of pypi package # noqa: F821 + python_requires=">=3.8", + description="Layer MiVOLO for SOTA age and gender recognition", + long_description=README, + long_description_content_type="text/markdown", + url="https://github.com/WildChlamydia/MiVOLO", + project_urls={"Datasets": "https://wildchlamydia.github.io/lagenda/"}, + author="Layer Team, SberDevices", + author_email="mvkuprashevich@gmail.com, irinakr4snova@gmail.com", + packages=find_packages(include=["mivolo", "mivolo.model", "mivolo.data", "mivolo.data.dataset"]), # required + include_package_data=True, + install_requires=REQUIREMENTS, + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: Attribution-ShareAlike 4.0", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Software Development", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Image Recognition", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS", + "Operating System :: Microsoft :: Windows", + ], + keywords="machine-learning, deep-learning, vision, ML, DL, AI, transformer, mivolo", +) diff --git a/age_estimator/mivolo/tools/dataset_visualization.py b/age_estimator/mivolo/tools/dataset_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..08e0452f70ae3893c55e7a54408e6d4cb8f4b038 --- /dev/null +++ b/age_estimator/mivolo/tools/dataset_visualization.py @@ -0,0 +1,44 @@ +import argparse +from typing import Dict, List + +import cv2 +from mivolo.data.data_reader import PictureInfo, read_csv_annotation_file +from ultralytics.yolo.utils.plotting import Annotator, colors + + +def get_parser(): + parser = argparse.ArgumentParser(description="Visualization") + parser.add_argument("--dataset_images", default="", type=str, required=True, help="path to images") + parser.add_argument("--annotation_file", default="", type=str, required=True, help="path to annotations") + + return parser + + +def visualize(images_dir, new_annotation_file): + + bboxes_per_image: Dict[str, List[PictureInfo]] = read_csv_annotation_file(new_annotation_file, images_dir)[0] + print(f"Found {len(bboxes_per_image)} unique images") + + for image_path, bboxes in bboxes_per_image.items(): + im_cv = cv2.imread(image_path) + annotator = Annotator(im_cv) + + for i, bbox_info in enumerate(bboxes): + label = f"{bbox_info.gender} Age: {bbox_info.age}" + if any(coord != -1 for coord in bbox_info.bbox): + # draw face bbox if exist + annotator.box_label(bbox_info.bbox, label, color=colors(i, True)) + + if any(coord != -1 for coord in bbox_info.person_bbox): + # draw person bbox if exist + annotator.box_label(bbox_info.person_bbox, "p " + label, color=colors(i, True)) + + im_cv = annotator.result() + cv2.imshow("image", im_cv) + cv2.waitKey(0) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + visualize(args.dataset_images, args.annotation_file) diff --git a/age_estimator/mivolo/tools/download_lagenda.py b/age_estimator/mivolo/tools/download_lagenda.py new file mode 100644 index 0000000000000000000000000000000000000000..4130f9b9f1c8a722f64dcb8e36354bbb92802d61 --- /dev/null +++ b/age_estimator/mivolo/tools/download_lagenda.py @@ -0,0 +1,30 @@ +import os +import zipfile + +# pip install gdown +import gdown + +if __name__ == "__main__": + print("Download LAGENDA Age Gender Dataset... ") + out_dir = "LAGENDA" + os.makedirs(out_dir, exist_ok=True) + + ids = ["1QXO0NlkABPZT6x1_0Uc2i6KAtdcrpTbG", "1mNYjYFb3MuKg-OL1UISoYsKObMUllbJx"] + dests = [f"{out_dir}/lagenda_benchmark_images.zip", f"{out_dir}/lagenda_annotation.csv"] + + for file_id, destination in zip(ids, dests): + url = f"https://drive.google.com/uc?id={file_id}" + gdown.download(url, destination, quiet=False) + + if not os.path.exists(destination): + print(f"ERROR: Can not download {destination}") + continue + + if os.path.basename(destination).split(".")[-1] != ".zip": + continue + + print(f"Extracting {destination} ... ") + with zipfile.ZipFile(destination) as zf: + zip_dir = zf.namelist()[0] + zf.extractall(f"./{out_dir}/") + os.remove(destination) diff --git a/age_estimator/mivolo/tools/preparation_utils.py b/age_estimator/mivolo/tools/preparation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aa769f028ef95019f4af9348f8d49d92450f6c5c --- /dev/null +++ b/age_estimator/mivolo/tools/preparation_utils.py @@ -0,0 +1,140 @@ +from typing import Dict, List, Optional, Tuple + +import pandas as pd +import torch +from mivolo.data.data_reader import PictureInfo +from mivolo.data.misc import assign_faces, box_iou +from mivolo.model.yolo_detector import PersonAndFaceResult + + +def save_annotations(images: List[PictureInfo], images_dir: str, out_file: str): + def get_age_str(age: Optional[str]) -> str: + age = "-1" if age is None else age.replace("(", "").replace(")", "").replace(" ", "").replace(",", ";") + return age + + def get_gender_str(gender: Optional[str]) -> str: + gender = "-1" if gender is None else gender + return gender + + headers = [ + "img_name", + "age", + "gender", + "face_x0", + "face_y0", + "face_x1", + "face_y1", + "person_x0", + "person_y0", + "person_x1", + "person_y1", + ] + output_data = [] + for image_info in images: + relative_image_path = image_info.image_path.replace(f"{images_dir}/", "") + face_x0, face_y0, face_x1, face_y1 = image_info.bbox + p_x0, p_y0, p_x1, p_y1 = image_info.person_bbox + output_data.append( + { + "img_name": relative_image_path, + "age": get_age_str(image_info.age), + "gender": get_gender_str(image_info.gender), + "face_x0": face_x0, + "face_y0": face_y0, + "face_x1": face_x1, + "face_y1": face_y1, + "person_x0": p_x0, + "person_y0": p_y0, + "person_x1": p_x1, + "person_y1": p_y1, + } + ) + output_df = pd.DataFrame(output_data, columns=headers) + output_df.to_csv(out_file, sep=",", index=False) + print(f"Saved annotations for {len(images)} images to {out_file}") + + +def get_main_face( + detected_objects: PersonAndFaceResult, coarse_bbox: Optional[List[int]] = None, coarse_thresh: float = 0.2 +) -> Tuple[Optional[List[int]], List[int]]: + """ + Returns: + main_bbox (Optional[List[int]]): The most cenetered face bbox + other_bboxes (List[int]): indexes of other faces + """ + face_bboxes_inds: List[int] = detected_objects.get_bboxes_inds("face") + if len(face_bboxes_inds) == 0: + return None, [] + + # sort found faces + face_bboxes_inds = sorted(face_bboxes_inds, key=lambda bb_ind: detected_objects.get_distance_to_center(bb_ind)) + most_centered_bbox_ind = face_bboxes_inds[0] + main_bbox = detected_objects.get_bbox_by_ind(most_centered_bbox_ind).cpu().numpy().tolist() + + iou_matrix: List[float] = [1.0] + [0.0] * (len(face_bboxes_inds) - 1) + + if coarse_bbox is not None: + # calc iou between coarse_bbox and found bboxes + found_bboxes: List[torch.tensor] = [ + detected_objects.get_bbox_by_ind(other_ind) for other_ind in face_bboxes_inds + ] + + iou_matrix = ( + box_iou(torch.stack([torch.tensor(coarse_bbox)]), torch.stack(found_bboxes).cpu()).numpy()[0].tolist() + ) + + if iou_matrix[0] < coarse_thresh: + # to avoid fp detections + main_bbox = None + other_bboxes = [ind for i, ind in enumerate(face_bboxes_inds[1:]) if iou_matrix[i] < coarse_thresh] + else: + other_bboxes = face_bboxes_inds[1:] + + return main_bbox, other_bboxes + + +def get_additional_bboxes( + detected_objects: PersonAndFaceResult, other_bboxes_inds: List[int], image_path: str, **kwargs +) -> List[PictureInfo]: + is_face = True if "is_person" not in kwargs else False + is_person = False if "is_person" not in kwargs else True + + additional_data: List[PictureInfo] = [] + # extend other faces + for other_ind in other_bboxes_inds: + other_box: List[int] = detected_objects.get_bbox_by_ind(other_ind).cpu().numpy().tolist() + if is_face: + additional_data.append(PictureInfo(image_path, None, None, other_box)) + elif is_person: + additional_data.append(PictureInfo(image_path, None, None, person_bbox=other_box)) + return additional_data + + +def associate_persons(face_bboxes: List[torch.tensor], detected_objects: PersonAndFaceResult): + person_bboxes_inds: List[int] = detected_objects.get_bboxes_inds("person") + person_bboxes: List[torch.tensor] = [detected_objects.get_bbox_by_ind(ind) for ind in person_bboxes_inds] + + face_to_person_map: Dict[int, Optional[int]] = {ind: None for ind in range(len(face_bboxes))} + + if len(person_bboxes) == 0: + return face_to_person_map, [] + + assigned_faces, unassigned_persons_inds = assign_faces(person_bboxes, face_bboxes) + + for face_ind, person_ind in enumerate(assigned_faces): + person_ind = person_bboxes_inds[person_ind] if person_ind is not None else None + face_to_person_map[face_ind] = person_ind + + unassigned_persons_inds = [person_bboxes_inds[person_ind] for person_ind in unassigned_persons_inds] + return face_to_person_map, unassigned_persons_inds + + +def assign_persons( + faces_info: List[PictureInfo], faces_persons_map: Dict[int, int], detected_objects: PersonAndFaceResult +): + for face_ind, person_ind in faces_persons_map.items(): + if person_ind is None: + continue + + person_bbox = detected_objects.get_bbox_by_ind(person_ind).cpu().numpy().tolist() + faces_info[face_ind].person_bbox = person_bbox diff --git a/age_estimator/mivolo/tools/prepare_adience.py b/age_estimator/mivolo/tools/prepare_adience.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e20788c2b6647d8c2b98875f65c01b52843569 --- /dev/null +++ b/age_estimator/mivolo/tools/prepare_adience.py @@ -0,0 +1,216 @@ +import argparse +import os +from collections import defaultdict +from typing import Dict, List, Optional + +import cv2 +import pandas as pd +import tqdm +from mivolo.data.data_reader import PictureInfo, get_all_files +from mivolo.model.yolo_detector import Detector, PersonAndFaceResult +from preparation_utils import get_additional_bboxes, get_main_face, save_annotations + + +def read_adience_annotations(annotations_files): + annotations_per_image = {} + stat_per_fold = defaultdict(int) + cols = ["user_id", "original_image", "face_id", "age", "gender"] + for file in annotations_files: + fold_name = os.path.basename(file).split(".")[0] + df = pd.read_csv(file, sep="\t", usecols=cols) + for index, row in df.iterrows(): + face_id, img_name, user_id = row["face_id"], row["original_image"], row["user_id"] + aligned_face_path = f"faces/{user_id}/coarse_tilt_aligned_face.{face_id}.{img_name}" + + age, gender = row["age"], row["gender"] + gender = gender.upper() if isinstance(gender, str) and gender != "u" else None + age = age if isinstance(age, str) else None + + annotations_per_image[aligned_face_path] = {"age": age, "gender": gender, "fold": fold_name} + stat_per_fold[fold_name] += 1 + + print(f"Per fold images: {stat_per_fold}") + return annotations_per_image + + +def read_data(images_dir, annotations_files, data_dir) -> List[PictureInfo]: + dataset_pictures: List[PictureInfo] = [] + + all_images = get_all_files(images_dir) + annotations_per_file = read_adience_annotations(annotations_files) + + total, missed = 0, 0 + stat_per_gender: Dict[str, int] = defaultdict(int) + missed_gender, missed_age, missed_gender_and_age = 0, 0, 0 + stat_per_ages: Dict[str, int] = defaultdict(int) + + # final age classes: '0;2', "4;6", "8;12", "15;20", "25;32", "38;43", "48;53", "60;100" + + age_map = { + "2": "(0, 2)", + "3": "(0, 2)", + "13": "(8, 12)", + "(8, 23)": "(8, 12)", + "22": "(15, 20)", + "23": "(25, 32)", + "29": "(25, 32)", + "(27, 32)": "(25, 32)", + "32": "(25, 32)", + "34": "(25, 32)", + "35": "(25, 32)", + "36": "(38, 43)", + "(38, 42)": "(38, 43)", + "(38, 48)": "(38, 43)", + "42": "(38, 43)", + "45": "(38, 43)", + "46": "(48, 53)", + "55": "(48, 53)", + "56": "(48, 53)", + "57": "(60, 100)", + "58": "(60, 100)", + } + for image_path in all_images: + total += 1 + relative_path = image_path.replace(f"{data_dir}/", "") + if relative_path not in annotations_per_file: + missed += 1 + print("Can not find annotation for ", relative_path) + else: + annot = annotations_per_file[relative_path] + age, gender = annot["age"], annot["gender"] + + if gender is None and age is not None: + missed_gender += 1 + elif age is None and gender is not None: + missed_age += 1 + elif gender is None and age is None: + missed_gender_and_age += 1 + # skip such image + continue + + if gender is not None: + stat_per_gender[gender] += 1 + + if age is not None: + age = age_map[age] if age in age_map else age + stat_per_ages[age] += 1 + + dataset_pictures.append(PictureInfo(image_path, age, gender)) + + print(f"Missed annots for images: {missed}/{total}") + print(f"Missed genders: {missed_gender}") + print(f"Missed ages: {missed_age}") + print(f"Missed ages and gender: {missed_gender_and_age}") + print(f"\nPer gender images: {stat_per_gender}") + ages = list(stat_per_ages.keys()) + print(f"Per ages categories ({len(ages)} cats) :") + ages = sorted(ages, key=lambda x: int(x.split("(")[-1].split(",")[0].strip())) + for age in ages: + print(f"Age: {age} Count: {stat_per_ages[age]}") + + return dataset_pictures + + +def main(faces_dir: str, annotations: List[str], data_dir: str, detector_cfg: dict = None): + """ + Generate a .txt annotation file with columns: + ["img_name", "age", "gender", + "face_x0", "face_y0", "face_x1", "face_y1", + "person_x0", "person_y0", "person_x1", "person_y1"] + + All person bboxes here will be set to [-1, -1, -1, -1] + + If detector_cfg is set, for each face bbox will be refined using detector. + Also, other detected faces wil be written to txt file (needed for further preprocessing) + """ + # out directory for annotations + out_dir = os.path.join(data_dir, "annotations") + os.makedirs(out_dir, exist_ok=True) + + # load annotations + images: List[PictureInfo] = read_data(faces_dir, annotations, data_dir) + + if detector_cfg: + # detect faces with yolo detector + faces_not_found, images_with_other_faces = 0, 0 + other_faces: List[PictureInfo] = [] + + detector_weights, device = detector_cfg["weights"], detector_cfg["device"] + detector = Detector(detector_weights, device, verbose=False, conf_thresh=0.1, iou_thresh=0.2) + for image_info in tqdm.tqdm(images, desc="Detecting faces: "): + cv_im = cv2.imread(image_info.image_path) + im_h, im_w = cv_im.shape[:2] + + detected_objects: PersonAndFaceResult = detector.predict(cv_im) + main_bbox, other_bboxes_inds = get_main_face(detected_objects) + + if main_bbox is None: + # use a full image as face bbox + faces_not_found += 1 + image_info.bbox = [0, 0, im_w, im_h] + else: + image_info.bbox = main_bbox + + if len(other_bboxes_inds): + images_with_other_faces += 1 + + additional_faces = get_additional_bboxes(detected_objects, other_bboxes_inds, image_info.image_path) + other_faces.extend(additional_faces) + + print(f"Faces not detected: {faces_not_found}/{len(images)}") + print(f"Images with other faces: {images_with_other_faces}/{len(images)}") + print(f"Other faces: {len(other_faces)}") + + images = images + other_faces + + else: + # use a full image as face bbox + for image_info in tqdm.tqdm(images, desc="Collect face bboxes: "): + cv_im = cv2.imread(image_info.image_path) + im_h, im_w = cv_im.shape[:2] + image_info.bbox = [0, 0, im_w, im_h] # xyxy + + save_annotations(images, faces_dir, out_file=os.path.join(out_dir, "adience_annotations.csv")) + + +def get_parser(): + parser = argparse.ArgumentParser(description="Adience") + parser.add_argument( + "--dataset_path", + default="data/adience", + type=str, + required=True, + help="path to dataset with faces/ and fold_{i}_data.txt files", + ) + parser.add_argument( + "--detector_weights", default=None, type=str, required=False, help="path to face and person detector" + ) + parser.add_argument("--device", default="cuda:0", type=str, required=False, help="device to inference detector") + + return parser + + +if __name__ == "__main__": + + parser = get_parser() + args = parser.parse_args() + + data_dir = args.dataset_path + faces_dir = os.path.join(data_dir, "faces") + + if data_dir[-1] == "/": + data_dir = data_dir[:-1] + + annotations = [ + os.path.join(data_dir, "fold_0_data.txt"), + os.path.join(data_dir, "fold_1_data.txt"), + os.path.join(data_dir, "fold_2_data.txt"), + os.path.join(data_dir, "fold_3_data.txt"), + os.path.join(data_dir, "fold_4_data.txt"), + ] + + detector_cfg: Optional[Dict[str, str]] = None + if args.detector_weights is not None: + detector_cfg = {"weights": args.detector_weights, "device": "cuda:0"} + + main(faces_dir, annotations, data_dir, detector_cfg) diff --git a/age_estimator/mivolo/tools/prepare_agedb.py b/age_estimator/mivolo/tools/prepare_agedb.py new file mode 100644 index 0000000000000000000000000000000000000000..dc25801a6699ea44cff347361692031a6ba2838e --- /dev/null +++ b/age_estimator/mivolo/tools/prepare_agedb.py @@ -0,0 +1,57 @@ +import argparse +import json +import os +from typing import Dict, Optional + +from prepare_cacd import collect_faces + + +def get_parser(): + parser = argparse.ArgumentParser(description="AgeDB") + parser.add_argument( + "--dataset_path", + default="data/AgeDB", + type=str, + required=True, + help="path to dataset with AgeDB folder", + ) + parser.add_argument( + "--detector_weights", default=None, type=str, required=False, help="path to face and person detector" + ) + parser.add_argument("--device", default="cuda:0", type=str, required=False, help="device to inference detector") + + return parser + + +if __name__ == "__main__": + + parser = get_parser() + args = parser.parse_args() + + data_dir = args.dataset_path + if data_dir[-1] == "/": + data_dir = data_dir[:-1] + + faces_dir = os.path.join(data_dir, "AgeDB") + + # https://github.com/paplhjak/Facial-Age-Estimation-Benchmark-Databases/tree/main + json_path = os.path.join(data_dir, "AgeDB.json") + with open(json_path, "r") as stream: + annotations = json.load(stream) + + detector_cfg: Optional[Dict[str, str]] = None + if args.detector_weights is not None: + detector_cfg = {"weights": args.detector_weights, "device": "cuda:0"} + + splits = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] + collect_faces( + faces_dir, + annotations, + data_dir, + detector_cfg, + padding=0.1, + splits=splits, + db_name="agedb", + find_persons=True, + use_coarse_faces=True, + ) diff --git a/age_estimator/mivolo/tools/prepare_cacd.py b/age_estimator/mivolo/tools/prepare_cacd.py new file mode 100644 index 0000000000000000000000000000000000000000..ddff3c29f3eb804eb94a20caa7a72bddde24a02e --- /dev/null +++ b/age_estimator/mivolo/tools/prepare_cacd.py @@ -0,0 +1,254 @@ +import argparse +import json +import os +from collections import defaultdict +from typing import Dict, List, Optional + +import cv2 +import tqdm +from mivolo.data.data_reader import PictureInfo, get_all_files +from mivolo.modeling.yolo_detector import Detector, PersonAndFaceResult +from preparation_utils import get_additional_bboxes, get_main_face, save_annotations +from prepare_fairface import find_persons_on_image + + +def get_im_name(img_path): + im_name = img_path.split("/")[-1] + im_name = im_name.replace("é", "e").replace("é", "e") + im_name = im_name.replace("ó", "o").replace("ó", "o") + im_name = im_name.replace("å", "a").replace("å", "a") + im_name = im_name.replace("ñ", "n").replace("ñ", "n") + im_name = im_name.replace("ö", "o").replace("ö", "o") + im_name = im_name.replace("ä", "a").replace("ä", "a") + im_name = im_name.replace("ü", "u").replace("ü", "u") + im_name = im_name.replace("á", "a").replace("á", "a") + im_name = im_name.replace("ë", "e").replace("ë", "e") + im_name = im_name.replace("í", "i").replace("í", "i") + + return im_name + + +def read_json_annotations(annotations: List[str], splits: List[str]) -> Dict[str, dict]: + print("Parsing annotations") + annotations_per_image = {} + stat_per_split: Dict[str, int] = defaultdict(int) + + missed = 0 + for item_id, face in tqdm.tqdm(enumerate(annotations), total=len(annotations)): + im_name = get_im_name(face["img_path"]) + split = splits[int(face["folder"])] + + stat_per_split[split] += 1 + + gender = face["gender"] if "gender" in face else None + if "alignment_source" in face and face["alignment_source"] == "file not found": + missed += 1 + + annotations_per_image[im_name] = {"age": str(face["age"]), "gender": gender, "split": split} + + print("missed annots: ", missed) + + print(f"Per split images: {stat_per_split}") + print(f"Found {len(annotations_per_image)} annotations") + return annotations_per_image + + +def read_data(images_dir, annotations, splits) -> Dict[str, List[PictureInfo]]: + dataset: Dict[str, List[PictureInfo]] = defaultdict(list) + all_images = get_all_files(images_dir) + print(f"Found {len(all_images)} images") + + annotations_per_file: Dict[str, dict] = read_json_annotations(annotations, splits) + + total, missed = 0, 0 + missed_gender_and_age = 0 + stat_per_ages: Dict[str, int] = defaultdict(int) + stat_per_gender: Dict[str, int] = defaultdict(int) + + for image_path in all_images: + total += 1 + image_name = get_im_name(image_path) + + if image_name not in annotations_per_file: + missed += 1 + print(f"Can not find annotation for {image_name}") + else: + annot = annotations_per_file[image_name] + age, gender, split = annot["age"], annot["gender"], annot["split"] + + if gender is None and age is None: + missed_gender_and_age += 1 + # skip such image + continue + + if age is not None: + stat_per_ages[age] += 1 + if gender is not None: + stat_per_gender[gender] += 1 + + info = PictureInfo(image_path, age, gender) + dataset[split].append(info) + + print(f"Missed annots for images: {missed}/{total}") + print(f"Missed ages and gender: {missed_gender_and_age}") + ages = list(stat_per_ages.keys()) + print(f"Per gender stat: {stat_per_gender}") + print(f"Per ages categories ({len(ages)} cats) :") + ages = sorted(ages, key=lambda x: int(x.split("(")[-1].split(",")[0].strip())) + for age in ages: + print(f"Age: {age} Count: {stat_per_ages[age]}") + + return dataset + + +def collect_faces( + faces_dir: str, + annotations: List[dict], + data_dir: str, + detector_cfg: dict = None, + padding: float = 0.1, + splits: List[str] = [], + db_name: str = "", + use_coarse_persons: bool = False, + find_persons: bool = False, + person_padding: float = 0.0, + use_coarse_faces: bool = False, +): + """ + Generate train, val, test .txt annotation files with columns: + ["img_name", "age", "gender", + "face_x0", "face_y0", "face_x1", "face_y1", + "person_x0", "person_y0", "person_x1", "person_y1"] + + All person bboxes here will be set to [-1, -1, -1, -1] + + If detector_cfg is set, for each face bbox will be refined using detector. + Also, other detected faces wil be written to txt file (needed for further preprocessing) + """ + + # out directory for annotations + out_dir = os.path.join(data_dir, "annotations") + os.makedirs(out_dir, exist_ok=True) + + # load annotations + images_per_split: Dict[str, List[PictureInfo]] = read_data(faces_dir, annotations, splits) + + for split_ind, (split, images) in enumerate(images_per_split.items()): + print(f"Processing {split} split ({split_ind}/{len(images_per_split)})...") + if detector_cfg: + # detect faces with yolo detector + faces_not_found, images_with_other_faces = 0, 0 + other_faces: List[PictureInfo] = [] + + detector_weights, device = detector_cfg["weights"], detector_cfg["device"] + detector = Detector(detector_weights, device, verbose=False, conf_thresh=0.1, iou_thresh=0.2) + for image_info in tqdm.tqdm(images, desc="Detecting faces: "): + cv_im = cv2.imread(image_info.image_path) + im_h, im_w = cv_im.shape[:2] + + pad_x, pad_y = int(padding * im_w), int(padding * im_h) + coarse_face_bbox = [pad_x, pad_y, im_w - pad_x, im_h - pad_y] # xyxy + + detected_objects: PersonAndFaceResult = detector.predict(cv_im) + main_bbox, other_faces_inds = get_main_face(detected_objects, coarse_face_bbox) + + if len(other_faces_inds): + images_with_other_faces += 1 + + if main_bbox is None: + # use a full image as a face bbox + faces_not_found += 1 + main_bbox = coarse_face_bbox + elif use_coarse_faces: + main_bbox = coarse_face_bbox + image_info.bbox = main_bbox + + if find_persons: + additional_faces, additional_persons = find_persons_on_image( + image_info, main_bbox, detected_objects, other_faces_inds, device + ) + # add all additional faces + other_faces.extend(additional_faces) + # add persons with empty faces + other_faces.extend(additional_persons) + else: + additional_faces = get_additional_bboxes(detected_objects, other_faces_inds, image_info.image_path) + other_faces.extend(additional_faces) + # full image as a person bbox + coarse_person_bbox = [0, 0, im_w, im_h] # xyxy + if find_persons: + image_info.person_bbox = coarse_person_bbox + + print(f"Faces not detected: {faces_not_found}/{len(images)}") + print(f"Images with other faces: {images_with_other_faces}/{len(images)}") + print(f"Other faces: {len(other_faces)}") + + images = images + other_faces + + else: + for image_info in tqdm.tqdm(images, desc="Collect face bboxes: "): + + cv_im = cv2.imread(image_info.image_path) + im_h, im_w = cv_im.shape[:2] + + # use a full image as a face bbox + pad_x, pad_y = int(padding * im_w), int(padding * im_h) + image_info.bbox = [pad_x, pad_y, im_w - pad_x, im_h - pad_y] # xyxy + + if use_coarse_persons or find_persons: + # full image as a person bbox + pad_x_p, pad_y_p = int(person_padding * im_w), int(person_padding * im_h) + image_info.person_bbox = [pad_x_p, pad_y_p, im_w - pad_x_p, im_h] # xyxy + + save_annotations(images, faces_dir, out_file=os.path.join(out_dir, f"{db_name}_{split}_annotations.csv")) + + +def get_parser(): + parser = argparse.ArgumentParser(description="CACD") + parser.add_argument( + "--dataset_path", + default="data/CACD", + type=str, + required=True, + help="path to dataset with CACD200 folder", + ) + parser.add_argument( + "--detector_weights", default=None, type=str, required=False, help="path to face and person detector" + ) + parser.add_argument("--device", default="cuda:0", type=str, required=False, help="device to inference detector") + + return parser + + +if __name__ == "__main__": + + parser = get_parser() + args = parser.parse_args() + + data_dir = args.dataset_path + if data_dir[-1] == "/": + data_dir = data_dir[:-1] + + faces_dir = os.path.join(data_dir, "CACD2000") + + # https://github.com/paplhjak/Facial-Age-Estimation-Benchmark-Databases/tree/main + json_path = os.path.join(data_dir, "CACD2000.json") + with open(json_path, "r") as stream: + annotations = json.load(stream) + + detector_cfg: Optional[Dict[str, str]] = None + if args.detector_weights is not None: + detector_cfg = {"weights": args.detector_weights, "device": "cuda:0"} + + splits = ["train", "valid", "test"] + collect_faces( + faces_dir, + annotations, + data_dir, + detector_cfg, + padding=0.2, + splits=splits, + db_name="cacd", + find_persons=True, + use_coarse_faces=True, + ) diff --git a/age_estimator/mivolo/tools/prepare_fairface.py b/age_estimator/mivolo/tools/prepare_fairface.py new file mode 100644 index 0000000000000000000000000000000000000000..b3e975bdeb2f3825c60ac49317e62f140d2d21aa --- /dev/null +++ b/age_estimator/mivolo/tools/prepare_fairface.py @@ -0,0 +1,205 @@ +import argparse +import os +from collections import defaultdict +from typing import Dict, List, Optional, Tuple + +import cv2 +import pandas as pd +import torch +import tqdm +from mivolo.data.data_reader import PictureInfo, get_all_files +from mivolo.modeling.yolo_detector import Detector, PersonAndFaceResult +from preparation_utils import assign_persons, associate_persons, get_additional_bboxes, get_main_face, save_annotations + + +def read_fairface_annotations(annotations_files): + annotations_per_image = {} + cols = ["file", "age", "gender"] + + for file in annotations_files: + split_name = os.path.basename(file).split(".")[0].split("_")[-1] + df = pd.read_csv(file, sep=",", usecols=cols) + for index, row in df.iterrows(): + aligned_face_path = row["file"] + + age, gender = row["age"], row["gender"] + # M or F + gender = gender[0].upper() if isinstance(gender, str) else None + age = age.replace("-", ";") if isinstance(age, str) else None + + annotations_per_image[aligned_face_path] = {"age": age, "gender": gender, "split": split_name} + return annotations_per_image + + +def read_data(images_dir, annotations_files) -> Tuple[List[PictureInfo], List[PictureInfo]]: + dataset_pictures_train: List[PictureInfo] = [] + dataset_pictures_val: List[PictureInfo] = [] + + all_images = get_all_files(images_dir) + annotations_per_file = read_fairface_annotations(annotations_files) + + SPLIT_TYPE = Dict[str, Dict[str, int]] + splits_stat_per_gender: SPLIT_TYPE = defaultdict(lambda: defaultdict(int)) + splits_stat_per_ages: SPLIT_TYPE = defaultdict(lambda: defaultdict(int)) + + age_map = {"more than 70": "70;120"} + for image_path in all_images: + relative_path = image_path.replace(f"{images_dir}/", "") + + annot = annotations_per_file[relative_path] + split = annot["split"] + age, gender = annot["age"], annot["gender"] + age = age_map[age] if age in age_map else age + + splits_stat_per_gender[split][gender] += 1 + splits_stat_per_ages[split][age] += 1 + + if split == "train": + dataset_pictures_train.append(PictureInfo(image_path, age, gender)) + elif split == "val": + dataset_pictures_val.append(PictureInfo(image_path, age, gender)) + else: + raise ValueError(f"Unknown split name: {split}") + + print(f"Found train/val images: {len(dataset_pictures_train)}/{len(dataset_pictures_val)}") + for split, stat_per_gender in splits_stat_per_gender.items(): + print(f"\n{split} Per gender images: {stat_per_gender}") + + for split, stat_per_ages in splits_stat_per_ages.items(): + ages = list(stat_per_ages.keys()) + print(f"\n{split} Per ages categories ({len(ages)} cats) :") + ages = sorted(ages, key=lambda x: int(x.split(";")[0].strip())) + for age in ages: + print(f"Age: {age} Count: {stat_per_ages[age]}") + + return dataset_pictures_train, dataset_pictures_val + + +def find_persons_on_image(image_info, main_bbox, detected_objects, other_faces_inds, device): + # find person_ind for each face (main + other_faces) + all_faces: List[torch.tensor] = [torch.tensor(main_bbox).to(device)] + [ + detected_objects.get_bbox_by_ind(ind) for ind in other_faces_inds + ] + faces_persons_map, other_persons_inds = associate_persons(all_faces, detected_objects) + + additional_faces: List[PictureInfo] = get_additional_bboxes( + detected_objects, other_faces_inds, image_info.image_path + ) + + # set person bboxes for all faces (main + additional_faces) + assign_persons([image_info] + additional_faces, faces_persons_map, detected_objects) + if faces_persons_map[0] is not None: + assert all(coord != -1 for coord in image_info.person_bbox) + + additional_persons: List[PictureInfo] = get_additional_bboxes( + detected_objects, other_persons_inds, image_info.image_path, is_person=True + ) + + return additional_faces, additional_persons + + +def main(faces_dir: str, annotations: List[str], data_dir: str, detector_cfg: dict = None): + """ + Generate a .txt annotation file with columns: + ["img_name", "age", "gender", + "face_x0", "face_y0", "face_x1", "face_y1", + "person_x0", "person_y0", "person_x1", "person_y1"] + + If detector_cfg is set, for each face bbox will be refined using detector. + Person bbox will be assigned for each face. + Also, other detected faces and persons wil be written to txt file (needed for further preprocessing) + """ + # out directory for txt annotations + out_dir = os.path.join(data_dir, "annotations") + os.makedirs(out_dir, exist_ok=True) + + # load annotations + dataset_pictures_train, dataset_pictures_val = read_data(faces_dir, annotations) + + for images, split_name in zip([dataset_pictures_train, dataset_pictures_val], ["train", "val"]): + + if detector_cfg: + # detect faces with yolo detector + faces_not_found, images_with_other_faces = 0, 0 + other_faces: List[PictureInfo] = [] + + detector_weights, device = detector_cfg["weights"], detector_cfg["device"] + detector = Detector(detector_weights, device, verbose=False, conf_thresh=0.1, iou_thresh=0.2) + for image_info in tqdm.tqdm(images, desc=f"Detecting {split_name} faces: "): + cv_im = cv2.imread(image_info.image_path) + im_h, im_w = cv_im.shape[:2] + # all images are 448x448 and with 125 padding + coarse_bbox = [125, 125, im_w - 125, im_h - 125] # xyxy + + detected_objects: PersonAndFaceResult = detector.predict(cv_im) + main_bbox, other_faces_inds = get_main_face(detected_objects, coarse_bbox) + if len(other_faces_inds): + images_with_other_faces += 1 + + if main_bbox is None: + # use a full image as face bbox + faces_not_found += 1 + main_bbox = coarse_bbox + image_info.bbox = main_bbox + + additional_faces, additional_persons = find_persons_on_image( + image_info, main_bbox, detected_objects, other_faces_inds, device + ) + + # add all additional faces + other_faces.extend(additional_faces) + + # add persons with empty faces + other_faces.extend(additional_persons) + + print(f"Faces not detected: {faces_not_found}/{len(images)}") + print(f"Images with other faces: {images_with_other_faces}/{len(images)}") + print(f"Other bboxes (faces/persons): {len(other_faces)}") + + images = images + other_faces + + else: + for image_info in tqdm.tqdm(images, desc="Collect face bboxes: "): + cv_im = cv2.imread(image_info.image_path) + im_h, im_w = cv_im.shape[:2] + # all images are 448x448 and with 125 padding + image_info.bbox = [125, 125, im_w - 125, im_h - 125] # xyxy + + save_annotations(images, faces_dir, out_file=os.path.join(out_dir, f"{split_name}_annotations.csv")) + + +def get_parser(): + parser = argparse.ArgumentParser(description="FairFace") + parser.add_argument( + "--dataset_path", + default="data/FairFace", + type=str, + required=True, + help="path to folder with fairface-img-margin125-trainval/ and fairface_label_{split}.csv", + ) + parser.add_argument( + "--detector_weights", default=None, type=str, required=False, help="path to face and person detector" + ) + parser.add_argument("--device", default="cuda:0", type=str, required=False, help="device to inference detector") + + return parser + + +if __name__ == "__main__": + + parser = get_parser() + args = parser.parse_args() + + data_dir = args.dataset_path + faces_dir = os.path.join(data_dir, "fairface-img-margin125-trainval") + + if data_dir[-1] == "/": + data_dir = data_dir[:-1] + + annotations = [os.path.join(data_dir, "fairface_label_train.csv"), os.path.join(data_dir, "fairface_label_val.csv")] + + detector_cfg: Optional[Dict[str, str]] = None + if args.detector_weights is not None: + detector_cfg = {"weights": args.detector_weights, "device": "cuda:0"} + + main(faces_dir, annotations, data_dir, detector_cfg) diff --git a/age_estimator/models.py b/age_estimator/models.py new file mode 100644 index 0000000000000000000000000000000000000000..71a836239075aa6e6e4ecb700e9c42c95c022d91 --- /dev/null +++ b/age_estimator/models.py @@ -0,0 +1,3 @@ +from django.db import models + +# Create your models here. diff --git a/age_estimator/tests.py b/age_estimator/tests.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce503c2dd97ba78597f6ff6e4393132753573f6 --- /dev/null +++ b/age_estimator/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/age_estimator/urls.py b/age_estimator/urls.py new file mode 100644 index 0000000000000000000000000000000000000000..6b67fb5fb3645ac05f1978ca1f807206d024e526 --- /dev/null +++ b/age_estimator/urls.py @@ -0,0 +1,12 @@ +from django.urls import path +from . import views + +from .views import AgeEstimation +# from bpm_app.views import calculate_heart_rate + +# urlpatterns = [ +# path('calculate-bpm/', , name='calculate_bpm'), +# ] +urlpatterns = [ + path('age_estimator/', AgeEstimation.as_view(), name='age_estimation_view'), +] diff --git a/age_estimator/views.py b/age_estimator/views.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf2a40ac9245bd365ac40ca765ca914f613b25e --- /dev/null +++ b/age_estimator/views.py @@ -0,0 +1,62 @@ +from django.shortcuts import render + +# Create your views here. +from django.http import JsonResponse +from django.views.decorators.csrf import csrf_exempt +import json +from rest_framework.views import APIView +from django.core.files.storage import default_storage +# from .demo import run_inference # Import from demo.py +from age_estimator.mivolo.demo_copy import main +import os + + +# @csrf_exempt +class AgeEstimation(APIView): + def post(self, request): + # Save the uploaded video file + try: + video_file = request.FILES['video_file'] + + +# def age_estimation_view(request): +# if request.method == "POST": +# try: +# data = json.loads(request.body) + # video_path = data.get("video_path") + output_folder = 'output' + detector_weights = 'age_estimator/mivolo/models/yolov8x_person_face.pt' + checkpoint = 'age_estimator/mivolo/models/model_imdb_cross_person_4.22_99.46.pth.tar' + # detector_weights = data.get("detector_weights") + # checkpoint = data.get("checkpoint") + # device = data.get("device", "cpu") + # with_persons = data.get("with_persons", False) + # disable_faces = data.get("disable_faces", False) + # draw = data.get("draw", False) + device = 'cpu' + with_persons = True + disable_faces = False + draw = True + file_name = default_storage.save(video_file.name, video_file) + video_file_path = os.path.join(default_storage.location, file_name) + + # Check for required parameters + if not video_file_path or not detector_weights or not checkpoint: + return JsonResponse({"error": "Missing required fields: 'video_path', 'detector_weights', or 'checkpoint'"}, status=400) + + # Run the inference function from demo.py + absolute_age, lower_bound, upper_bound = main( + video_file_path, output_folder, detector_weights, checkpoint, device, with_persons, disable_faces, draw + ) + + # print(absolute_age) + # Return the result as a JSON response + return JsonResponse({ + # "absolute_age": absolute_age, + "age_range": f"{lower_bound} - {upper_bound}" + }) + + except Exception as e: + return JsonResponse({"error": str(e)}, status=500) + + return JsonResponse({"error": "Invalid request method"}, status=400)