diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4de25099f95cd1379ff38fe48d5590fa66348840 --- /dev/null +++ b/.gitignore @@ -0,0 +1,134 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +data/ +*.iml +.idea/ +*.pt \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a3112b2df0ba4c714b2d9b429e29250832cddc88 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Pramook Khungurn + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index a334ed4f2130cf528282ac607a1ecb56913c19b9..2cbaa908c9f52c2a17d2a3d2b9be4aaac70c2eb9 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,239 @@ ---- -title: Talking Head Anime 3 -emoji: 🐠 -colorFrom: pink -colorTo: red -sdk: gradio -sdk_version: 3.18.0 -app_file: app.py -pinned: false -license: mit ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Demo Code for "Talking Head(?) Anime from A Single Image 3: Now the Body Too" + +This repository contains demo programs for the [Talking Head(?) Anime from a Single Image 3: Now the Body Too](https://pkhungurn.github.io/talking-head-anime-3/index.html) project. As the name implies, the project allows you to animate anime characters, and you only need a single image of that character to do so. There are two demo programs: + +* The ``manual_poser`` lets you manipulate a character's facial expression, head rotation, body rotation, and chest expansion due to breathing through a graphical user interface. +* ``ifacialmocap_puppeteer`` lets you transfer your facial motion to an anime character. + +## Try the Manual Poser on Google Colab + +If you do not have the required hardware (discussed below) or do not want to download the code and set up an environment to run it, click [![this link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pkhungurn/talking-head-anime-3-demo/blob/master/colab.ipynb) to try running the manual poser on [Google Colab](https://research.google.com/colaboratory/faq.html). + +## Hardware Requirements + +Both programs require a recent and powerful Nvidia GPU to run. I could personally ran them at good speed with the Nvidia Titan RTX. However, I think recent high-end gaming GPUs such as the RTX 2080, the RTX 3080, or better would do just as well. + +The `ifacialmocap_puppeteer` requires an iOS device that is capable of computing [blend shape parameters](https://developer.apple.com/documentation/arkit/arfaceanchor/2928251-blendshapes) from a video feed. This means that the device must be able to run iOS 11.0 or higher and must have a TrueDepth front-facing camera. (See [this page](https://developer.apple.com/documentation/arkit/content_anchors/tracking_and_visualizing_faces) for more info.) In other words, if you have the iPhone X or something better, you should be all set. Personally, I have used an iPhone 12 mini. + +## Software Requirements + +### GPU Related Software + +Please update your GPU's device driver and install the [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) that is compatible with your GPU and is newer than the version you will be installing in the next subsection. + +### Python Environment + +Both ``manual_poser`` and ``ifacialmocap_puppeteer`` are available as desktop applications. To run them, you need to set up an environment for running programs written in the [Python](http://www.python.org) language. The environment needs to have the following software packages: + +* Python >= 3.8 +* PyTorch >= 1.11.0 with CUDA support +* SciPY >= 1.7.3 +* wxPython >= 4.1.1 +* Matplotlib >= 3.5.1 + +One way to do so is to install [Anaconda](https://www.anaconda.com/) and run the following commands in your shell: + +``` +> conda create -n talking-head-anime-3-demo python=3.8 +> conda activate talking-head-anime-3-demo +> conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch +> conda install scipy +> pip install wxpython +> conda install matplotlib +``` + +#### Caveat 1: Do not use Python 3.10 on Windows + +As of June 2006, you cannot use [wxPython](https://www.wxpython.org/) with Python 3.10 on Windows. As a result, do not use Python 3.10 until [this bug](https://github.com/wxWidgets/Phoenix/issues/2024) is fixed. This means you should not set ``python=3.10`` in the first ``conda`` command in the listing above. + +#### Caveat 2: Adjust versions of Python and CUDA Toolkit as needed + +The environment created by the commands above gives you Python version 3.8 and an installation of [PyTorch](http://pytorch.org) that was compiled with CUDA Toolkit version 11.3. This particular setup might not work in the future because you may find that this particular PyTorch package does not work with your new computer. The solution is to: + +1. Change the Python version in the first command to a recent one that works for your OS. (That is, do not use 3.10 if you are using Windows.) +2. Change the version of CUDA toolkit in the third command to one that the PyTorch's website says is available. In particular, scroll to the "Install PyTorch" section and use the chooser there to pick the right command for your computer. Use that command to install PyTorch instead of the third command above. + +![The command to install PyTorch](docs/pytorch-install-command.png "The command to install PyTorch") + +### Jupyter Environment + +The ``manual_poser`` is also available as a [Jupyter Nootbook](http://jupyter.org). To run it on your local machines, you also need to install: + +* Jupyter Notebook >= 7.3.4 +* IPywidgets >= 7.7.0 + +In some case, you will also need to enable the ``widgetsnbextension`` as well. So, run + +``` +> jupyter nbextension enable --py widgetsnbextension +``` + +After installing the above two packages. Using Anaconda, I managed to do the above with the following commands: + +``` +> conda install -c conda-forge notebook +> conda install -c conda-forge ipywidgets +> jupyter nbextension enable --py widgetsnbextension +``` + +### Automatic Environment Construction with Anaconda + +You can also use Anaconda to download and install all Python packages in one command. Open your shell, change the directory to where you clone the repository, and run: + +``` +> conda env create -f environment.yml +``` + +This will create an environment called ``talking-head-anime-3-demo`` containing all the required Python packages. + +### iFacialMocap + +If you want to use ``ifacialmocap_puppeteer``, you will also need to an iOS software called [iFacialMocap](https://www.ifacialmocap.com/) (a 980 yen purchase in the App Store). You do not need to download the paired application this time. Your iOS and your computer must use the same network. For example, you may connect them to the same wireless router. + +## Download the Models + +Before running the programs, you need to download the model files from this [Dropbox link](https://www.dropbox.com/s/y7b8jl4n2euv8xe/talking-head-anime-3-models.zip?dl=0) and unzip it to the ``data/models`` folder under the repository's root directory. In the end, the data folder should look like: + +``` ++ data + + images + - crypko_00.png + - crypko_01.png + : + - crypko_07.png + - lambda_00.png + - lambda_01.png + + models + + separable_float + - editor.pt + - eyebrow_decomposer.pt + - eyebrow_morphing_combiner.pt + - face_morpher.pt + - two_algo_face_body_rotator.pt + + separable_half + - editor.pt + : + - two_algo_face_body_rotator.pt + + standard_float + - editor.pt + : + - two_algo_face_body_rotator.pt + + standard_half + - editor.pt + : + - two_algo_face_body_rotator.pt +``` + +The model files are distributed with the +[Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/legalcode), which +means that you can use them for commercial purposes. However, if you distribute them, you must, among other things, say +that I am the creator. + +## Running the `manual_poser` Desktop Application + +Open a shell. Change your working directory to the repository's root directory. Then, run: + +``` +> python tha3/app/manual_poser.py +``` + +Note that before running the command above, you might have to activate the Python environment that contains the required +packages. If you created an environment using Anaconda as was discussed above, you need to run + +``` +> conda activate talking-head-anime-3-demo +``` + +if you have not already activated the environment. + +### Choosing System Variant to Use + +As noted in the [project's write-up](http://pkhungurn.github.io/talking-head-anime-3/index.html), I created 4 variants of the neural network system. They are called ``standard_float``, ``separable_float``, ``standard_half``, and ``separable_half``. All of them have the same functionalities, but they differ in their sizes, RAM usage, speed, and accuracy. You can specify which variant that the ``manual_poser`` program uses through the ``--model`` command line option. + +``` +> python tha3/app/manual_poser --model +``` + +where ```` must be one of the 4 names above. If no variant is specified, the ``standard_float`` variant (which is the largest, slowest, and most accurate) will be used. + +## Running the `manual_poser` Jupyter Notebook + +Open a shell. Activate the environment. Change your working directory to the repository's root directory. Then, run: + +``` +> jupyter notebook +``` + +A browser window should open. In it, open `manual_poser.ipynb`. Once you have done so, you should see that it has two cells. Run the two cells in order. Then, scroll down to the end of the document, and you'll see the GUI there. + +You can choose the system variant to use by changing the ``MODEL_NAME`` variable in the first cell. If you do, you will need to rerun both cells in order for the variant to be loaded and the GUI to be properly updated to use it. + +## Running the `ifacialmocap_poser` + +First, run iFacialMocap on your iOS device. It should show you the device's IP address. Jot it down. Keep the app open. + +![IP address in iFacialMocap screen](docs/ifacialmocap_ip.jpg "IP address in iFacialMocap screen") + +Open a shell. Activate the Python environment. Change your working directory to the repository's root directory. Then, run: + +``` +> python tha3/app/ifacialmocap_puppeteer.py +``` + +You will see a text box with label "Capture Device IP." Write the iOS device's IP address that you jotted down there. + +![Write IP address of your iOS device in the 'Capture Device IP' text box.](docs/ifacialmocap_puppeteer_ip_address_box.png "Write IP address of your iOS device in the 'Capture Device IP' text box.") + +Click the "START CAPTURE!" button to the right. + +![Click the 'START CAPTURE!' button.](docs/ifacialmocap_puppeteer_click_start_capture.png "Click the 'START CAPTURE!' button.") + +If the programs are connected properly, you should see the numbers in the bottom part of the window change when you move your head. + +![The numbers in the bottom part of the window should change when you move your head.](docs/ifacialmocap_puppeteer_numbers.png "The numbers in the bottom part of the window should change when you move your head.") + +Now, you can load an image of a character, and it should follow your facial movement. + +## Contraints on Input Images + +In order for the system to work well, the input image must obey the following constraints: + +* It should be of resolution 512 x 512. (If the demo programs receives an input image of any other size, they will resize the image to this resolution and also output at this resolution.) +* It must have an alpha channel. +* It must contain only one humanoid character. +* The character should be standing upright and facing forward. +* The character's hands should be below and far from the head. +* The head of the character should roughly be contained in the 128 x 128 box in the middle of the top half of the image. +* The alpha channels of all pixels that do not belong to the character (i.e., background pixels) must be 0. + +![An example of an image that conforms to the above criteria](docs/input_spec.png "An example of an image that conforms to the above criteria") + +See the project's [write-up](http://pkhungurn.github.io/talking-head-anime-3/full.html#sec:problem-spec) for more details on the input image. + +## Citation + +If your academic work benefits from the code in this repository, please cite the project's web page as follows: + +> Pramook Khungurn. **Talking Head(?) Anime from a Single Image 3: Now the Body Too.** http://pkhungurn.github.io/talking-head-anime-3/, 2022. Accessed: YYYY-MM-DD. + +You can also used the following BibTex entry: + +``` +@misc{Khungurn:2022, + author = {Pramook Khungurn}, + title = {Talking Head(?) Anime from a Single Image 3: Now the Body Too}, + howpublished = {\url{http://pkhungurn.github.io/talking-head-anime-3/}}, + year = 2022, + note = {Accessed: YYYY-MM-DD}, +} +``` + +## Disclaimer + +While the author is an employee of [Google Japan](https://careers.google.com/locations/tokyo/), this software is not Google's product and is not supported by Google. + +The copyright of this software belongs to me as I have requested it using the [IARC process](https://opensource.google/documentation/reference/releasing#iarc). However, Google might claim the rights to the intellectual +property of this invention. + +The code is released under the [MIT license](https://github.com/pkhungurn/talking-head-anime-2-demo/blob/master/LICENSE). +The model is released under the [Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/legalcode). Please see the README.md file in the ``data/images`` directory for the licenses for the images there. diff --git a/colab.ipynb b/colab.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..235609c6c5d60e605b95a28a1a703c71ca9a133f --- /dev/null +++ b/colab.ipynb @@ -0,0 +1,542 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1027b46a", + "metadata": {}, + "source": [ + "# Talking Head(?) Anime from a Single Image 3: Now the Body Too (Manual Poser Tool)\n", + "\n", + "**Instruction**\n", + "\n", + "1. Run the four cells below, one by one, in order by clicking the \"Play\" button to the left of it. Wait for each cell to finish before going to the next one.\n", + "2. Scroll down to the end of the last cell, and play with the GUI.\n", + "\n", + "**Links**\n", + "\n", + "* Github repository: http://github.com/pkhungurn/talking-head-anime-3-demo\n", + "* Project writeup: http://pkhungurn.github.io/talking-head-anime-3/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54cc96d7", + "metadata": {}, + "outputs": [], + "source": [ + "# Clone the repository\n", + "%cd /content\n", + "!git clone https://github.com/pkhungurn/talking-head-anime-3-demo.git" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77f2016c", + "metadata": {}, + "outputs": [], + "source": [ + "# CD into the repository directory.\n", + "%cd /content/talking-head-anime-3-demo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1771c927", + "metadata": {}, + "outputs": [], + "source": [ + "# Download model files\n", + "!mkdir -p data/models/standard_float\n", + "!wget -O data/models/standard_float/editor.pt https://www.dropbox.com/s/zp3e5ox57sdws3y/editor.pt?dl=0\n", + "!wget -O data/models/standard_float/eyebrow_decomposer.pt https://www.dropbox.com/s/bcp42knbrk7egk8/eyebrow_decomposer.pt?dl=0\n", + "!wget -O data/models/standard_float/eyebrow_morphing_combiner.pt https://www.dropbox.com/s/oywaiio2s53lc57/eyebrow_morphing_combiner.pt?dl=0\n", + "!wget -O data/models/standard_float/face_morpher.pt https://www.dropbox.com/s/8qvo0u5lw7hqvtq/face_morpher.pt?dl=0\n", + "!wget -O data/models/standard_float/two_algo_face_body_rotator.pt https://www.dropbox.com/s/qmq1dnxrmzsxb4h/two_algo_face_body_rotator.pt?dl=0\n", + "\n", + "!mkdir -p data/models/standard_half\n", + "!wget -O data/models/standard_half/editor.pt https://www.dropbox.com/s/g21ps8gfuvz4kbo/editor.pt?dl=0\n", + "!wget -O data/models/standard_half/eyebrow_decomposer.pt https://www.dropbox.com/s/nwwwevzpmxiilgn/eyebrow_decomposer.pt?dl=0\n", + "!wget -O data/models/standard_half/eyebrow_morphing_combiner.pt https://www.dropbox.com/s/z5v0amgqif7yup1/eyebrow_morphing_combiner.pt?dl=0\n", + "!wget -O data/models/standard_half/face_morpher.pt https://www.dropbox.com/s/g03sfnd5yfs0m65/face_morpher.pt?dl=0\n", + "!wget -O data/models/standard_half/two_algo_face_body_rotator.pt https://www.dropbox.com/s/c5lrn7z34x12317/two_algo_face_body_rotator.pt?dl=0\n", + "\n", + "!mkdir -p data/models/separable_float \n", + "!wget -O data/models/separable_float/editor.pt https://www.dropbox.com/s/nwdxhrpa9fy19r4/editor.pt?dl=0\n", + "!wget -O data/models/separable_float/eyebrow_decomposer.pt https://www.dropbox.com/s/hfzjcu9cqr9wm3i/eyebrow_decomposer.pt?dl=0\n", + "!wget -O data/models/separable_float/eyebrow_morphing_combiner.pt https://www.dropbox.com/s/g04dyyyavh5o1e2/eyebrow_morphing_combiner.pt?dl=0\n", + "!wget -O data/models/separable_float/face_morpher.pt https://www.dropbox.com/s/vgi9dsj95y0rrwv/face_morpher.pt?dl=0\n", + "!wget -O data/models/separable_float/two_algo_face_body_rotator.pt https://www.dropbox.com/s/8u0qond8po34l24/two_algo_face_body_rotator.pt?dl=0\n", + "\n", + "!mkdir -p data/models/separable_half\n", + "!wget -O data/models/separable_half/editor.pt https://www.dropbox.com/s/on8kn6z9fj95j0h/editor.pt?dl=0\n", + "!wget -O data/models/separable_half/eyebrow_decomposer.pt https://www.dropbox.com/s/0hxu8opu1hmghqe/eyebrow_decomposer.pt?dl=0\n", + "!wget -O data/models/separable_half/eyebrow_morphing_combiner.pt https://www.dropbox.com/s/bgz02afp0xojqfs/eyebrow_morphing_combiner.pt?dl=0\n", + "!wget -O data/models/separable_half/face_morpher.pt https://www.dropbox.com/s/bgz02afp0xojqfs/eyebrow_morphing_combiner.pt?dl=0\n", + "!wget -O data/models/separable_half/two_algo_face_body_rotator.pt https://www.dropbox.com/s/vr8h2xxltszhw7w/two_algo_face_body_rotator.pt?dl=0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "062014f7", + "metadata": { + "id": "breeding-extra" + }, + "outputs": [], + "source": [ + "# Set this constant to specify which system variant to use.\n", + "MODEL_NAME = \"standard_float\" \n", + "\n", + "# Load the models.\n", + "import torch\n", + "DEVICE_NAME = 'cuda'\n", + "device = torch.device(DEVICE_NAME)\n", + "\n", + "def load_poser(model: str, device: torch.device):\n", + " print(\"Using the %s model.\" % model)\n", + " if model == \"standard_float\":\n", + " from tha3.poser.modes.standard_float import create_poser\n", + " return create_poser(device)\n", + " elif model == \"standard_half\":\n", + " from tha3.poser.modes.standard_half import create_poser\n", + " return create_poser(device)\n", + " elif model == \"separable_float\":\n", + " from tha3.poser.modes.separable_float import create_poser\n", + " return create_poser(device)\n", + " elif model == \"separable_half\":\n", + " from tha3.poser.modes.separable_half import create_poser\n", + " return create_poser(device)\n", + " else:\n", + " raise RuntimeError(\"Invalid model: '%s'\" % model)\n", + " \n", + "poser = load_poser(MODEL_NAME, DEVICE_NAME)\n", + "poser.get_modules();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "breeding-extra", + "metadata": { + "id": "breeding-extra" + }, + "outputs": [], + "source": [ + "# Create the GUI for manipulating character images.\n", + "import PIL.Image\n", + "import io\n", + "from io import StringIO, BytesIO\n", + "import IPython.display\n", + "import numpy\n", + "import ipywidgets\n", + "import time\n", + "import threading\n", + "import torch\n", + "from tha3.util import resize_PIL_image, extract_PIL_image_from_filelike, \\\n", + " extract_pytorch_image_from_PIL_image, convert_output_image_from_torch_to_numpy\n", + "\n", + "FRAME_RATE = 30.0\n", + "\n", + "last_torch_input_image = None\n", + "torch_input_image = None\n", + "\n", + "def show_pytorch_image(pytorch_image):\n", + " output_image = pytorch_image.detach().cpu()\n", + " numpy_image = numpy.uint8(numpy.rint(convert_output_image_from_torch_to_numpy(output_image) * 255.0))\n", + " pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')\n", + " IPython.display.display(pil_image)\n", + "\n", + "upload_input_image_button = ipywidgets.FileUpload(\n", + " accept='.png',\n", + " multiple=False,\n", + " layout={\n", + " 'width': '512px'\n", + " }\n", + ")\n", + "\n", + "output_image_widget = ipywidgets.Output(\n", + " layout={\n", + " 'border': '1px solid black',\n", + " 'width': '512px',\n", + " 'height': '512px'\n", + " }\n", + ")\n", + "\n", + "eyebrow_dropdown = ipywidgets.Dropdown(\n", + " options=[\"troubled\", \"angry\", \"lowered\", \"raised\", \"happy\", \"serious\"],\n", + " value=\"troubled\",\n", + " description=\"Eyebrow:\", \n", + ")\n", + "eyebrow_left_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Left:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "eyebrow_right_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Right:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "\n", + "eye_dropdown = ipywidgets.Dropdown(\n", + " options=[\"wink\", \"happy_wink\", \"surprised\", \"relaxed\", \"unimpressed\", \"raised_lower_eyelid\"],\n", + " value=\"wink\",\n", + " description=\"Eye:\", \n", + ")\n", + "eye_left_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Left:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "eye_right_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Right:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "\n", + "mouth_dropdown = ipywidgets.Dropdown(\n", + " options=[\"aaa\", \"iii\", \"uuu\", \"eee\", \"ooo\", \"delta\", \"lowered_corner\", \"raised_corner\", \"smirk\"],\n", + " value=\"aaa\",\n", + " description=\"Mouth:\", \n", + ")\n", + "mouth_left_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Value:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "mouth_right_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\" \",\n", + " readout=True,\n", + " readout_format=\".2f\",\n", + " disabled=True,\n", + ")\n", + "\n", + "def update_mouth_sliders(change):\n", + " if mouth_dropdown.value == \"lowered_corner\" or mouth_dropdown.value == \"raised_corner\":\n", + " mouth_left_slider.description = \"Left:\"\n", + " mouth_right_slider.description = \"Right:\"\n", + " mouth_right_slider.disabled = False\n", + " else:\n", + " mouth_left_slider.description = \"Value:\"\n", + " mouth_right_slider.description = \" \"\n", + " mouth_right_slider.disabled = True\n", + "\n", + "mouth_dropdown.observe(update_mouth_sliders, names='value')\n", + "\n", + "iris_small_left_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Left:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "iris_small_right_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Right:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "iris_rotation_x_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"X-axis:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "iris_rotation_y_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Y-axis:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "\n", + "head_x_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"X-axis:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "head_y_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Y-axis:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "neck_z_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Z-axis:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "body_y_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Y-axis rotation:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "body_z_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Z-axis rotation:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "breathing_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Breathing:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "\n", + "\n", + "control_panel = ipywidgets.VBox([\n", + " eyebrow_dropdown,\n", + " eyebrow_left_slider,\n", + " eyebrow_right_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " eye_dropdown,\n", + " eye_left_slider,\n", + " eye_right_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " mouth_dropdown,\n", + " mouth_left_slider,\n", + " mouth_right_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " ipywidgets.HTML(value=\"
Iris Shrinkage
\"),\n", + " iris_small_left_slider,\n", + " iris_small_right_slider,\n", + " ipywidgets.HTML(value=\"
Iris Rotation
\"),\n", + " iris_rotation_x_slider,\n", + " iris_rotation_y_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " ipywidgets.HTML(value=\"
Head Rotation
\"),\n", + " head_x_slider,\n", + " head_y_slider,\n", + " neck_z_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " ipywidgets.HTML(value=\"
Body Rotation
\"),\n", + " body_y_slider,\n", + " body_z_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " ipywidgets.HTML(value=\"
Breathing
\"),\n", + " breathing_slider,\n", + "])\n", + "\n", + "controls = ipywidgets.HBox([\n", + " ipywidgets.VBox([\n", + " output_image_widget, \n", + " upload_input_image_button\n", + " ]),\n", + " control_panel,\n", + "])\n", + "\n", + "from tha3.poser.modes.pose_parameters import get_pose_parameters\n", + "pose_parameters = get_pose_parameters()\n", + "pose_size = poser.get_num_parameters()\n", + "last_pose = torch.zeros(1, pose_size, dtype=poser.get_dtype()).to(device)\n", + "\n", + "iris_small_left_index = pose_parameters.get_parameter_index(\"iris_small_left\")\n", + "iris_small_right_index = pose_parameters.get_parameter_index(\"iris_small_right\")\n", + "iris_rotation_x_index = pose_parameters.get_parameter_index(\"iris_rotation_x\")\n", + "iris_rotation_y_index = pose_parameters.get_parameter_index(\"iris_rotation_y\")\n", + "head_x_index = pose_parameters.get_parameter_index(\"head_x\")\n", + "head_y_index = pose_parameters.get_parameter_index(\"head_y\")\n", + "neck_z_index = pose_parameters.get_parameter_index(\"neck_z\")\n", + "body_y_index = pose_parameters.get_parameter_index(\"body_y\")\n", + "body_z_index = pose_parameters.get_parameter_index(\"body_z\")\n", + "breathing_index = pose_parameters.get_parameter_index(\"breathing\")\n", + "\n", + "def get_pose():\n", + " pose = torch.zeros(1, pose_size, dtype=poser.get_dtype())\n", + "\n", + " eyebrow_name = f\"eyebrow_{eyebrow_dropdown.value}\"\n", + " eyebrow_left_index = pose_parameters.get_parameter_index(f\"{eyebrow_name}_left\")\n", + " eyebrow_right_index = pose_parameters.get_parameter_index(f\"{eyebrow_name}_right\")\n", + " pose[0, eyebrow_left_index] = eyebrow_left_slider.value\n", + " pose[0, eyebrow_right_index] = eyebrow_right_slider.value\n", + "\n", + " eye_name = f\"eye_{eye_dropdown.value}\"\n", + " eye_left_index = pose_parameters.get_parameter_index(f\"{eye_name}_left\")\n", + " eye_right_index = pose_parameters.get_parameter_index(f\"{eye_name}_right\")\n", + " pose[0, eye_left_index] = eye_left_slider.value\n", + " pose[0, eye_right_index] = eye_right_slider.value\n", + "\n", + " mouth_name = f\"mouth_{mouth_dropdown.value}\"\n", + " if mouth_name == \"mouth_lowered_corner\" or mouth_name == \"mouth_raised_corner\":\n", + " mouth_left_index = pose_parameters.get_parameter_index(f\"{mouth_name}_left\")\n", + " mouth_right_index = pose_parameters.get_parameter_index(f\"{mouth_name}_right\")\n", + " pose[0, mouth_left_index] = mouth_left_slider.value\n", + " pose[0, mouth_right_index] = mouth_right_slider.value\n", + " else:\n", + " mouth_index = pose_parameters.get_parameter_index(mouth_name)\n", + " pose[0, mouth_index] = mouth_left_slider.value\n", + "\n", + " pose[0, iris_small_left_index] = iris_small_left_slider.value\n", + " pose[0, iris_small_right_index] = iris_small_right_slider.value\n", + " pose[0, iris_rotation_x_index] = iris_rotation_x_slider.value\n", + " pose[0, iris_rotation_y_index] = iris_rotation_y_slider.value\n", + " pose[0, head_x_index] = head_x_slider.value\n", + " pose[0, head_y_index] = head_y_slider.value\n", + " pose[0, neck_z_index] = neck_z_slider.value\n", + " pose[0, body_y_index] = body_y_slider.value\n", + " pose[0, body_z_index] = body_z_slider.value\n", + " pose[0, breathing_index] = breathing_slider.value\n", + "\n", + " return pose.to(device)\n", + "\n", + "display(controls)\n", + "\n", + "def update(change):\n", + " global last_pose\n", + " global last_torch_input_image\n", + "\n", + " if torch_input_image is None:\n", + " return\n", + "\n", + " needs_update = False\n", + " if last_torch_input_image is None:\n", + " needs_update = True \n", + " else:\n", + " if (torch_input_image - last_torch_input_image).abs().max().item() > 0:\n", + " needs_update = True \n", + "\n", + " pose = get_pose()\n", + " if (pose - last_pose).abs().max().item() > 0:\n", + " needs_update = True\n", + "\n", + " if not needs_update:\n", + " return\n", + "\n", + " output_image = poser.pose(torch_input_image, pose)[0]\n", + " with output_image_widget:\n", + " output_image_widget.clear_output(wait=True)\n", + " show_pytorch_image(output_image) \n", + "\n", + " last_torch_input_image = torch_input_image\n", + " last_pose = pose\n", + "\n", + "def upload_image(change):\n", + " global torch_input_image\n", + " for name, file_info in upload_input_image_button.value.items():\n", + " content = io.BytesIO(file_info['content'])\n", + " if content is not None:\n", + " pil_image = resize_PIL_image(extract_PIL_image_from_filelike(content), size=(512,512))\n", + " w, h = pil_image.size\n", + " if pil_image.mode != 'RGBA':\n", + " with output_image_widget:\n", + " torch_input_image = None\n", + " output_image_widget.clear_output(wait=True)\n", + " display(ipywidgets.HTML(\"Image must have an alpha channel!!!\"))\n", + " else:\n", + " torch_input_image = extract_pytorch_image_from_PIL_image(pil_image).to(device)\n", + " if poser.get_dtype() == torch.half:\n", + " torch_input_image = torch_input_image.half()\n", + " update(None)\n", + "\n", + "upload_input_image_button.observe(upload_image, names='value')\n", + "eyebrow_dropdown.observe(update, 'value')\n", + "eyebrow_left_slider.observe(update, 'value')\n", + "eyebrow_right_slider.observe(update, 'value')\n", + "eye_dropdown.observe(update, 'value')\n", + "eye_left_slider.observe(update, 'value')\n", + "eye_right_slider.observe(update, 'value')\n", + "mouth_dropdown.observe(update, 'value')\n", + "mouth_left_slider.observe(update, 'value')\n", + "mouth_right_slider.observe(update, 'value')\n", + "iris_small_left_slider.observe(update, 'value')\n", + "iris_small_right_slider.observe(update, 'value')\n", + "iris_rotation_x_slider.observe(update, 'value')\n", + "iris_rotation_y_slider.observe(update, 'value')\n", + "head_x_slider.observe(update, 'value')\n", + "head_y_slider.observe(update, 'value')\n", + "neck_z_slider.observe(update, 'value')\n", + "body_y_slider.observe(update, 'value')\n", + "body_z_slider.observe(update, 'value')\n", + "breathing_slider.observe(update, 'value')" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "tha3.ipynb", + "provenance": [] + }, + "interpreter": { + "hash": "684906ad716c90e6f3397644b72c2a23821e93080f6b0264e4cd74aee22032ce" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/ifacialmocap_ip.jpg b/docs/ifacialmocap_ip.jpg new file mode 100644 index 0000000000000000000000000000000000000000..87970740ed2f59a37a719449657a055402b549ae Binary files /dev/null and b/docs/ifacialmocap_ip.jpg differ diff --git a/docs/ifacialmocap_puppeteer_click_start_capture.png b/docs/ifacialmocap_puppeteer_click_start_capture.png new file mode 100644 index 0000000000000000000000000000000000000000..254406389af96a18c2be610016717fe9215c4560 Binary files /dev/null and b/docs/ifacialmocap_puppeteer_click_start_capture.png differ diff --git a/docs/ifacialmocap_puppeteer_ip_address_box.png b/docs/ifacialmocap_puppeteer_ip_address_box.png new file mode 100644 index 0000000000000000000000000000000000000000..57183660086e4f0a7744c5255df8bb12c731224b Binary files /dev/null and b/docs/ifacialmocap_puppeteer_ip_address_box.png differ diff --git a/docs/ifacialmocap_puppeteer_numbers.png b/docs/ifacialmocap_puppeteer_numbers.png new file mode 100644 index 0000000000000000000000000000000000000000..98014072c4dccf249375387b2747c08da6825601 Binary files /dev/null and b/docs/ifacialmocap_puppeteer_numbers.png differ diff --git a/docs/input_spec.png b/docs/input_spec.png new file mode 100644 index 0000000000000000000000000000000000000000..3bfc5669bf550ef3144387b7023099e80b42f88c Binary files /dev/null and b/docs/input_spec.png differ diff --git a/docs/pytorch-install-command.png b/docs/pytorch-install-command.png new file mode 100644 index 0000000000000000000000000000000000000000..a26187c3eb44cb56cef1a075bfc27f0e13138e72 Binary files /dev/null and b/docs/pytorch-install-command.png differ diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..5db79969756d6a19e0ee00bb695ac6762beff3bd --- /dev/null +++ b/environment.yml @@ -0,0 +1,141 @@ +name: talking-head-anime-3-demo +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - argon2-cffi=21.3.0=pyhd8ed1ab_0 + - argon2-cffi-bindings=21.2.0=py38h294d835_2 + - asttokens=2.0.5=pyhd8ed1ab_0 + - attrs=21.4.0=pyhd8ed1ab_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - backports=1.0=py_2 + - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 + - beautifulsoup4=4.11.1=pyha770c72_0 + - blas=1.0=mkl + - bleach=5.0.0=pyhd8ed1ab_0 + - brotli=1.0.9=ha925a31_2 + - brotlipy=0.7.0=py38h2bbff1b_1003 + - ca-certificates=2022.5.18.1=h5b45459_0 + - certifi=2022.5.18.1=py38haa244fe_0 + - cffi=1.15.0=py38h2bbff1b_1 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - colorama=0.4.4=pyh9f0ad1d_0 + - cryptography=37.0.1=py38h21b164f_0 + - cudatoolkit=11.3.1=h59b6b97_2 + - cycler=0.11.0=pyhd3eb1b0_0 + - debugpy=1.6.0=py38h885f38d_0 + - decorator=5.1.1=pyhd8ed1ab_0 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - entrypoints=0.4=pyhd8ed1ab_0 + - executing=0.8.3=pyhd8ed1ab_0 + - flit-core=3.7.1=pyhd8ed1ab_0 + - fonttools=4.25.0=pyhd3eb1b0_0 + - freetype=2.10.4=hd328e21_0 + - icc_rt=2019.0.0=h0cc432a_1 + - icu=58.2=ha925a31_3 + - idna=3.3=pyhd3eb1b0_0 + - importlib-metadata=4.11.4=py38haa244fe_0 + - importlib_resources=5.7.1=pyhd8ed1ab_1 + - intel-openmp=2021.4.0=haa95532_3556 + - ipykernel=6.13.1=py38h4317176_0 + - ipython=8.4.0=py38haa244fe_0 + - ipython_genutils=0.2.0=py_1 + - ipywidgets=7.7.0=pyhd8ed1ab_0 + - jedi=0.18.1=py38haa244fe_1 + - jinja2=3.1.2=pyhd8ed1ab_1 + - jpeg=9e=h2bbff1b_0 + - jsonschema=4.6.0=pyhd8ed1ab_0 + - jupyter_client=7.3.4=pyhd8ed1ab_0 + - jupyter_core=4.10.0=py38haa244fe_0 + - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0 + - jupyterlab_widgets=1.1.0=pyhd8ed1ab_0 + - kiwisolver=1.4.2=py38hd77b12b_0 + - libpng=1.6.37=h2a8f88b_0 + - libsodium=1.0.18=h8d14728_1 + - libtiff=4.2.0=he0120a3_1 + - libuv=1.40.0=he774522_0 + - libwebp=1.2.2=h2bbff1b_0 + - lz4-c=1.9.3=h2bbff1b_1 + - markupsafe=2.1.1=py38h294d835_1 + - matplotlib=3.5.1=py38haa95532_1 + - matplotlib-base=3.5.1=py38hd77b12b_1 + - matplotlib-inline=0.1.3=pyhd8ed1ab_0 + - mistune=0.8.4=py38h294d835_1005 + - mkl=2021.4.0=haa95532_640 + - mkl-service=2.4.0=py38h2bbff1b_0 + - mkl_fft=1.3.1=py38h277e83a_0 + - mkl_random=1.2.2=py38hf11a4ad_0 + - munkres=1.1.4=py_0 + - nbclient=0.6.4=pyhd8ed1ab_1 + - nbconvert=6.5.0=pyhd8ed1ab_0 + - nbconvert-core=6.5.0=pyhd8ed1ab_0 + - nbconvert-pandoc=6.5.0=pyhd8ed1ab_0 + - nbformat=5.4.0=pyhd8ed1ab_0 + - nest-asyncio=1.5.5=pyhd8ed1ab_0 + - notebook=6.4.12=pyha770c72_0 + - numpy=1.22.3=py38h7a0a035_0 + - numpy-base=1.22.3=py38hca35cd5_0 + - openssl=1.1.1o=h8ffe710_0 + - packaging=21.3=pyhd3eb1b0_0 + - pandoc=2.18=h57928b3_0 + - pandocfilters=1.5.0=pyhd8ed1ab_0 + - parso=0.8.3=pyhd8ed1ab_0 + - pickleshare=0.7.5=py_1003 + - pillow=9.0.1=py38hdc2b20a_0 + - pip=21.2.2=py38haa95532_0 + - prometheus_client=0.14.1=pyhd8ed1ab_0 + - prompt-toolkit=3.0.29=pyha770c72_0 + - psutil=5.9.1=py38h294d835_0 + - pure_eval=0.2.2=pyhd8ed1ab_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pygments=2.12.0=pyhd8ed1ab_0 + - pyopenssl=22.0.0=pyhd3eb1b0_0 + - pyparsing=3.0.4=pyhd3eb1b0_0 + - pyqt=5.9.2=py38hd77b12b_6 + - pyrsistent=0.18.1=py38h294d835_1 + - pysocks=1.7.1=py38haa95532_0 + - python=3.8.13=h6244533_0 + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - python-fastjsonschema=2.15.3=pyhd8ed1ab_0 + - python_abi=3.8=2_cp38 + - pytorch=1.11.0=py3.8_cuda11.3_cudnn8_0 + - pytorch-mutex=1.0=cuda + - pywin32=303=py38h294d835_0 + - pywinpty=2.0.2=py38h5da7b33_0 + - pyzmq=23.1.0=py38h09162b1_0 + - qt=5.9.7=vc14h73c81de_0 + - requests=2.27.1=pyhd3eb1b0_0 + - scipy=1.7.3=py38h0a974cb_0 + - send2trash=1.8.0=pyhd8ed1ab_0 + - setuptools=61.2.0=py38haa95532_0 + - sip=4.19.13=py38hd77b12b_0 + - six=1.16.0=pyhd3eb1b0_1 + - soupsieve=2.3.1=pyhd8ed1ab_0 + - sqlite=3.38.3=h2bbff1b_0 + - stack_data=0.2.0=pyhd8ed1ab_0 + - terminado=0.15.0=py38haa244fe_0 + - tinycss2=1.1.1=pyhd8ed1ab_0 + - tk=8.6.12=h2bbff1b_0 + - torchaudio=0.11.0=py38_cu113 + - torchvision=0.12.0=py38_cu113 + - tornado=6.1=py38h2bbff1b_0 + - traitlets=5.2.2.post1=pyhd8ed1ab_0 + - typing_extensions=4.1.1=pyh06a4308_0 + - urllib3=1.26.9=py38haa95532_0 + - vc=14.2=h21ff451_1 + - vs2015_runtime=14.27.29016=h5e58377_2 + - wcwidth=0.2.5=pyh9f0ad1d_2 + - webencodings=0.5.1=py_1 + - wheel=0.37.1=pyhd3eb1b0_0 + - widgetsnbextension=3.6.0=py38haa244fe_0 + - win_inet_pton=1.1.0=py38haa95532_0 + - wincertstore=0.2=py38haa95532_2 + - winpty=0.4.3=4 + - xz=5.2.5=h8cc25b3_1 + - zeromq=4.3.4=h0e60522_1 + - zipp=3.8.0=pyhd8ed1ab_0 + - zlib=1.2.12=h8cc25b3_2 + - zstd=1.5.2=h19a0ad4_0 + - pip: + - wxpython==4.1.1 diff --git a/manual_poser.ipynb b/manual_poser.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..eb9fe237ce66249d3ab0265a2c142afa2635b046 --- /dev/null +++ b/manual_poser.ipynb @@ -0,0 +1,460 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "062014f7", + "metadata": { + "id": "breeding-extra" + }, + "outputs": [], + "source": [ + "import torch\n", + "MODEL_NAME = \"standard_float\"\n", + "DEVICE_NAME = 'cuda'\n", + "device = torch.device(DEVICE_NAME)\n", + "\n", + "def load_poser(model: str, device: torch.device):\n", + " print(\"Using the %s model.\" % model)\n", + " if model == \"standard_float\":\n", + " from tha3.poser.modes.standard_float import create_poser\n", + " return create_poser(device)\n", + " elif model == \"standard_half\":\n", + " from tha3.poser.modes.standard_half import create_poser\n", + " return create_poser(device)\n", + " elif model == \"separable_float\":\n", + " from tha3.poser.modes.separable_float import create_poser\n", + " return create_poser(device)\n", + " elif model == \"separable_half\":\n", + " from tha3.poser.modes.separable_half import create_poser\n", + " return create_poser(device)\n", + " else:\n", + " raise RuntimeError(\"Invalid model: '%s'\" % model)\n", + " \n", + "poser = load_poser(MODEL_NAME, DEVICE_NAME)\n", + "poser.get_modules();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "breeding-extra", + "metadata": { + "id": "breeding-extra" + }, + "outputs": [], + "source": [ + "import PIL.Image\n", + "import io\n", + "from io import StringIO, BytesIO\n", + "import IPython.display\n", + "import numpy\n", + "import ipywidgets\n", + "import time\n", + "import threading\n", + "import torch\n", + "from tha3.util import resize_PIL_image, extract_PIL_image_from_filelike, \\\n", + " extract_pytorch_image_from_PIL_image, convert_output_image_from_torch_to_numpy\n", + "\n", + "FRAME_RATE = 30.0\n", + "\n", + "last_torch_input_image = None\n", + "torch_input_image = None\n", + "\n", + "def show_pytorch_image(pytorch_image):\n", + " output_image = pytorch_image.detach().cpu()\n", + " numpy_image = numpy.uint8(numpy.rint(convert_output_image_from_torch_to_numpy(output_image) * 255.0))\n", + " pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')\n", + " IPython.display.display(pil_image)\n", + "\n", + "upload_input_image_button = ipywidgets.FileUpload(\n", + " accept='.png',\n", + " multiple=False,\n", + " layout={\n", + " 'width': '512px'\n", + " }\n", + ")\n", + "\n", + "output_image_widget = ipywidgets.Output(\n", + " layout={\n", + " 'border': '1px solid black',\n", + " 'width': '512px',\n", + " 'height': '512px'\n", + " }\n", + ")\n", + "\n", + "eyebrow_dropdown = ipywidgets.Dropdown(\n", + " options=[\"troubled\", \"angry\", \"lowered\", \"raised\", \"happy\", \"serious\"],\n", + " value=\"troubled\",\n", + " description=\"Eyebrow:\", \n", + ")\n", + "eyebrow_left_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Left:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "eyebrow_right_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Right:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "\n", + "eye_dropdown = ipywidgets.Dropdown(\n", + " options=[\"wink\", \"happy_wink\", \"surprised\", \"relaxed\", \"unimpressed\", \"raised_lower_eyelid\"],\n", + " value=\"wink\",\n", + " description=\"Eye:\", \n", + ")\n", + "eye_left_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Left:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "eye_right_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Right:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "\n", + "mouth_dropdown = ipywidgets.Dropdown(\n", + " options=[\"aaa\", \"iii\", \"uuu\", \"eee\", \"ooo\", \"delta\", \"lowered_corner\", \"raised_corner\", \"smirk\"],\n", + " value=\"aaa\",\n", + " description=\"Mouth:\", \n", + ")\n", + "mouth_left_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Value:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "mouth_right_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\" \",\n", + " readout=True,\n", + " readout_format=\".2f\",\n", + " disabled=True,\n", + ")\n", + "\n", + "def update_mouth_sliders(change):\n", + " if mouth_dropdown.value == \"lowered_corner\" or mouth_dropdown.value == \"raised_corner\":\n", + " mouth_left_slider.description = \"Left:\"\n", + " mouth_right_slider.description = \"Right:\"\n", + " mouth_right_slider.disabled = False\n", + " else:\n", + " mouth_left_slider.description = \"Value:\"\n", + " mouth_right_slider.description = \" \"\n", + " mouth_right_slider.disabled = True\n", + "\n", + "mouth_dropdown.observe(update_mouth_sliders, names='value')\n", + "\n", + "iris_small_left_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Left:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "iris_small_right_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Right:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "iris_rotation_x_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"X-axis:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "iris_rotation_y_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Y-axis:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "\n", + "head_x_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"X-axis:\",\n", + " readout=True,\n", + " readout_format=\".2f\"\n", + ")\n", + "head_y_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Y-axis:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "neck_z_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Z-axis:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "body_y_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Y-axis rotation:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "body_z_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=-1.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Z-axis rotation:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "breathing_slider = ipywidgets.FloatSlider(\n", + " value=0.0,\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " description=\"Breathing:\",\n", + " readout=True,\n", + " readout_format=\".2f\", \n", + ")\n", + "\n", + "\n", + "control_panel = ipywidgets.VBox([\n", + " eyebrow_dropdown,\n", + " eyebrow_left_slider,\n", + " eyebrow_right_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " eye_dropdown,\n", + " eye_left_slider,\n", + " eye_right_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " mouth_dropdown,\n", + " mouth_left_slider,\n", + " mouth_right_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " ipywidgets.HTML(value=\"
Iris Shrinkage
\"),\n", + " iris_small_left_slider,\n", + " iris_small_right_slider,\n", + " ipywidgets.HTML(value=\"
Iris Rotation
\"),\n", + " iris_rotation_x_slider,\n", + " iris_rotation_y_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " ipywidgets.HTML(value=\"
Head Rotation
\"),\n", + " head_x_slider,\n", + " head_y_slider,\n", + " neck_z_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " ipywidgets.HTML(value=\"
Body Rotation
\"),\n", + " body_y_slider,\n", + " body_z_slider,\n", + " ipywidgets.HTML(value=\"
\"),\n", + " ipywidgets.HTML(value=\"
Breathing
\"),\n", + " breathing_slider,\n", + "])\n", + "\n", + "controls = ipywidgets.HBox([\n", + " ipywidgets.VBox([\n", + " output_image_widget, \n", + " upload_input_image_button\n", + " ]),\n", + " control_panel,\n", + "])\n", + "\n", + "from tha3.poser.modes.pose_parameters import get_pose_parameters\n", + "pose_parameters = get_pose_parameters()\n", + "pose_size = poser.get_num_parameters()\n", + "last_pose = torch.zeros(1, pose_size, dtype=poser.get_dtype()).to(device)\n", + "\n", + "iris_small_left_index = pose_parameters.get_parameter_index(\"iris_small_left\")\n", + "iris_small_right_index = pose_parameters.get_parameter_index(\"iris_small_right\")\n", + "iris_rotation_x_index = pose_parameters.get_parameter_index(\"iris_rotation_x\")\n", + "iris_rotation_y_index = pose_parameters.get_parameter_index(\"iris_rotation_y\")\n", + "head_x_index = pose_parameters.get_parameter_index(\"head_x\")\n", + "head_y_index = pose_parameters.get_parameter_index(\"head_y\")\n", + "neck_z_index = pose_parameters.get_parameter_index(\"neck_z\")\n", + "body_y_index = pose_parameters.get_parameter_index(\"body_y\")\n", + "body_z_index = pose_parameters.get_parameter_index(\"body_z\")\n", + "breathing_index = pose_parameters.get_parameter_index(\"breathing\")\n", + "\n", + "def get_pose():\n", + " pose = torch.zeros(1, pose_size, dtype=poser.get_dtype())\n", + "\n", + " eyebrow_name = f\"eyebrow_{eyebrow_dropdown.value}\"\n", + " eyebrow_left_index = pose_parameters.get_parameter_index(f\"{eyebrow_name}_left\")\n", + " eyebrow_right_index = pose_parameters.get_parameter_index(f\"{eyebrow_name}_right\")\n", + " pose[0, eyebrow_left_index] = eyebrow_left_slider.value\n", + " pose[0, eyebrow_right_index] = eyebrow_right_slider.value\n", + "\n", + " eye_name = f\"eye_{eye_dropdown.value}\"\n", + " eye_left_index = pose_parameters.get_parameter_index(f\"{eye_name}_left\")\n", + " eye_right_index = pose_parameters.get_parameter_index(f\"{eye_name}_right\")\n", + " pose[0, eye_left_index] = eye_left_slider.value\n", + " pose[0, eye_right_index] = eye_right_slider.value\n", + "\n", + " mouth_name = f\"mouth_{mouth_dropdown.value}\"\n", + " if mouth_name == \"mouth_lowered_corner\" or mouth_name == \"mouth_raised_corner\":\n", + " mouth_left_index = pose_parameters.get_parameter_index(f\"{mouth_name}_left\")\n", + " mouth_right_index = pose_parameters.get_parameter_index(f\"{mouth_name}_right\")\n", + " pose[0, mouth_left_index] = mouth_left_slider.value\n", + " pose[0, mouth_right_index] = mouth_right_slider.value\n", + " else:\n", + " mouth_index = pose_parameters.get_parameter_index(mouth_name)\n", + " pose[0, mouth_index] = mouth_left_slider.value\n", + "\n", + " pose[0, iris_small_left_index] = iris_small_left_slider.value\n", + " pose[0, iris_small_right_index] = iris_small_right_slider.value\n", + " pose[0, iris_rotation_x_index] = iris_rotation_x_slider.value\n", + " pose[0, iris_rotation_y_index] = iris_rotation_y_slider.value\n", + " pose[0, head_x_index] = head_x_slider.value\n", + " pose[0, head_y_index] = head_y_slider.value\n", + " pose[0, neck_z_index] = neck_z_slider.value\n", + " pose[0, body_y_index] = body_y_slider.value\n", + " pose[0, body_z_index] = body_z_slider.value\n", + " pose[0, breathing_index] = breathing_slider.value\n", + "\n", + " return pose.to(device)\n", + "\n", + "display(controls)\n", + "\n", + "def update(change):\n", + " global last_pose\n", + " global last_torch_input_image\n", + "\n", + " if torch_input_image is None:\n", + " return\n", + "\n", + " needs_update = False\n", + " if last_torch_input_image is None:\n", + " needs_update = True \n", + " else:\n", + " if (torch_input_image - last_torch_input_image).abs().max().item() > 0:\n", + " needs_update = True \n", + "\n", + " pose = get_pose()\n", + " if (pose - last_pose).abs().max().item() > 0:\n", + " needs_update = True\n", + "\n", + " if not needs_update:\n", + " return\n", + "\n", + " output_image = poser.pose(torch_input_image, pose)[0]\n", + " with output_image_widget:\n", + " output_image_widget.clear_output(wait=True)\n", + " show_pytorch_image(output_image) \n", + "\n", + " last_torch_input_image = torch_input_image\n", + " last_pose = pose\n", + "\n", + "def upload_image(change):\n", + " global torch_input_image\n", + " for name, file_info in upload_input_image_button.value.items():\n", + " content = io.BytesIO(file_info['content'])\n", + " if content is not None:\n", + " pil_image = resize_PIL_image(extract_PIL_image_from_filelike(content), size=(512,512))\n", + " w, h = pil_image.size\n", + " if pil_image.mode != 'RGBA':\n", + " with output_image_widget:\n", + " torch_input_image = None\n", + " output_image_widget.clear_output(wait=True)\n", + " display(ipywidgets.HTML(\"Image must have an alpha channel!!!\"))\n", + " else:\n", + " torch_input_image = extract_pytorch_image_from_PIL_image(pil_image).to(device)\n", + " if poser.get_dtype() == torch.half:\n", + " torch_input_image = torch_input_image.half()\n", + " update(None)\n", + "\n", + "upload_input_image_button.observe(upload_image, names='value')\n", + "eyebrow_dropdown.observe(update, 'value')\n", + "eyebrow_left_slider.observe(update, 'value')\n", + "eyebrow_right_slider.observe(update, 'value')\n", + "eye_dropdown.observe(update, 'value')\n", + "eye_left_slider.observe(update, 'value')\n", + "eye_right_slider.observe(update, 'value')\n", + "mouth_dropdown.observe(update, 'value')\n", + "mouth_left_slider.observe(update, 'value')\n", + "mouth_right_slider.observe(update, 'value')\n", + "iris_small_left_slider.observe(update, 'value')\n", + "iris_small_right_slider.observe(update, 'value')\n", + "iris_rotation_x_slider.observe(update, 'value')\n", + "iris_rotation_y_slider.observe(update, 'value')\n", + "head_x_slider.observe(update, 'value')\n", + "head_y_slider.observe(update, 'value')\n", + "neck_z_slider.observe(update, 'value')\n", + "body_y_slider.observe(update, 'value')\n", + "body_z_slider.observe(update, 'value')\n", + "breathing_slider.observe(update, 'value')" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "tha3.ipynb", + "provenance": [] + }, + "interpreter": { + "hash": "684906ad716c90e6f3397644b72c2a23821e93080f6b0264e4cd74aee22032ce" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tha3/__init__.py b/tha3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/app/__init__.py b/tha3/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/app/ifacialmocap_puppeteer.py b/tha3/app/ifacialmocap_puppeteer.py new file mode 100644 index 0000000000000000000000000000000000000000..0807f98138826b05eb79df493625877110e8d4a7 --- /dev/null +++ b/tha3/app/ifacialmocap_puppeteer.py @@ -0,0 +1,439 @@ +import argparse +import os +import socket +import sys +import threading +import time +from typing import Optional + +sys.path.append(os.getcwd()) + +from tha3.mocap.ifacialmocap_pose import create_default_ifacialmocap_pose +from tha3.mocap.ifacialmocap_v2 import IFACIALMOCAP_PORT, IFACIALMOCAP_START_STRING, parse_ifacialmocap_v2_pose, \ + parse_ifacialmocap_v1_pose +from tha3.poser.modes.load_poser import load_poser + +import torch +import wx + +from tha3.poser.poser import Poser +from tha3.mocap.ifacialmocap_constants import * +from tha3.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter +from tha3.util import torch_linear_to_srgb, resize_PIL_image, extract_PIL_image_from_filelike, \ + extract_pytorch_image_from_PIL_image + + +def convert_linear_to_srgb(image: torch.Tensor) -> torch.Tensor: + rgb_image = torch_linear_to_srgb(image[0:3, :, :]) + return torch.cat([rgb_image, image[3:4, :, :]], dim=0) + + +class FpsStatistics: + def __init__(self): + self.count = 100 + self.fps = [] + + def add_fps(self, fps): + self.fps.append(fps) + while len(self.fps) > self.count: + del self.fps[0] + + def get_average_fps(self): + if len(self.fps) == 0: + return 0.0 + else: + return sum(self.fps) / len(self.fps) + + +class MainFrame(wx.Frame): + def __init__(self, poser: Poser, pose_converter: IFacialMocapPoseConverter, device: torch.device): + super().__init__(None, wx.ID_ANY, "iFacialMocap Puppeteer (Marigold)") + self.pose_converter = pose_converter + self.poser = poser + self.device = device + + + self.ifacialmocap_pose = create_default_ifacialmocap_pose() + self.source_image_bitmap = wx.Bitmap(self.poser.get_image_size(), self.poser.get_image_size()) + self.result_image_bitmap = wx.Bitmap(self.poser.get_image_size(), self.poser.get_image_size()) + self.wx_source_image = None + self.torch_source_image = None + self.last_pose = None + self.fps_statistics = FpsStatistics() + self.last_update_time = None + + self.create_receiving_socket() + self.create_ui() + self.create_timers() + self.Bind(wx.EVT_CLOSE, self.on_close) + + self.update_source_image_bitmap() + self.update_result_image_bitmap() + + def create_receiving_socket(self): + self.receiving_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.receiving_socket.bind(("", IFACIALMOCAP_PORT)) + self.receiving_socket.setblocking(False) + + def create_timers(self): + self.capture_timer = wx.Timer(self, wx.ID_ANY) + self.Bind(wx.EVT_TIMER, self.update_capture_panel, id=self.capture_timer.GetId()) + self.animation_timer = wx.Timer(self, wx.ID_ANY) + self.Bind(wx.EVT_TIMER, self.update_result_image_bitmap, id=self.animation_timer.GetId()) + + def on_close(self, event: wx.Event): + # Stop the timers + self.animation_timer.Stop() + self.capture_timer.Stop() + + # Close receiving socket + self.receiving_socket.close() + + # Destroy the windows + self.Destroy() + event.Skip() + + def on_start_capture(self, event: wx.Event): + capture_device_ip_address = self.capture_device_ip_text_ctrl.GetValue() + out_socket = None + try: + address = (capture_device_ip_address, IFACIALMOCAP_PORT) + out_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + out_socket.sendto(IFACIALMOCAP_START_STRING, address) + except Exception as e: + message_dialog = wx.MessageDialog(self, str(e), "Error!", wx.OK) + message_dialog.ShowModal() + message_dialog.Destroy() + finally: + if out_socket is not None: + out_socket.close() + + def read_ifacialmocap_pose(self): + if not self.animation_timer.IsRunning(): + return self.ifacialmocap_pose + socket_bytes = None + while True: + try: + socket_bytes = self.receiving_socket.recv(8192) + except socket.error as e: + break + if socket_bytes is not None: + socket_string = socket_bytes.decode("utf-8") + self.ifacialmocap_pose = parse_ifacialmocap_v2_pose(socket_string) + return self.ifacialmocap_pose + + def on_erase_background(self, event: wx.Event): + pass + + def create_animation_panel(self, parent): + self.animation_panel = wx.Panel(parent, style=wx.RAISED_BORDER) + self.animation_panel_sizer = wx.BoxSizer(wx.HORIZONTAL) + self.animation_panel.SetSizer(self.animation_panel_sizer) + self.animation_panel.SetAutoLayout(1) + + image_size = self.poser.get_image_size() + + if True: + self.input_panel = wx.Panel(self.animation_panel, size=(image_size, image_size + 128), + style=wx.SIMPLE_BORDER) + self.input_panel_sizer = wx.BoxSizer(wx.VERTICAL) + self.input_panel.SetSizer(self.input_panel_sizer) + self.input_panel.SetAutoLayout(1) + self.animation_panel_sizer.Add(self.input_panel, 0, wx.FIXED_MINSIZE) + + self.source_image_panel = wx.Panel(self.input_panel, size=(image_size, image_size), style=wx.SIMPLE_BORDER) + self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel) + self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) + self.input_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE) + + self.load_image_button = wx.Button(self.input_panel, wx.ID_ANY, "Load Image") + self.input_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND) + self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image) + + self.input_panel_sizer.Fit(self.input_panel) + + if True: + self.pose_converter.init_pose_converter_panel(self.animation_panel) + + if True: + self.animation_left_panel = wx.Panel(self.animation_panel, style=wx.SIMPLE_BORDER) + self.animation_left_panel_sizer = wx.BoxSizer(wx.VERTICAL) + self.animation_left_panel.SetSizer(self.animation_left_panel_sizer) + self.animation_left_panel.SetAutoLayout(1) + self.animation_panel_sizer.Add(self.animation_left_panel, 0, wx.EXPAND) + + self.result_image_panel = wx.Panel(self.animation_left_panel, size=(image_size, image_size), + style=wx.SIMPLE_BORDER) + self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel) + self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) + self.animation_left_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE) + + separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5)) + self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND) + + background_text = wx.StaticText(self.animation_left_panel, label="--- Background ---", + style=wx.ALIGN_CENTER) + self.animation_left_panel_sizer.Add(background_text, 0, wx.EXPAND) + + self.output_background_choice = wx.Choice( + self.animation_left_panel, + choices=[ + "TRANSPARENT", + "GREEN", + "BLUE", + "BLACK", + "WHITE" + ]) + self.output_background_choice.SetSelection(0) + self.animation_left_panel_sizer.Add(self.output_background_choice, 0, wx.EXPAND) + + separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5)) + self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND) + + self.fps_text = wx.StaticText(self.animation_left_panel, label="") + self.animation_left_panel_sizer.Add(self.fps_text, wx.SizerFlags().Border()) + + self.animation_left_panel_sizer.Fit(self.animation_left_panel) + + self.animation_panel_sizer.Fit(self.animation_panel) + + def create_ui(self): + self.main_sizer = wx.BoxSizer(wx.VERTICAL) + self.SetSizer(self.main_sizer) + self.SetAutoLayout(1) + + self.capture_pose_lock = threading.Lock() + + self.create_connection_panel(self) + self.main_sizer.Add(self.connection_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5)) + + self.create_animation_panel(self) + self.main_sizer.Add(self.animation_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5)) + + self.create_capture_panel(self) + self.main_sizer.Add(self.capture_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5)) + + self.main_sizer.Fit(self) + + def create_connection_panel(self, parent): + self.connection_panel = wx.Panel(parent, style=wx.RAISED_BORDER) + self.connection_panel_sizer = wx.BoxSizer(wx.HORIZONTAL) + self.connection_panel.SetSizer(self.connection_panel_sizer) + self.connection_panel.SetAutoLayout(1) + + capture_device_ip_text = wx.StaticText(self.connection_panel, label="Capture Device IP:", style=wx.ALIGN_RIGHT) + self.connection_panel_sizer.Add(capture_device_ip_text, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 3)) + + self.capture_device_ip_text_ctrl = wx.TextCtrl(self.connection_panel, value="192.168.0.1") + self.connection_panel_sizer.Add(self.capture_device_ip_text_ctrl, wx.SizerFlags(1).Expand().Border(wx.ALL, 3)) + + self.start_capture_button = wx.Button(self.connection_panel, label="START CAPTURE!") + self.connection_panel_sizer.Add(self.start_capture_button, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 3)) + self.start_capture_button.Bind(wx.EVT_BUTTON, self.on_start_capture) + + def create_capture_panel(self, parent): + self.capture_panel = wx.Panel(parent, style=wx.RAISED_BORDER) + self.capture_panel_sizer = wx.FlexGridSizer(cols=5) + for i in range(5): + self.capture_panel_sizer.AddGrowableCol(i) + self.capture_panel.SetSizer(self.capture_panel_sizer) + self.capture_panel.SetAutoLayout(1) + + self.rotation_labels = {} + self.rotation_value_labels = {} + rotation_column_0 = self.create_rotation_column(self.capture_panel, RIGHT_EYE_BONE_ROTATIONS) + self.capture_panel_sizer.Add(rotation_column_0, wx.SizerFlags(0).Expand().Border(wx.ALL, 3)) + rotation_column_1 = self.create_rotation_column(self.capture_panel, LEFT_EYE_BONE_ROTATIONS) + self.capture_panel_sizer.Add(rotation_column_1, wx.SizerFlags(0).Expand().Border(wx.ALL, 3)) + rotation_column_2 = self.create_rotation_column(self.capture_panel, HEAD_BONE_ROTATIONS) + self.capture_panel_sizer.Add(rotation_column_2, wx.SizerFlags(0).Expand().Border(wx.ALL, 3)) + + def create_rotation_column(self, parent, rotation_names): + column_panel = wx.Panel(parent, style=wx.SIMPLE_BORDER) + column_panel_sizer = wx.FlexGridSizer(cols=2) + column_panel_sizer.AddGrowableCol(1) + column_panel.SetSizer(column_panel_sizer) + column_panel.SetAutoLayout(1) + + for rotation_name in rotation_names: + self.rotation_labels[rotation_name] = wx.StaticText( + column_panel, label=rotation_name, style=wx.ALIGN_RIGHT) + column_panel_sizer.Add(self.rotation_labels[rotation_name], + wx.SizerFlags(1).Expand().Border(wx.ALL, 3)) + + self.rotation_value_labels[rotation_name] = wx.TextCtrl( + column_panel, style=wx.TE_RIGHT) + self.rotation_value_labels[rotation_name].SetValue("0.00") + self.rotation_value_labels[rotation_name].Disable() + column_panel_sizer.Add(self.rotation_value_labels[rotation_name], + wx.SizerFlags(1).Expand().Border(wx.ALL, 3)) + + column_panel.GetSizer().Fit(column_panel) + return column_panel + + def paint_capture_panel(self, event: wx.Event): + self.update_capture_panel(event) + + def update_capture_panel(self, event: wx.Event): + data = self.ifacialmocap_pose + for rotation_name in ROTATION_NAMES: + value = data[rotation_name] + self.rotation_value_labels[rotation_name].SetValue("%0.2f" % value) + + @staticmethod + def convert_to_100(x): + return int(max(0.0, min(1.0, x)) * 100) + + def paint_source_image_panel(self, event: wx.Event): + wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap) + + def update_source_image_bitmap(self): + dc = wx.MemoryDC() + dc.SelectObject(self.source_image_bitmap) + if self.wx_source_image is None: + self.draw_nothing_yet_string(dc) + else: + dc.Clear() + dc.DrawBitmap(self.wx_source_image, 0, 0, True) + del dc + + def draw_nothing_yet_string(self, dc): + dc.Clear() + font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS)) + dc.SetFont(font) + w, h = dc.GetTextExtent("Nothing yet!") + dc.DrawText("Nothing yet!", (self.poser.get_image_size() - w) // 2, (self.poser.get_image_size() - h) // 2) + + def paint_result_image_panel(self, event: wx.Event): + wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap) + + def update_result_image_bitmap(self, event: Optional[wx.Event] = None): + ifacialmocap_pose = self.read_ifacialmocap_pose() + current_pose = self.pose_converter.convert(ifacialmocap_pose) + if self.last_pose is not None and self.last_pose == current_pose: + return + self.last_pose = current_pose + + if self.torch_source_image is None: + dc = wx.MemoryDC() + dc.SelectObject(self.result_image_bitmap) + self.draw_nothing_yet_string(dc) + del dc + return + + pose = torch.tensor(current_pose, device=self.device, dtype=self.poser.get_dtype()) + + with torch.no_grad(): + output_image = self.poser.pose(self.torch_source_image, pose)[0].float() + output_image = convert_linear_to_srgb((output_image + 1.0) / 2.0) + + background_choice = self.output_background_choice.GetSelection() + if background_choice == 0: + pass + else: + background = torch.zeros(4, output_image.shape[1], output_image.shape[2], device=self.device) + background[3, :, :] = 1.0 + if background_choice == 1: + background[1, :, :] = 1.0 + output_image = self.blend_with_background(output_image, background) + elif background_choice == 2: + background[2, :, :] = 1.0 + output_image = self.blend_with_background(output_image, background) + elif background_choice == 3: + output_image = self.blend_with_background(output_image, background) + else: + background[0:3, :, :] = 1.0 + output_image = self.blend_with_background(output_image, background) + + c, h, w = output_image.shape + output_image = 255.0 * torch.transpose(output_image.reshape(c, h * w), 0, 1).reshape(h, w, c) + output_image = output_image.byte() + + numpy_image = output_image.detach().cpu().numpy() + wx_image = wx.ImageFromBuffer(numpy_image.shape[0], + numpy_image.shape[1], + numpy_image[:, :, 0:3].tobytes(), + numpy_image[:, :, 3].tobytes()) + wx_bitmap = wx_image.ConvertToBitmap() + + dc = wx.MemoryDC() + dc.SelectObject(self.result_image_bitmap) + dc.Clear() + dc.DrawBitmap(wx_bitmap, + (self.poser.get_image_size() - numpy_image.shape[0]) // 2, + (self.poser.get_image_size() - numpy_image.shape[1]) // 2, True) + del dc + + time_now = time.time_ns() + if self.last_update_time is not None: + elapsed_time = time_now - self.last_update_time + fps = 1.0 / (elapsed_time / 10**9) + if self.torch_source_image is not None: + self.fps_statistics.add_fps(fps) + self.fps_text.SetLabelText("FPS = %0.2f" % self.fps_statistics.get_average_fps()) + self.last_update_time = time_now + + self.Refresh() + + def blend_with_background(self, numpy_image, background): + alpha = numpy_image[3:4, :, :] + color = numpy_image[0:3, :, :] + new_color = color * alpha + (1.0 - alpha) * background[0:3, :, :] + return torch.cat([new_color, background[3:4, :, :]], dim=0) + + def load_image(self, event: wx.Event): + dir_name = "data/images" + file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_OPEN) + if file_dialog.ShowModal() == wx.ID_OK: + image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) + try: + pil_image = resize_PIL_image( + extract_PIL_image_from_filelike(image_file_name), + (self.poser.get_image_size(), self.poser.get_image_size())) + w, h = pil_image.size + if pil_image.mode != 'RGBA': + self.source_image_string = "Image must have alpha channel!" + self.wx_source_image = None + self.torch_source_image = None + else: + self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes()) + self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image) \ + .to(self.device).to(self.poser.get_dtype()) + self.update_source_image_bitmap() + except: + message_dialog = wx.MessageDialog(self, "Could not load image " + image_file_name, "Poser", wx.OK) + message_dialog.ShowModal() + message_dialog.Destroy() + file_dialog.Destroy() + self.Refresh() + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Control characters with movement captured by iFacialMocap.') + parser.add_argument( + '--model', + type=str, + required=False, + default='standard_float', + choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'], + help='The model to use.') + args = parser.parse_args() + + device = torch.device('cuda') + try: + poser = load_poser(args.model, device) + except RuntimeError as e: + print(e) + sys.exit() + + from tha3.mocap.ifacialmocap_poser_converter_25 import create_ifacialmocap_pose_converter + + pose_converter = create_ifacialmocap_pose_converter() + + app = wx.App() + main_frame = MainFrame(poser, pose_converter, device) + main_frame.Show(True) + main_frame.capture_timer.Start(10) + main_frame.animation_timer.Start(10) + app.MainLoop() diff --git a/tha3/app/manual_poser.py b/tha3/app/manual_poser.py new file mode 100644 index 0000000000000000000000000000000000000000..a89280f22ef3b736bbfef97edb7863b2f05135f7 --- /dev/null +++ b/tha3/app/manual_poser.py @@ -0,0 +1,464 @@ +import argparse +import logging +import os +import sys +from typing import List + +sys.path.append(os.getcwd()) + +import PIL.Image +import numpy +import torch +import wx + +from tha3.poser.modes.load_poser import load_poser +from tha3.poser.poser import Poser, PoseParameterCategory, PoseParameterGroup +from tha3.util import extract_pytorch_image_from_filelike, rgba_to_numpy_image, grid_change_to_numpy_image, \ + rgb_to_numpy_image, resize_PIL_image, extract_PIL_image_from_filelike, extract_pytorch_image_from_PIL_image + + +class MorphCategoryControlPanel(wx.Panel): + def __init__(self, + parent, + title: str, + pose_param_category: PoseParameterCategory, + param_groups: List[PoseParameterGroup]): + super().__init__(parent, style=wx.SIMPLE_BORDER) + self.pose_param_category = pose_param_category + self.sizer = wx.BoxSizer(wx.VERTICAL) + self.SetSizer(self.sizer) + self.SetAutoLayout(1) + + title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER) + self.sizer.Add(title_text, 0, wx.EXPAND) + + self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category] + self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups]) + if len(self.param_groups) > 0: + self.choice.SetSelection(0) + self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated) + self.sizer.Add(self.choice, 0, wx.EXPAND) + + self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL) + self.sizer.Add(self.left_slider, 0, wx.EXPAND) + + self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL) + self.sizer.Add(self.right_slider, 0, wx.EXPAND) + + self.checkbox = wx.CheckBox(self, label="Show") + self.checkbox.SetValue(True) + self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER) + + self.update_ui() + + self.sizer.Fit(self) + + def update_ui(self): + param_group = self.param_groups[self.choice.GetSelection()] + if param_group.is_discrete(): + self.left_slider.Enable(False) + self.right_slider.Enable(False) + self.checkbox.Enable(True) + elif param_group.get_arity() == 1: + self.left_slider.Enable(True) + self.right_slider.Enable(False) + self.checkbox.Enable(False) + else: + self.left_slider.Enable(True) + self.right_slider.Enable(True) + self.checkbox.Enable(False) + + def on_choice_updated(self, event: wx.Event): + param_group = self.param_groups[self.choice.GetSelection()] + if param_group.is_discrete(): + self.checkbox.SetValue(True) + self.update_ui() + + def set_param_value(self, pose: List[float]): + if len(self.param_groups) == 0: + return + selected_morph_index = self.choice.GetSelection() + param_group = self.param_groups[selected_morph_index] + param_index = param_group.get_parameter_index() + if param_group.is_discrete(): + if self.checkbox.GetValue(): + for i in range(param_group.get_arity()): + pose[param_index + i] = 1.0 + else: + param_range = param_group.get_range() + alpha = (self.left_slider.GetValue() + 1000) / 2000.0 + pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha + if param_group.get_arity() == 2: + alpha = (self.right_slider.GetValue() + 1000) / 2000.0 + pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha + + +class SimpleParamGroupsControlPanel(wx.Panel): + def __init__(self, parent, + pose_param_category: PoseParameterCategory, + param_groups: List[PoseParameterGroup]): + super().__init__(parent, style=wx.SIMPLE_BORDER) + self.sizer = wx.BoxSizer(wx.VERTICAL) + self.SetSizer(self.sizer) + self.SetAutoLayout(1) + + self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category] + for param_group in self.param_groups: + assert not param_group.is_discrete() + assert param_group.get_arity() == 1 + + self.sliders = [] + for param_group in self.param_groups: + static_text = wx.StaticText( + self, + label=" ------------ %s ------------ " % param_group.get_group_name(), style=wx.ALIGN_CENTER) + self.sizer.Add(static_text, 0, wx.EXPAND) + range = param_group.get_range() + min_value = int(range[0] * 1000) + max_value = int(range[1] * 1000) + slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL) + self.sizer.Add(slider, 0, wx.EXPAND) + self.sliders.append(slider) + + self.sizer.Fit(self) + + def set_param_value(self, pose: List[float]): + if len(self.param_groups) == 0: + return + for param_group_index in range(len(self.param_groups)): + param_group = self.param_groups[param_group_index] + slider = self.sliders[param_group_index] + param_range = param_group.get_range() + param_index = param_group.get_parameter_index() + alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin()) + pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha + + +def convert_output_image_from_torch_to_numpy(output_image): + if output_image.shape[2] == 2: + h, w, c = output_image.shape + numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w) + elif output_image.shape[0] == 4: + numpy_image = rgba_to_numpy_image(output_image) + elif output_image.shape[0] == 3: + numpy_image = rgb_to_numpy_image(output_image) + elif output_image.shape[0] == 1: + c, h, w = output_image.shape + alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0) + numpy_image = rgba_to_numpy_image(alpha_image) + elif output_image.shape[0] == 2: + numpy_image = grid_change_to_numpy_image(output_image, num_channels=4) + else: + raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0]) + numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0)) + return numpy_image + + +class MainFrame(wx.Frame): + def __init__(self, poser: Poser, device: torch.device): + super().__init__(None, wx.ID_ANY, "Poser") + self.poser = poser + self.dtype = self.poser.get_dtype() + self.device = device + self.image_size = self.poser.get_image_size() + + self.wx_source_image = None + self.torch_source_image = None + + self.main_sizer = wx.BoxSizer(wx.HORIZONTAL) + self.SetSizer(self.main_sizer) + self.SetAutoLayout(1) + self.init_left_panel() + self.init_control_panel() + self.init_right_panel() + self.main_sizer.Fit(self) + + self.timer = wx.Timer(self, wx.ID_ANY) + self.Bind(wx.EVT_TIMER, self.update_images, self.timer) + + save_image_id = wx.NewIdRef() + self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id) + accelerator_table = wx.AcceleratorTable([ + (wx.ACCEL_CTRL, ord('S'), save_image_id) + ]) + self.SetAcceleratorTable(accelerator_table) + + self.last_pose = None + self.last_output_index = self.output_index_choice.GetSelection() + self.last_output_numpy_image = None + + self.wx_source_image = None + self.torch_source_image = None + self.source_image_bitmap = wx.Bitmap(self.image_size, self.image_size) + self.result_image_bitmap = wx.Bitmap(self.image_size, self.image_size) + self.source_image_dirty = True + + def init_left_panel(self): + self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(self.image_size, -1)) + self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER) + left_panel_sizer = wx.BoxSizer(wx.VERTICAL) + self.left_panel.SetSizer(left_panel_sizer) + self.left_panel.SetAutoLayout(1) + + self.source_image_panel = wx.Panel(self.left_panel, size=(self.image_size, self.image_size), + style=wx.SIMPLE_BORDER) + self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel) + self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) + left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE) + + self.load_image_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad Image\n\n") + left_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND) + self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image) + + left_panel_sizer.Fit(self.left_panel) + self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE) + + def on_erase_background(self, event: wx.Event): + pass + + def init_control_panel(self): + self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL) + self.control_panel.SetSizer(self.control_panel_sizer) + self.control_panel.SetMinSize(wx.Size(256, 1)) + + morph_categories = [ + PoseParameterCategory.EYEBROW, + PoseParameterCategory.EYE, + PoseParameterCategory.MOUTH, + PoseParameterCategory.IRIS_MORPH + ] + morph_category_titles = { + PoseParameterCategory.EYEBROW: " ------------ Eyebrow ------------ ", + PoseParameterCategory.EYE: " ------------ Eye ------------ ", + PoseParameterCategory.MOUTH: " ------------ Mouth ------------ ", + PoseParameterCategory.IRIS_MORPH: " ------------ Iris morphs ------------ ", + } + self.morph_control_panels = {} + for category in morph_categories: + param_groups = self.poser.get_pose_parameter_groups() + filtered_param_groups = [group for group in param_groups if group.get_category() == category] + if len(filtered_param_groups) == 0: + continue + control_panel = MorphCategoryControlPanel( + self.control_panel, + morph_category_titles[category], + category, + self.poser.get_pose_parameter_groups()) + self.morph_control_panels[category] = control_panel + self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND) + + self.non_morph_control_panels = {} + non_morph_categories = [ + PoseParameterCategory.IRIS_ROTATION, + PoseParameterCategory.FACE_ROTATION, + PoseParameterCategory.BODY_ROTATION, + PoseParameterCategory.BREATHING + ] + for category in non_morph_categories: + param_groups = self.poser.get_pose_parameter_groups() + filtered_param_groups = [group for group in param_groups if group.get_category() == category] + if len(filtered_param_groups) == 0: + continue + control_panel = SimpleParamGroupsControlPanel( + self.control_panel, + category, + self.poser.get_pose_parameter_groups()) + self.non_morph_control_panels[category] = control_panel + self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND) + + self.control_panel_sizer.Fit(self.control_panel) + self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE) + + def init_right_panel(self): + self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER) + right_panel_sizer = wx.BoxSizer(wx.VERTICAL) + self.right_panel.SetSizer(right_panel_sizer) + self.right_panel.SetAutoLayout(1) + + self.result_image_panel = wx.Panel(self.right_panel, + size=(self.image_size, self.image_size), + style=wx.SIMPLE_BORDER) + self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel) + self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) + self.output_index_choice = wx.Choice( + self.right_panel, + choices=[str(i) for i in range(self.poser.get_output_length())]) + self.output_index_choice.SetSelection(0) + right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE) + right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND) + + self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave Image\n\n") + right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND) + self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image) + + right_panel_sizer.Fit(self.right_panel) + self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE) + + def create_param_category_choice(self, param_category: PoseParameterCategory): + params = [] + for param_group in self.poser.get_pose_parameter_groups(): + if param_group.get_category() == param_category: + params.append(param_group.get_group_name()) + choice = wx.Choice(self.control_panel, choices=params) + if len(params) > 0: + choice.SetSelection(0) + return choice + + def load_image(self, event: wx.Event): + dir_name = "data/images" + file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_OPEN) + if file_dialog.ShowModal() == wx.ID_OK: + image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) + try: + pil_image = resize_PIL_image(extract_PIL_image_from_filelike(image_file_name), + (self.poser.get_image_size(), self.poser.get_image_size())) + w, h = pil_image.size + if pil_image.mode != 'RGBA': + self.source_image_string = "Image must have alpha channel!" + self.wx_source_image = None + self.torch_source_image = None + else: + self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes()) + self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image)\ + .to(self.device).to(self.dtype) + self.source_image_dirty = True + self.Refresh() + self.Update() + except: + message_dialog = wx.MessageDialog(self, "Could not load image " + image_file_name, "Poser", wx.OK) + message_dialog.ShowModal() + message_dialog.Destroy() + file_dialog.Destroy() + + def paint_source_image_panel(self, event: wx.Event): + wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap) + + def paint_result_image_panel(self, event: wx.Event): + wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap) + + def draw_nothing_yet_string_to_bitmap(self, bitmap): + dc = wx.MemoryDC() + dc.SelectObject(bitmap) + + dc.Clear() + font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS)) + dc.SetFont(font) + w, h = dc.GetTextExtent("Nothing yet!") + dc.DrawText("Nothing yet!", (self.image_size - w) // 2, (self.image_size - - h) // 2) + + del dc + + def get_current_pose(self): + current_pose = [0.0 for i in range(self.poser.get_num_parameters())] + for morph_control_panel in self.morph_control_panels.values(): + morph_control_panel.set_param_value(current_pose) + for rotation_control_panel in self.non_morph_control_panels.values(): + rotation_control_panel.set_param_value(current_pose) + return current_pose + + def update_images(self, event: wx.Event): + current_pose = self.get_current_pose() + if not self.source_image_dirty \ + and self.last_pose is not None \ + and self.last_pose == current_pose \ + and self.last_output_index == self.output_index_choice.GetSelection(): + return + self.last_pose = current_pose + self.last_output_index = self.output_index_choice.GetSelection() + + if self.torch_source_image is None: + self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap) + self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap) + self.source_image_dirty = False + self.Refresh() + self.Update() + return + + if self.source_image_dirty: + dc = wx.MemoryDC() + dc.SelectObject(self.source_image_bitmap) + dc.Clear() + dc.DrawBitmap(self.wx_source_image, 0, 0) + self.source_image_dirty = False + + pose = torch.tensor(current_pose, device=self.device, dtype=self.dtype) + output_index = self.output_index_choice.GetSelection() + with torch.no_grad(): + output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu() + + numpy_image = convert_output_image_from_torch_to_numpy(output_image) + self.last_output_numpy_image = numpy_image + wx_image = wx.ImageFromBuffer( + numpy_image.shape[0], + numpy_image.shape[1], + numpy_image[:, :, 0:3].tobytes(), + numpy_image[:, :, 3].tobytes()) + wx_bitmap = wx_image.ConvertToBitmap() + + dc = wx.MemoryDC() + dc.SelectObject(self.result_image_bitmap) + dc.Clear() + dc.DrawBitmap(wx_bitmap, + (self.image_size - numpy_image.shape[0]) // 2, + (self.image_size - numpy_image.shape[1]) // 2, + True) + del dc + + self.Refresh() + self.Update() + + def on_save_image(self, event: wx.Event): + if self.last_output_numpy_image is None: + logging.info("There is no output image to save!!!") + return + + dir_name = "data/images" + file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_SAVE) + if file_dialog.ShowModal() == wx.ID_OK: + image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) + try: + if os.path.exists(image_file_name): + message_dialog = wx.MessageDialog(self, f"Override {image_file_name}", "Manual Poser", + wx.YES_NO | wx.ICON_QUESTION) + result = message_dialog.ShowModal() + if result == wx.ID_YES: + self.save_last_numpy_image(image_file_name) + message_dialog.Destroy() + else: + self.save_last_numpy_image(image_file_name) + except: + message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}", "Manual Poser", wx.OK) + message_dialog.ShowModal() + message_dialog.Destroy() + file_dialog.Destroy() + + def save_last_numpy_image(self, image_file_name): + numpy_image = self.last_output_numpy_image + pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA') + os.makedirs(os.path.dirname(image_file_name), exist_ok=True) + pil_image.save(image_file_name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Manually pose a character image.') + parser.add_argument( + '--model', + type=str, + required=False, + default='standard_float', + choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'], + help='The model to use.') + args = parser.parse_args() + + device = torch.device('cuda') + try: + poser = load_poser(args.model, device) + except RuntimeError as e: + print(e) + sys.exit() + + app = wx.App() + main_frame = MainFrame(poser, device) + main_frame.Show(True) + main_frame.timer.Start(30) + app.MainLoop() diff --git a/tha3/compute/__init__.py b/tha3/compute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/compute/cached_computation_func.py b/tha3/compute/cached_computation_func.py new file mode 100644 index 0000000000000000000000000000000000000000..4641629c1bc2ea2d8a3409a95bc2ae9dadd58289 --- /dev/null +++ b/tha3/compute/cached_computation_func.py @@ -0,0 +1,9 @@ +from typing import Callable, Dict, List + +from torch import Tensor +from torch.nn import Module + +TensorCachedComputationFunc = Callable[ + [Dict[str, Module], List[Tensor], Dict[str, List[Tensor]]], Tensor] +TensorListCachedComputationFunc = Callable[ + [Dict[str, Module], List[Tensor], Dict[str, List[Tensor]]], List[Tensor]] diff --git a/tha3/compute/cached_computation_protocol.py b/tha3/compute/cached_computation_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f352e1247434d05778f21bf3c6891565f9bd6f --- /dev/null +++ b/tha3/compute/cached_computation_protocol.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Dict, List + +from torch import Tensor +from torch.nn import Module + +from tha3.compute.cached_computation_func import TensorCachedComputationFunc, TensorListCachedComputationFunc + + +class CachedComputationProtocol(ABC): + def get_output(self, + key: str, + modules: Dict[str, Module], + batch: List[Tensor], + outputs: Dict[str, List[Tensor]]): + if key in outputs: + return outputs[key] + else: + output = self.compute_output(key, modules, batch, outputs) + outputs[key] = output + return outputs[key] + + @abstractmethod + def compute_output(self, + key: str, + modules: Dict[str, Module], + batch: List[Tensor], + outputs: Dict[str, List[Tensor]]) -> List[Tensor]: + pass + + def get_output_tensor_func(self, key: str, index: int) -> TensorCachedComputationFunc: + def func(modules: Dict[str, Module], + batch: List[Tensor], + outputs: Dict[str, List[Tensor]]): + return self.get_output(key, modules, batch, outputs)[index] + return func + + def get_output_tensor_list_func(self, key: str) -> TensorListCachedComputationFunc: + def func(modules: Dict[str, Module], + batch: List[Tensor], + outputs: Dict[str, List[Tensor]]): + return self.get_output(key, modules, batch, outputs) + return func \ No newline at end of file diff --git a/tha3/mocap/__init__.py b/tha3/mocap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/mocap/ifacialmocap_constants.py b/tha3/mocap/ifacialmocap_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..27031ac728a0a77d6e21ff50f6bfbcafc6a1131b --- /dev/null +++ b/tha3/mocap/ifacialmocap_constants.py @@ -0,0 +1,239 @@ +EYE_LOOK_IN_LEFT = "eyeLookInLeft" +EYE_LOOK_OUT_LEFT = "eyeLookOutLeft" +EYE_LOOK_DOWN_LEFT = "eyeLookDownLeft" +EYE_LOOK_UP_LEFT = "eyeLookUpLeft" +EYE_BLINK_LEFT = "eyeBlinkLeft" +EYE_SQUINT_LEFT = "eyeSquintLeft" +EYE_WIDE_LEFT = "eyeWideLeft" +EYE_LOOK_IN_RIGHT = "eyeLookInRight" +EYE_LOOK_OUT_RIGHT = "eyeLookOutRight" +EYE_LOOK_DOWN_RIGHT = "eyeLookDownRight" +EYE_LOOK_UP_RIGHT = "eyeLookUpRight" +EYE_BLINK_RIGHT = "eyeBlinkRight" +EYE_SQUINT_RIGHT = "eyeSquintRight" +EYE_WIDE_RIGHT = "eyeWideRight" +BROW_DOWN_LEFT = "browDownLeft" +BROW_OUTER_UP_LEFT = "browOuterUpLeft" +BROW_DOWN_RIGHT = "browDownRight" +BROW_OUTER_UP_RIGHT = "browOuterUpRight" +BROW_INNER_UP = "browInnerUp" +NOSE_SNEER_LEFT = "noseSneerLeft" +NOSE_SNEER_RIGHT = "noseSneerRight" +CHEEK_SQUINT_LEFT = "cheekSquintLeft" +CHEEK_SQUINT_RIGHT = "cheekSquintRight" +CHEEK_PUFF = "cheekPuff" +MOUTH_LEFT = "mouthLeft" +MOUTH_DIMPLE_LEFT = "mouthDimpleLeft" +MOUTH_FROWN_LEFT = "mouthFrownLeft" +MOUTH_LOWER_DOWN_LEFT = "mouthLowerDownLeft" +MOUTH_PRESS_LEFT = "mouthPressLeft" +MOUTH_SMILE_LEFT = "mouthSmileLeft" +MOUTH_STRETCH_LEFT = "mouthStretchLeft" +MOUTH_UPPER_UP_LEFT = "mouthUpperUpLeft" +MOUTH_RIGHT = "mouthRight" +MOUTH_DIMPLE_RIGHT = "mouthDimpleRight" +MOUTH_FROWN_RIGHT = "mouthFrownRight" +MOUTH_LOWER_DOWN_RIGHT = "mouthLowerDownRight" +MOUTH_PRESS_RIGHT = "mouthPressRight" +MOUTH_SMILE_RIGHT = "mouthSmileRight" +MOUTH_STRETCH_RIGHT = "mouthStretchRight" +MOUTH_UPPER_UP_RIGHT = "mouthUpperUpRight" +MOUTH_CLOSE = "mouthClose" +MOUTH_FUNNEL = "mouthFunnel" +MOUTH_PUCKER = "mouthPucker" +MOUTH_ROLL_LOWER = "mouthRollLower" +MOUTH_ROLL_UPPER = "mouthRollUpper" +MOUTH_SHRUG_LOWER = "mouthShrugLower" +MOUTH_SHRUG_UPPER = "mouthShrugUpper" +JAW_LEFT = "jawLeft" +JAW_RIGHT = "jawRight" +JAW_FORWARD = "jawForward" +JAW_OPEN = "jawOpen" +TONGUE_OUT = "tongueOut" + +BLENDSHAPE_NAMES = [ + EYE_LOOK_IN_LEFT, # 0 + EYE_LOOK_OUT_LEFT, # 1 + EYE_LOOK_DOWN_LEFT, # 2 + EYE_LOOK_UP_LEFT, # 3 + EYE_BLINK_LEFT, # 4 + EYE_SQUINT_LEFT, # 5 + EYE_WIDE_LEFT, # 6 + EYE_LOOK_IN_RIGHT, # 7 + EYE_LOOK_OUT_RIGHT, # 8 + EYE_LOOK_DOWN_RIGHT, # 9 + EYE_LOOK_UP_RIGHT, # 10 + EYE_BLINK_RIGHT, # 11 + EYE_SQUINT_RIGHT, # 12 + EYE_WIDE_RIGHT, # 13 + BROW_DOWN_LEFT, # 14 + BROW_OUTER_UP_LEFT, # 15 + BROW_DOWN_RIGHT, # 16 + BROW_OUTER_UP_RIGHT, # 17 + BROW_INNER_UP, # 18 + NOSE_SNEER_LEFT, # 19 + NOSE_SNEER_RIGHT, # 20 + CHEEK_SQUINT_LEFT, # 21 + CHEEK_SQUINT_RIGHT, # 22 + CHEEK_PUFF, # 23 + MOUTH_LEFT, # 24 + MOUTH_DIMPLE_LEFT, # 25 + MOUTH_FROWN_LEFT, # 26 + MOUTH_LOWER_DOWN_LEFT, # 27 + MOUTH_PRESS_LEFT, # 28 + MOUTH_SMILE_LEFT, # 29 + MOUTH_STRETCH_LEFT, # 30 + MOUTH_UPPER_UP_LEFT, # 31 + MOUTH_RIGHT, # 32 + MOUTH_DIMPLE_RIGHT, # 33 + MOUTH_FROWN_RIGHT, # 34 + MOUTH_LOWER_DOWN_RIGHT, # 35 + MOUTH_PRESS_RIGHT, # 36 + MOUTH_SMILE_RIGHT, # 37 + MOUTH_STRETCH_RIGHT, # 38 + MOUTH_UPPER_UP_RIGHT, # 39 + MOUTH_CLOSE, # 40 + MOUTH_FUNNEL, # 41 + MOUTH_PUCKER, # 42 + MOUTH_ROLL_LOWER, # 43 + MOUTH_ROLL_UPPER, # 44 + MOUTH_SHRUG_LOWER, # 45 + MOUTH_SHRUG_UPPER, # 46 + JAW_LEFT, # 47 + JAW_RIGHT, # 48 + JAW_FORWARD, # 49 + JAW_OPEN, # 50 + TONGUE_OUT, # 51 +] + +EYE_LEFT_BLENDSHAPES = [ + EYE_LOOK_IN_LEFT, # 0 + EYE_LOOK_OUT_LEFT, # 1 + EYE_LOOK_DOWN_LEFT, # 2 + EYE_LOOK_UP_LEFT, # 3 + EYE_BLINK_LEFT, # 4 + EYE_SQUINT_LEFT, # 5 + EYE_WIDE_LEFT, # 6 +] + +EYE_RIGHT_BLENDSHAPES = [ + EYE_LOOK_IN_RIGHT, # 7 + EYE_LOOK_OUT_RIGHT, # 8 + EYE_LOOK_DOWN_RIGHT, # 9 + EYE_LOOK_UP_RIGHT, # 10 + EYE_BLINK_RIGHT, # 11 + EYE_SQUINT_RIGHT, # 12 + EYE_WIDE_RIGHT, # 13 +] + +BROW_LEFT_BLENDSHAPES = [ + BROW_DOWN_LEFT, # 14 + BROW_OUTER_UP_LEFT, # 15 + +] + +BROW_RIGHT_BLENDSHAPES = [ + BROW_DOWN_RIGHT, # 16 + BROW_OUTER_UP_RIGHT, # 17 + +] + +BROW_BOTH_BLENDSHAPES = [ + BROW_INNER_UP, # 18 +] + +NOSE_BLENDSHAPES = [ + NOSE_SNEER_LEFT, # 19 + NOSE_SNEER_RIGHT, # 20 +] + +CHECK_BLENDSHAPES = [ + CHEEK_SQUINT_LEFT, # 21 + CHEEK_SQUINT_RIGHT, # 22 + CHEEK_PUFF, # 23 +] + +MOUTH_LEFT_BLENDSHAPES = [ + MOUTH_LEFT, # 24 + MOUTH_DIMPLE_LEFT, # 25 + MOUTH_FROWN_LEFT, # 26 + MOUTH_LOWER_DOWN_LEFT, # 27 + MOUTH_PRESS_LEFT, # 28 + MOUTH_SMILE_LEFT, # 29 + MOUTH_STRETCH_LEFT, # 30 + MOUTH_UPPER_UP_LEFT, # 31 +] + +MOUTH_RIGHT_BLENDSHAPES = [ + MOUTH_RIGHT, # 32 + MOUTH_DIMPLE_RIGHT, # 33 + MOUTH_FROWN_RIGHT, # 34 + MOUTH_LOWER_DOWN_RIGHT, # 35 + MOUTH_PRESS_RIGHT, # 36 + MOUTH_SMILE_RIGHT, # 37 + MOUTH_STRETCH_RIGHT, # 38 + MOUTH_UPPER_UP_RIGHT, # 39 +] + +MOUTH_BOTH_BLENDSHAPES = [ + MOUTH_CLOSE, # 40 + MOUTH_FUNNEL, # 41 + MOUTH_PUCKER, # 42 + MOUTH_ROLL_LOWER, # 43 + MOUTH_ROLL_UPPER, # 44 + MOUTH_SHRUG_LOWER, # 45 + MOUTH_SHRUG_UPPER, # 46 +] + +JAW_BLENDSHAPES = [ + JAW_LEFT, # 47 + JAW_RIGHT, # 48 + JAW_FORWARD, # 49 + JAW_OPEN, # 50 +] + +TONGUE_BLENDSHAPES = [ + TONGUE_OUT, # 51 +] + +COLUMN_0_BLENDSHAPES = EYE_RIGHT_BLENDSHAPES + BROW_RIGHT_BLENDSHAPES + [NOSE_SNEER_RIGHT, CHEEK_SQUINT_RIGHT] +COLUMN_1_BLENDSHAPES = EYE_LEFT_BLENDSHAPES + BROW_LEFT_BLENDSHAPES + [NOSE_SNEER_LEFT, CHEEK_SQUINT_LEFT] +COLUMN_2_BLENDSHAPES = MOUTH_RIGHT_BLENDSHAPES + [JAW_RIGHT] +COLUMN_3_BLENDSHAPES = MOUTH_LEFT_BLENDSHAPES + [JAW_LEFT] +COLUMN_4_BLENDSHAPES = [BROW_INNER_UP, CHEEK_PUFF] + MOUTH_BOTH_BLENDSHAPES + [JAW_FORWARD, JAW_OPEN, TONGUE_OUT] + +BLENDSHAPE_COLUMNS = [ + COLUMN_0_BLENDSHAPES, + COLUMN_1_BLENDSHAPES, + COLUMN_2_BLENDSHAPES, + COLUMN_3_BLENDSHAPES, + COLUMN_4_BLENDSHAPES, +] + +RIGHT_EYE_BONE_X = "rightEyeBoneX" +RIGHT_EYE_BONE_Y = "rightEyeBoneY" +RIGHT_EYE_BONE_Z = "rightEyeBoneZ" +RIGHT_EYE_BONE_ROTATIONS = [RIGHT_EYE_BONE_X, RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z] + +LEFT_EYE_BONE_X = "leftEyeBoneX" +LEFT_EYE_BONE_Y = "leftEyeBoneY" +LEFT_EYE_BONE_Z = "leftEyeBoneZ" +LEFT_EYE_BONE_ROTATIONS = [LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z] + +HEAD_BONE_X = "headBoneX" +HEAD_BONE_Y = "headBoneY" +HEAD_BONE_Z = "headBoneZ" +HEAD_BONE_ROTATIONS = [HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z] + +ROTATION_NAMES = RIGHT_EYE_BONE_ROTATIONS + LEFT_EYE_BONE_ROTATIONS + HEAD_BONE_ROTATIONS + +RIGHT_EYE_BONE_QUAT = "rightEyeBoneQuat" +LEFT_EYE_BONE_QUAT = "leftEyeBoneQuat" +HEAD_BONE_QUAT = "headBoneQuat" +QUATERNION_NAMES = [ + RIGHT_EYE_BONE_QUAT, + LEFT_EYE_BONE_QUAT, + HEAD_BONE_QUAT +] + +IFACIALMOCAP_DATETIME_FORMAT = "%Y/%m/%d-%H:%M:%S.%f" diff --git a/tha3/mocap/ifacialmocap_pose.py b/tha3/mocap/ifacialmocap_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..d90936e4cc906293ea673f613f16520a57650883 --- /dev/null +++ b/tha3/mocap/ifacialmocap_pose.py @@ -0,0 +1,27 @@ +from tha3.mocap.ifacialmocap_constants import BLENDSHAPE_NAMES, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, \ + HEAD_BONE_QUAT, LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z, LEFT_EYE_BONE_QUAT, RIGHT_EYE_BONE_X, \ + RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z, RIGHT_EYE_BONE_QUAT + + +def create_default_ifacialmocap_pose(): + data = {} + + for blendshape_name in BLENDSHAPE_NAMES: + data[blendshape_name] = 0.0 + + data[HEAD_BONE_X] = 0.0 + data[HEAD_BONE_Y] = 0.0 + data[HEAD_BONE_Z] = 0.0 + data[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] + + data[LEFT_EYE_BONE_X] = 0.0 + data[LEFT_EYE_BONE_Y] = 0.0 + data[LEFT_EYE_BONE_Z] = 0.0 + data[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] + + data[RIGHT_EYE_BONE_X] = 0.0 + data[RIGHT_EYE_BONE_Y] = 0.0 + data[RIGHT_EYE_BONE_Z] = 0.0 + data[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] + + return data \ No newline at end of file diff --git a/tha3/mocap/ifacialmocap_pose_converter.py b/tha3/mocap/ifacialmocap_pose_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..390460b4bfe963141f9d787b3405adc8f79e6d7b --- /dev/null +++ b/tha3/mocap/ifacialmocap_pose_converter.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from typing import Dict, List + + +class IFacialMocapPoseConverter(ABC): + @abstractmethod + def convert(self, ifacialmocap_pose: Dict[str, float]) -> List[float]: + pass + + @abstractmethod + def init_pose_converter_panel(self, parent): + pass \ No newline at end of file diff --git a/tha3/mocap/ifacialmocap_poser_converter_25.py b/tha3/mocap/ifacialmocap_poser_converter_25.py new file mode 100644 index 0000000000000000000000000000000000000000..bff74d4cfe989c7ba6084167698427d07af92797 --- /dev/null +++ b/tha3/mocap/ifacialmocap_poser_converter_25.py @@ -0,0 +1,463 @@ +import math +import time +from enum import Enum +from typing import Optional, Dict, List + +import numpy +import scipy.optimize +import wx + +from tha3.mocap.ifacialmocap_constants import MOUTH_SMILE_LEFT, MOUTH_SHRUG_UPPER, MOUTH_SMILE_RIGHT, \ + BROW_INNER_UP, BROW_OUTER_UP_RIGHT, BROW_OUTER_UP_LEFT, BROW_DOWN_LEFT, BROW_DOWN_RIGHT, EYE_WIDE_LEFT, \ + EYE_WIDE_RIGHT, EYE_BLINK_LEFT, EYE_BLINK_RIGHT, CHEEK_SQUINT_LEFT, CHEEK_SQUINT_RIGHT, EYE_LOOK_IN_LEFT, \ + EYE_LOOK_OUT_LEFT, EYE_LOOK_IN_RIGHT, EYE_LOOK_OUT_RIGHT, EYE_LOOK_UP_LEFT, EYE_LOOK_UP_RIGHT, EYE_LOOK_DOWN_RIGHT, \ + EYE_LOOK_DOWN_LEFT, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, JAW_OPEN, MOUTH_FROWN_LEFT, MOUTH_FROWN_RIGHT, \ + MOUTH_LOWER_DOWN_LEFT, MOUTH_LOWER_DOWN_RIGHT, MOUTH_FUNNEL, MOUTH_PUCKER +from tha3.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter +from tha3.poser.modes.pose_parameters import get_pose_parameters + + +class EyebrowDownMode(Enum): + TROUBLED = 1 + ANGRY = 2 + LOWERED = 3 + SERIOUS = 4 + + +class WinkMode(Enum): + NORMAL = 1 + RELAXED = 2 + + +def rad_to_deg(rad): + return rad * 180.0 / math.pi + + +def deg_to_rad(deg): + return deg * math.pi / 180.0 + + +def clamp(x, min_value, max_value): + return max(min_value, min(max_value, x)) + + +class IFacialMocapPoseConverter25Args: + def __init__(self, + lower_smile_threshold: float = 0.4, + upper_smile_threshold: float = 0.6, + eyebrow_down_mode: EyebrowDownMode = EyebrowDownMode.ANGRY, + wink_mode: WinkMode = WinkMode.NORMAL, + eye_surprised_max_value: float = 0.5, + eye_wink_max_value: float = 0.8, + eyebrow_down_max_value: float = 0.4, + cheek_squint_min_value: float = 0.1, + cheek_squint_max_value: float = 0.7, + eye_rotation_factor: float = 1.0 / 0.75, + jaw_open_min_value: float = 0.1, + jaw_open_max_value: float = 0.4, + mouth_frown_max_value: float = 0.6, + mouth_funnel_min_value: float = 0.25, + mouth_funnel_max_value: float = 0.5, + iris_small_left=0.0, + iris_small_right=0.0): + self.iris_small_right = iris_small_left + self.iris_small_left = iris_small_right + self.wink_mode = wink_mode + self.mouth_funnel_max_value = mouth_funnel_max_value + self.mouth_funnel_min_value = mouth_funnel_min_value + self.mouth_frown_max_value = mouth_frown_max_value + self.jaw_open_max_value = jaw_open_max_value + self.jaw_open_min_value = jaw_open_min_value + self.eye_rotation_factor = eye_rotation_factor + self.cheek_squint_max_value = cheek_squint_max_value + self.cheek_squint_min_value = cheek_squint_min_value + self.eyebrow_down_max_value = eyebrow_down_max_value + self.eye_blink_max_value = eye_wink_max_value + self.eye_wide_max_value = eye_surprised_max_value + self.eyebrow_down_mode = eyebrow_down_mode + self.lower_smile_threshold = lower_smile_threshold + self.upper_smile_threshold = upper_smile_threshold + + +class IFacialMocapPoseConverter25(IFacialMocapPoseConverter): + def __init__(self, args: Optional[IFacialMocapPoseConverter25Args] = None): + super().__init__() + if args is None: + args = IFacialMocapPoseConverter25Args() + self.args = args + pose_parameters = get_pose_parameters() + self.pose_size = 45 + + self.eyebrow_troubled_left_index = pose_parameters.get_parameter_index("eyebrow_troubled_left") + self.eyebrow_troubled_right_index = pose_parameters.get_parameter_index("eyebrow_troubled_right") + self.eyebrow_angry_left_index = pose_parameters.get_parameter_index("eyebrow_angry_left") + self.eyebrow_angry_right_index = pose_parameters.get_parameter_index("eyebrow_angry_right") + self.eyebrow_happy_left_index = pose_parameters.get_parameter_index("eyebrow_happy_left") + self.eyebrow_happy_right_index = pose_parameters.get_parameter_index("eyebrow_happy_right") + self.eyebrow_raised_left_index = pose_parameters.get_parameter_index("eyebrow_raised_left") + self.eyebrow_raised_right_index = pose_parameters.get_parameter_index("eyebrow_raised_right") + self.eyebrow_lowered_left_index = pose_parameters.get_parameter_index("eyebrow_lowered_left") + self.eyebrow_lowered_right_index = pose_parameters.get_parameter_index("eyebrow_lowered_right") + self.eyebrow_serious_left_index = pose_parameters.get_parameter_index("eyebrow_serious_left") + self.eyebrow_serious_right_index = pose_parameters.get_parameter_index("eyebrow_serious_right") + + self.eye_surprised_left_index = pose_parameters.get_parameter_index("eye_surprised_left") + self.eye_surprised_right_index = pose_parameters.get_parameter_index("eye_surprised_right") + self.eye_wink_left_index = pose_parameters.get_parameter_index("eye_wink_left") + self.eye_wink_right_index = pose_parameters.get_parameter_index("eye_wink_right") + self.eye_happy_wink_left_index = pose_parameters.get_parameter_index("eye_happy_wink_left") + self.eye_happy_wink_right_index = pose_parameters.get_parameter_index("eye_happy_wink_right") + self.eye_relaxed_left_index = pose_parameters.get_parameter_index("eye_relaxed_left") + self.eye_relaxed_right_index = pose_parameters.get_parameter_index("eye_relaxed_right") + self.eye_raised_lower_eyelid_left_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_left") + self.eye_raised_lower_eyelid_right_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_right") + + self.iris_small_left_index = pose_parameters.get_parameter_index("iris_small_left") + self.iris_small_right_index = pose_parameters.get_parameter_index("iris_small_right") + + self.iris_rotation_x_index = pose_parameters.get_parameter_index("iris_rotation_x") + self.iris_rotation_y_index = pose_parameters.get_parameter_index("iris_rotation_y") + + self.head_x_index = pose_parameters.get_parameter_index("head_x") + self.head_y_index = pose_parameters.get_parameter_index("head_y") + self.neck_z_index = pose_parameters.get_parameter_index("neck_z") + + self.mouth_aaa_index = pose_parameters.get_parameter_index("mouth_aaa") + self.mouth_iii_index = pose_parameters.get_parameter_index("mouth_iii") + self.mouth_uuu_index = pose_parameters.get_parameter_index("mouth_uuu") + self.mouth_eee_index = pose_parameters.get_parameter_index("mouth_eee") + self.mouth_ooo_index = pose_parameters.get_parameter_index("mouth_ooo") + + self.mouth_lowered_corner_left_index = pose_parameters.get_parameter_index("mouth_lowered_corner_left") + self.mouth_lowered_corner_right_index = pose_parameters.get_parameter_index("mouth_lowered_corner_right") + self.mouth_raised_corner_left_index = pose_parameters.get_parameter_index("mouth_raised_corner_left") + self.mouth_raised_corner_right_index = pose_parameters.get_parameter_index("mouth_raised_corner_right") + + self.body_y_index = pose_parameters.get_parameter_index("body_y") + self.body_z_index = pose_parameters.get_parameter_index("body_z") + self.breathing_index = pose_parameters.get_parameter_index("breathing") + + self.breathing_start_time = time.time() + + self.panel = None + + def init_pose_converter_panel(self, parent): + self.panel = wx.Panel(parent, style=wx.SIMPLE_BORDER) + self.panel_sizer = wx.BoxSizer(wx.VERTICAL) + self.panel.SetSizer(self.panel_sizer) + self.panel.SetAutoLayout(1) + parent.GetSizer().Add(self.panel, 0, wx.EXPAND) + + if True: + eyebrow_down_mode_text = wx.StaticText(self.panel, label=" --- Eyebrow Down Mode --- ", + style=wx.ALIGN_CENTER) + self.panel_sizer.Add(eyebrow_down_mode_text, 0, wx.EXPAND) + + self.eyebrow_down_mode_choice = wx.Choice( + self.panel, + choices=[ + "ANGRY", + "TROUBLED", + "SERIOUS", + "LOWERED", + ]) + self.eyebrow_down_mode_choice.SetSelection(0) + self.panel_sizer.Add(self.eyebrow_down_mode_choice, 0, wx.EXPAND) + self.eyebrow_down_mode_choice.Bind(wx.EVT_CHOICE, self.change_eyebrow_down_mode) + + separator = wx.StaticLine(self.panel, -1, size=(256, 5)) + self.panel_sizer.Add(separator, 0, wx.EXPAND) + + if True: + wink_mode_text = wx.StaticText(self.panel, label=" --- Wink Mode --- ", style=wx.ALIGN_CENTER) + self.panel_sizer.Add(wink_mode_text, 0, wx.EXPAND) + + self.wink_mode_choice = wx.Choice( + self.panel, + choices=[ + "NORMAL", + "RELAXED", + ]) + self.wink_mode_choice.SetSelection(0) + self.panel_sizer.Add(self.wink_mode_choice, 0, wx.EXPAND) + self.wink_mode_choice.Bind(wx.EVT_CHOICE, self.change_wink_mode) + + separator = wx.StaticLine(self.panel, -1, size=(256, 5)) + self.panel_sizer.Add(separator, 0, wx.EXPAND) + + if True: + iris_size_text = wx.StaticText(self.panel, label=" --- Iris Size --- ", style=wx.ALIGN_CENTER) + self.panel_sizer.Add(iris_size_text, 0, wx.EXPAND) + + self.iris_left_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL) + self.panel_sizer.Add(self.iris_left_slider, 0, wx.EXPAND) + self.iris_left_slider.Bind(wx.EVT_SLIDER, self.change_iris_size) + + self.iris_right_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL) + self.panel_sizer.Add(self.iris_right_slider, 0, wx.EXPAND) + self.iris_right_slider.Bind(wx.EVT_SLIDER, self.change_iris_size) + self.iris_right_slider.Enable(False) + + self.link_left_right_irises = wx.CheckBox( + self.panel, label="Use same value for both sides") + self.link_left_right_irises.SetValue(True) + self.panel_sizer.Add(self.link_left_right_irises, wx.SizerFlags().CenterHorizontal().Border()) + self.link_left_right_irises.Bind(wx.EVT_CHECKBOX, self.link_left_right_irises_clicked) + + separator = wx.StaticLine(self.panel, -1, size=(256, 5)) + self.panel_sizer.Add(separator, 0, wx.EXPAND) + + if True: + breathing_frequency_text = wx.StaticText( + self.panel, label=" --- Breathing --- ", style=wx.ALIGN_CENTER) + self.panel_sizer.Add(breathing_frequency_text, 0, wx.EXPAND) + + self.restart_breathing_cycle_button = wx.Button(self.panel, label="Restart Breathing Cycle") + self.restart_breathing_cycle_button.Bind(wx.EVT_BUTTON, self.restart_breathing_cycle_clicked) + self.panel_sizer.Add(self.restart_breathing_cycle_button, 0, wx.EXPAND) + + self.breathing_frequency_slider = wx.Slider( + self.panel, minValue=0, maxValue=60, value=20, style=wx.HORIZONTAL) + self.panel_sizer.Add(self.breathing_frequency_slider, 0, wx.EXPAND) + + self.breathing_gauge = wx.Gauge(self.panel, style=wx.GA_HORIZONTAL, range=1000) + self.panel_sizer.Add(self.breathing_gauge, 0, wx.EXPAND) + + self.panel_sizer.Fit(self.panel) + + def restart_breathing_cycle_clicked(self, event: wx.Event): + self.breathing_start_time = time.time() + + def change_eyebrow_down_mode(self, event: wx.Event): + selected_index = self.eyebrow_down_mode_choice.GetSelection() + if selected_index == 0: + self.args.eyebrow_down_mode = EyebrowDownMode.ANGRY + elif selected_index == 1: + self.args.eyebrow_down_mode = EyebrowDownMode.TROUBLED + elif selected_index == 2: + self.args.eyebrow_down_mode = EyebrowDownMode.SERIOUS + else: + self.args.eyebrow_down_mode = EyebrowDownMode.LOWERED + + def change_wink_mode(self, event: wx.Event): + selected_index = self.wink_mode_choice.GetSelection() + if selected_index == 0: + self.args.wink_mode = WinkMode.NORMAL + else: + self.args.wink_mode = WinkMode.RELAXED + + def change_iris_size(self, event: wx.Event): + if self.link_left_right_irises.GetValue(): + left_value = self.iris_left_slider.GetValue() + right_value = self.iris_right_slider.GetValue() + if left_value != right_value: + self.iris_right_slider.SetValue(left_value) + self.args.iris_small_left = left_value / 1000.0 + self.args.iris_small_right = left_value / 1000.0 + else: + self.args.iris_small_left = self.iris_left_slider.GetValue() / 1000.0 + self.args.iris_small_right = self.iris_right_slider.GetValue() / 1000.0 + + def link_left_right_irises_clicked(self, event: wx.Event): + if self.link_left_right_irises.GetValue(): + self.iris_right_slider.Enable(False) + else: + self.iris_right_slider.Enable(True) + self.change_iris_size(event) + + def decompose_head_body_param(self, param, threshold=2.0 / 3): + if abs(param) < threshold: + return (param, 0.0) + else: + if param < 0: + sign = -1.0 + else: + sign = 1.0 + return (threshold * sign, (abs(param) - threshold) * sign) + + def convert(self, ifacialmocap_pose: Dict[str, float]) -> List[float]: + pose = [0.0 for i in range(self.pose_size)] + + smile_value = \ + (ifacialmocap_pose[MOUTH_SMILE_LEFT] + ifacialmocap_pose[MOUTH_SMILE_RIGHT]) / 2.0 \ + + ifacialmocap_pose[MOUTH_SHRUG_UPPER] + if smile_value < self.args.lower_smile_threshold: + smile_degree = 0.0 + elif smile_value > self.args.upper_smile_threshold: + smile_degree = 1.0 + else: + smile_degree = (smile_value - self.args.lower_smile_threshold) / ( + self.args.upper_smile_threshold - self.args.lower_smile_threshold) + + # Eyebrow + if True: + brow_inner_up = ifacialmocap_pose[BROW_INNER_UP] + brow_outer_up_right = ifacialmocap_pose[BROW_OUTER_UP_RIGHT] + brow_outer_up_left = ifacialmocap_pose[BROW_OUTER_UP_LEFT] + + brow_up_left = clamp(brow_inner_up + brow_outer_up_left, 0.0, 1.0) + brow_up_right = clamp(brow_inner_up + brow_outer_up_right, 0.0, 1.0) + pose[self.eyebrow_raised_left_index] = brow_up_left + pose[self.eyebrow_raised_right_index] = brow_up_right + + brow_down_left = (1.0 - smile_degree) \ + * clamp(ifacialmocap_pose[BROW_DOWN_LEFT] / self.args.eyebrow_down_max_value, 0.0, 1.0) + brow_down_right = (1.0 - smile_degree) \ + * clamp(ifacialmocap_pose[BROW_DOWN_RIGHT] / self.args.eyebrow_down_max_value, 0.0, 1.0) + if self.args.eyebrow_down_mode == EyebrowDownMode.TROUBLED: + pose[self.eyebrow_troubled_left_index] = brow_down_left + pose[self.eyebrow_troubled_right_index] = brow_down_right + elif self.args.eyebrow_down_mode == EyebrowDownMode.ANGRY: + pose[self.eyebrow_angry_left_index] = brow_down_left + pose[self.eyebrow_angry_right_index] = brow_down_right + elif self.args.eyebrow_down_mode == EyebrowDownMode.LOWERED: + pose[self.eyebrow_lowered_left_index] = brow_down_left + pose[self.eyebrow_lowered_right_index] = brow_down_right + elif self.args.eyebrow_down_mode == EyebrowDownMode.SERIOUS: + pose[self.eyebrow_serious_left_index] = brow_down_left + pose[self.eyebrow_serious_right_index] = brow_down_right + + brow_happy_value = clamp(smile_value, 0.0, 1.0) * smile_degree + pose[self.eyebrow_happy_left_index] = brow_happy_value + pose[self.eyebrow_happy_right_index] = brow_happy_value + + # Eye + if True: + # Surprised + pose[self.eye_surprised_left_index] = clamp( + ifacialmocap_pose[EYE_WIDE_LEFT] / self.args.eye_wide_max_value, 0.0, 1.0) + pose[self.eye_surprised_right_index] = clamp( + ifacialmocap_pose[EYE_WIDE_RIGHT] / self.args.eye_wide_max_value, 0.0, 1.0) + + # Wink + if self.args.wink_mode == WinkMode.NORMAL: + wink_left_index = self.eye_wink_left_index + wink_right_index = self.eye_wink_right_index + else: + wink_left_index = self.eye_relaxed_left_index + wink_right_index = self.eye_relaxed_right_index + pose[wink_left_index] = (1.0 - smile_degree) * clamp( + ifacialmocap_pose[EYE_BLINK_LEFT] / self.args.eye_blink_max_value, 0.0, 1.0) + pose[wink_right_index] = (1.0 - smile_degree) * clamp( + ifacialmocap_pose[EYE_BLINK_RIGHT] / self.args.eye_blink_max_value, 0.0, 1.0) + pose[self.eye_happy_wink_left_index] = smile_degree * clamp( + ifacialmocap_pose[EYE_BLINK_LEFT] / self.args.eye_blink_max_value, 0.0, 1.0) + pose[self.eye_happy_wink_right_index] = smile_degree * clamp( + ifacialmocap_pose[EYE_BLINK_RIGHT] / self.args.eye_blink_max_value, 0.0, 1.0) + + # Lower eyelid + cheek_squint_denom = self.args.cheek_squint_max_value - self.args.cheek_squint_min_value + pose[self.eye_raised_lower_eyelid_left_index] = \ + clamp( + (ifacialmocap_pose[CHEEK_SQUINT_LEFT] - self.args.cheek_squint_min_value) / cheek_squint_denom, + 0.0, 1.0) + pose[self.eye_raised_lower_eyelid_right_index] = \ + clamp( + (ifacialmocap_pose[CHEEK_SQUINT_RIGHT] - self.args.cheek_squint_min_value) / cheek_squint_denom, + 0.0, 1.0) + + # Iris rotation + if True: + eye_rotation_y = (ifacialmocap_pose[EYE_LOOK_IN_LEFT] + - ifacialmocap_pose[EYE_LOOK_OUT_LEFT] + - ifacialmocap_pose[EYE_LOOK_IN_RIGHT] + + ifacialmocap_pose[EYE_LOOK_OUT_RIGHT]) / 2.0 * self.args.eye_rotation_factor + pose[self.iris_rotation_y_index] = clamp(eye_rotation_y, -1.0, 1.0) + + eye_rotation_x = (ifacialmocap_pose[EYE_LOOK_UP_LEFT] + + ifacialmocap_pose[EYE_LOOK_UP_RIGHT] + - ifacialmocap_pose[EYE_LOOK_DOWN_LEFT] + - ifacialmocap_pose[EYE_LOOK_DOWN_RIGHT]) / 2.0 * self.args.eye_rotation_factor + pose[self.iris_rotation_x_index] = clamp(eye_rotation_x, -1.0, 1.0) + + # Iris size + if True: + pose[self.iris_small_left_index] = self.args.iris_small_left + pose[self.iris_small_right_index] = self.args.iris_small_right + + # Head rotation + if True: + x_param = clamp(-ifacialmocap_pose[HEAD_BONE_X] * 180.0 / math.pi, -15.0, 15.0) / 15.0 + pose[self.head_x_index] = x_param + + y_param = clamp(-ifacialmocap_pose[HEAD_BONE_Y] * 180.0 / math.pi, -10.0, 10.0) / 10.0 + pose[self.head_y_index] = y_param + pose[self.body_y_index] = y_param + + z_param = clamp(ifacialmocap_pose[HEAD_BONE_Z] * 180.0 / math.pi, -15.0, 15.0) / 15.0 + pose[self.neck_z_index] = z_param + pose[self.body_z_index] = z_param + + # Mouth + if True: + jaw_open_denom = self.args.jaw_open_max_value - self.args.jaw_open_min_value + mouth_open = clamp((ifacialmocap_pose[JAW_OPEN] - self.args.jaw_open_min_value) / jaw_open_denom, 0.0, 1.0) + pose[self.mouth_aaa_index] = mouth_open + pose[self.mouth_raised_corner_left_index] = clamp(smile_value, 0.0, 1.0) + pose[self.mouth_raised_corner_right_index] = clamp(smile_value, 0.0, 1.0) + + is_mouth_open = mouth_open > 0.0 + if not is_mouth_open: + mouth_frown_value = clamp( + (ifacialmocap_pose[MOUTH_FROWN_LEFT] + ifacialmocap_pose[ + MOUTH_FROWN_RIGHT]) / self.args.mouth_frown_max_value, 0.0, 1.0) + pose[self.mouth_lowered_corner_left_index] = mouth_frown_value + pose[self.mouth_lowered_corner_right_index] = mouth_frown_value + else: + mouth_lower_down = clamp( + ifacialmocap_pose[MOUTH_LOWER_DOWN_LEFT] + ifacialmocap_pose[MOUTH_LOWER_DOWN_RIGHT], 0.0, 1.0) + mouth_funnel = ifacialmocap_pose[MOUTH_FUNNEL] + mouth_pucker = ifacialmocap_pose[MOUTH_PUCKER] + + mouth_point = [mouth_open, mouth_lower_down, mouth_funnel, mouth_pucker] + + aaa_point = [1.0, 1.0, 0.0, 0.0] + iii_point = [0.0, 1.0, 0.0, 0.0] + uuu_point = [0.5, 0.3, 0.25, 0.75] + ooo_point = [1.0, 0.5, 0.5, 0.4] + + decomp = numpy.array([0, 0, 0, 0]) + M = numpy.array([ + aaa_point, + iii_point, + uuu_point, + ooo_point + ]) + + def loss(decomp): + return numpy.linalg.norm(numpy.matmul(decomp, M) - mouth_point) \ + + 0.01 * numpy.linalg.norm(decomp, ord=1) + + opt_result = scipy.optimize.minimize( + loss, decomp, bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0), (0.0, 1.0)]) + decomp = opt_result["x"] + restricted_decomp = [decomp.item(0), decomp.item(1), decomp.item(2), decomp.item(3)] + pose[self.mouth_aaa_index] = restricted_decomp[0] + pose[self.mouth_iii_index] = restricted_decomp[1] + mouth_funnel_denom = self.args.mouth_funnel_max_value - self.args.mouth_funnel_min_value + ooo_alpha = clamp((mouth_funnel - self.args.mouth_funnel_min_value) / mouth_funnel_denom, 0.0, 1.0) + uo_value = clamp(restricted_decomp[2] + restricted_decomp[3], 0.0, 1.0) + pose[self.mouth_uuu_index] = uo_value * (1.0 - ooo_alpha) + pose[self.mouth_ooo_index] = uo_value * ooo_alpha + + if self.panel is not None: + frequency = self.breathing_frequency_slider.GetValue() + if frequency == 0: + value = 0.0 + pose[self.breathing_index] = value + self.breathing_start_time = time.time() + else: + period = 60.0 / frequency + now = time.time() + diff = now - self.breathing_start_time + frac = (diff % period) / period + value = (-math.cos(2 * math.pi * frac) + 1.0) / 2.0 + pose[self.breathing_index] = value + self.breathing_gauge.SetValue(int(1000 * value)) + + return pose + + +def create_ifacialmocap_pose_converter( + args: Optional[IFacialMocapPoseConverter25Args] = None) -> IFacialMocapPoseConverter: + return IFacialMocapPoseConverter25(args) diff --git a/tha3/mocap/ifacialmocap_v2.py b/tha3/mocap/ifacialmocap_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..dae46eaaf72fa22e091998451ab27e1e19d61773 --- /dev/null +++ b/tha3/mocap/ifacialmocap_v2.py @@ -0,0 +1,89 @@ +import math + +from tha3.mocap.ifacialmocap_constants import BLENDSHAPE_NAMES, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, \ + RIGHT_EYE_BONE_X, RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z, LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z, \ + HEAD_BONE_QUAT, LEFT_EYE_BONE_QUAT, RIGHT_EYE_BONE_QUAT + +IFACIALMOCAP_PORT = 49983 +IFACIALMOCAP_START_STRING = "iFacialMocap_sahuasouryya9218sauhuiayeta91555dy3719|sendDataVersion=v2".encode('utf-8') + + +def parse_ifacialmocap_v2_pose(ifacialmocap_output): + output = {} + parts = ifacialmocap_output.split("|") + for part in parts: + part = part.strip() + if len(part) == 0: + continue + if "&" in part: + components = part.split("&") + assert len(components) == 2 + key = components[0] + value = float(components[1]) / 100.0 + if key.endswith("_L"): + key = key[:-2] + "Left" + elif key.endswith("_R"): + key = key[:-2] + "Right" + if key in BLENDSHAPE_NAMES: + output[key] = value + elif part.startswith("=head#"): + components = part[len("=head#"):].split(",") + assert len(components) == 6 + output[HEAD_BONE_X] = float(components[0]) * math.pi / 180 + output[HEAD_BONE_Y] = float(components[1]) * math.pi / 180 + output[HEAD_BONE_Z] = float(components[2]) * math.pi / 180 + elif part.startswith("rightEye#"): + components = part[len("rightEye#"):].split(",") + output[RIGHT_EYE_BONE_X] = float(components[0]) * math.pi / 180 + output[RIGHT_EYE_BONE_Y] = float(components[1]) * math.pi / 180 + output[RIGHT_EYE_BONE_Z] = float(components[2]) * math.pi / 180 + elif part.startswith("leftEye#"): + components = part[len("leftEye#"):].split(",") + output[LEFT_EYE_BONE_X] = float(components[0]) * math.pi / 180 + output[LEFT_EYE_BONE_Y] = float(components[1]) * math.pi / 180 + output[LEFT_EYE_BONE_Z] = float(components[2]) * math.pi / 180 + output[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] + output[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] + output[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] + return output + + +def parse_ifacialmocap_v1_pose(ifacialmocap_output): + output = {} + parts = ifacialmocap_output.split("|") + for part in parts: + part = part.strip() + if len(part) == 0: + continue + if part.startswith("=head#"): + components = part[len("=head#"):].split(",") + assert len(components) == 6 + output[HEAD_BONE_X] = float(components[0]) * math.pi / 180 + output[HEAD_BONE_Y] = float(components[1]) * math.pi / 180 + output[HEAD_BONE_Z] = float(components[2]) * math.pi / 180 + elif part.startswith("rightEye#"): + components = part[len("rightEye#"):].split(",") + output[RIGHT_EYE_BONE_X] = float(components[0]) * math.pi / 180 + output[RIGHT_EYE_BONE_Y] = float(components[1]) * math.pi / 180 + output[RIGHT_EYE_BONE_Z] = float(components[2]) * math.pi / 180 + elif part.startswith("leftEye#"): + components = part[len("leftEye#"):].split(",") + output[LEFT_EYE_BONE_X] = float(components[0]) * math.pi / 180 + output[LEFT_EYE_BONE_Y] = float(components[1]) * math.pi / 180 + output[LEFT_EYE_BONE_Z] = float(components[2]) * math.pi / 180 + else: + components = part.split("-") + assert len(components) == 2 + key = components[0] + value = float(components[1]) / 100.0 + if key.endswith("_L"): + key = key[:-2] + "Left" + elif key.endswith("_R"): + key = key[:-2] + "Right" + if key in BLENDSHAPE_NAMES: + output[key] = value + output[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] + output[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] + output[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0] + return output + diff --git a/tha3/module/__init__.py b/tha3/module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/module/module_factory.py b/tha3/module/module_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..f6231c3dde48fc8a1679562e384ed039d8365752 --- /dev/null +++ b/tha3/module/module_factory.py @@ -0,0 +1,9 @@ +from abc import ABC, abstractmethod + +from torch.nn import Module + + +class ModuleFactory(ABC): + @abstractmethod + def create(self) -> Module: + pass \ No newline at end of file diff --git a/tha3/nn/__init__.py b/tha3/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/nn/common/__init__.py b/tha3/nn/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/nn/common/conv_block_factory.py b/tha3/nn/common/conv_block_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..f2756bc1d18d2e0310727911cde95b6512742094 --- /dev/null +++ b/tha3/nn/common/conv_block_factory.py @@ -0,0 +1,55 @@ +from typing import Optional + +from tha3.nn.conv import create_conv7_block_from_block_args, create_conv3_block_from_block_args, \ + create_downsample_block_from_block_args, create_conv3 +from tha3.nn.resnet_block import ResnetBlock +from tha3.nn.resnet_block_seperable import ResnetBlockSeparable +from tha3.nn.separable_conv import create_separable_conv7_block, create_separable_conv3_block, \ + create_separable_downsample_block, create_separable_conv3 +from tha3.nn.util import BlockArgs + + +class ConvBlockFactory: + def __init__(self, + block_args: BlockArgs, + use_separable_convolution: bool = False): + self.use_separable_convolution = use_separable_convolution + self.block_args = block_args + + def create_conv3(self, + in_channels: int, + out_channels: int, + bias: bool, + initialization_method: Optional[str] = None): + if initialization_method is None: + initialization_method = self.block_args.initialization_method + if self.use_separable_convolution: + return create_separable_conv3( + in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm) + else: + return create_conv3( + in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm) + + def create_conv7_block(self, in_channels: int, out_channels: int): + if self.use_separable_convolution: + return create_separable_conv7_block(in_channels, out_channels, self.block_args) + else: + return create_conv7_block_from_block_args(in_channels, out_channels, self.block_args) + + def create_conv3_block(self, in_channels: int, out_channels: int): + if self.use_separable_convolution: + return create_separable_conv3_block(in_channels, out_channels, self.block_args) + else: + return create_conv3_block_from_block_args(in_channels, out_channels, self.block_args) + + def create_downsample_block(self, in_channels: int, out_channels: int, is_output_1x1: bool): + if self.use_separable_convolution: + return create_separable_downsample_block(in_channels, out_channels, is_output_1x1, self.block_args) + else: + return create_downsample_block_from_block_args(in_channels, out_channels, is_output_1x1) + + def create_resnet_block(self, num_channels: int, is_1x1: bool): + if self.use_separable_convolution: + return ResnetBlockSeparable.create(num_channels, is_1x1, block_args=self.block_args) + else: + return ResnetBlock.create(num_channels, is_1x1, block_args=self.block_args) \ No newline at end of file diff --git a/tha3/nn/common/poser_args.py b/tha3/nn/common/poser_args.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a23aa02c18ce9a196229bc8a938c100920a70c --- /dev/null +++ b/tha3/nn/common/poser_args.py @@ -0,0 +1,68 @@ +from typing import Optional + +from torch.nn import Sigmoid, Sequential, Tanh + +from tha3.nn.conv import create_conv3, create_conv3_from_block_args +from tha3.nn.nonlinearity_factory import ReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class PoserArgs00: + def __init__(self, + image_size: int, + input_image_channels: int, + output_image_channels: int, + start_channels: int, + num_pose_params: int, + block_args: Optional[BlockArgs] = None): + self.num_pose_params = num_pose_params + self.start_channels = start_channels + self.output_image_channels = output_image_channels + self.input_image_channels = input_image_channels + self.image_size = image_size + if block_args is None: + self.block_args = BlockArgs( + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True)) + else: + self.block_args = block_args + + def create_alpha_block(self): + from torch.nn import Sequential + return Sequential( + create_conv3( + in_channels=self.start_channels, + out_channels=1, + bias=True, + initialization_method=self.block_args.initialization_method, + use_spectral_norm=False), + Sigmoid()) + + def create_all_channel_alpha_block(self): + from torch.nn import Sequential + return Sequential( + create_conv3( + in_channels=self.start_channels, + out_channels=self.output_image_channels, + bias=True, + initialization_method=self.block_args.initialization_method, + use_spectral_norm=False), + Sigmoid()) + + def create_color_change_block(self): + return Sequential( + create_conv3_from_block_args( + in_channels=self.start_channels, + out_channels=self.output_image_channels, + bias=True, + block_args=self.block_args), + Tanh()) + + def create_grid_change_block(self): + return create_conv3( + in_channels=self.start_channels, + out_channels=2, + bias=False, + initialization_method='zero', + use_spectral_norm=False) \ No newline at end of file diff --git a/tha3/nn/common/poser_encoder_decoder_00.py b/tha3/nn/common/poser_encoder_decoder_00.py new file mode 100644 index 0000000000000000000000000000000000000000..acd59e873ef0f7aa45c705096da67740cd33b9b0 --- /dev/null +++ b/tha3/nn/common/poser_encoder_decoder_00.py @@ -0,0 +1,121 @@ +import math +from typing import Optional, List + +import torch +from torch import Tensor +from torch.nn import ModuleList, Module + +from tha3.nn.common.poser_args import PoserArgs00 +from tha3.nn.conv import create_conv3_block_from_block_args, create_downsample_block_from_block_args, \ + create_upsample_block_from_block_args +from tha3.nn.nonlinearity_factory import ReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.resnet_block import ResnetBlock +from tha3.nn.util import BlockArgs + + +class PoserEncoderDecoder00Args(PoserArgs00): + def __init__(self, + image_size: int, + input_image_channels: int, + output_image_channels: int, + num_pose_params: int , + start_channels: int, + bottleneck_image_size, + num_bottleneck_blocks, + max_channels: int, + block_args: Optional[BlockArgs] = None): + super().__init__( + image_size, input_image_channels, output_image_channels, start_channels, num_pose_params, block_args) + self.max_channels = max_channels + self.num_bottleneck_blocks = num_bottleneck_blocks + self.bottleneck_image_size = bottleneck_image_size + assert bottleneck_image_size > 1 + + if block_args is None: + self.block_args = BlockArgs( + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True)) + else: + self.block_args = block_args + + +class PoserEncoderDecoder00(Module): + def __init__(self, args: PoserEncoderDecoder00Args): + super().__init__() + self.args = args + + self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1 + + self.downsample_blocks = ModuleList() + self.downsample_blocks.append( + create_conv3_block_from_block_args( + args.input_image_channels, + args.start_channels, + args.block_args)) + current_image_size = args.image_size + current_num_channels = args.start_channels + while current_image_size > args.bottleneck_image_size: + next_image_size = current_image_size // 2 + next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) + self.downsample_blocks.append(create_downsample_block_from_block_args( + in_channels=current_num_channels, + out_channels=next_num_channels, + is_output_1x1=False, + block_args=args.block_args)) + current_image_size = next_image_size + current_num_channels = next_num_channels + assert len(self.downsample_blocks) == self.num_levels + + self.bottleneck_blocks = ModuleList() + self.bottleneck_blocks.append(create_conv3_block_from_block_args( + in_channels=current_num_channels + args.num_pose_params, + out_channels=current_num_channels, + block_args=args.block_args)) + for i in range(1, args.num_bottleneck_blocks): + self.bottleneck_blocks.append( + ResnetBlock.create( + num_channels=current_num_channels, + is1x1=False, + block_args=args.block_args)) + + self.upsample_blocks = ModuleList() + while current_image_size < args.image_size: + next_image_size = current_image_size * 2 + next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) + self.upsample_blocks.append(create_upsample_block_from_block_args( + in_channels=current_num_channels, + out_channels=next_num_channels, + block_args=args.block_args)) + current_image_size = next_image_size + current_num_channels = next_num_channels + + def get_num_output_channels_from_level(self, level: int): + return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level)) + + def get_num_output_channels_from_image_size(self, image_size: int): + return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels) + + def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]: + if self.args.num_pose_params != 0: + assert pose is not None + else: + assert pose is None + outputs = [] + feature = image + outputs.append(feature) + for block in self.downsample_blocks: + feature = block(feature) + outputs.append(feature) + if pose is not None: + n, c = pose.shape + pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size) + feature = torch.cat([feature, pose], dim=1) + for block in self.bottleneck_blocks: + feature = block(feature) + outputs.append(feature) + for block in self.upsample_blocks: + feature = block(feature) + outputs.append(feature) + outputs.reverse() + return outputs diff --git a/tha3/nn/common/poser_encoder_decoder_00_separable.py b/tha3/nn/common/poser_encoder_decoder_00_separable.py new file mode 100644 index 0000000000000000000000000000000000000000..5a83eb094035e4fcdacc1fbc8d452340f28bfa23 --- /dev/null +++ b/tha3/nn/common/poser_encoder_decoder_00_separable.py @@ -0,0 +1,92 @@ +import math +from typing import Optional, List + +import torch +from torch import Tensor +from torch.nn import ModuleList, Module + +from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args +from tha3.nn.resnet_block_seperable import ResnetBlockSeparable +from tha3.nn.separable_conv import create_separable_conv3_block, create_separable_downsample_block, \ + create_separable_upsample_block + + +class PoserEncoderDecoder00Separable(Module): + def __init__(self, args: PoserEncoderDecoder00Args): + super().__init__() + self.args = args + + self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1 + + self.downsample_blocks = ModuleList() + self.downsample_blocks.append( + create_separable_conv3_block( + args.input_image_channels, + args.start_channels, + args.block_args)) + current_image_size = args.image_size + current_num_channels = args.start_channels + while current_image_size > args.bottleneck_image_size: + next_image_size = current_image_size // 2 + next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) + self.downsample_blocks.append(create_separable_downsample_block( + in_channels=current_num_channels, + out_channels=next_num_channels, + is_output_1x1=False, + block_args=args.block_args)) + current_image_size = next_image_size + current_num_channels = next_num_channels + assert len(self.downsample_blocks) == self.num_levels + + self.bottleneck_blocks = ModuleList() + self.bottleneck_blocks.append(create_separable_conv3_block( + in_channels=current_num_channels + args.num_pose_params, + out_channels=current_num_channels, + block_args=args.block_args)) + for i in range(1, args.num_bottleneck_blocks): + self.bottleneck_blocks.append( + ResnetBlockSeparable.create( + num_channels=current_num_channels, + is1x1=False, + block_args=args.block_args)) + + self.upsample_blocks = ModuleList() + while current_image_size < args.image_size: + next_image_size = current_image_size * 2 + next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) + self.upsample_blocks.append(create_separable_upsample_block( + in_channels=current_num_channels, + out_channels=next_num_channels, + block_args=args.block_args)) + current_image_size = next_image_size + current_num_channels = next_num_channels + + def get_num_output_channels_from_level(self, level: int): + return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level)) + + def get_num_output_channels_from_image_size(self, image_size: int): + return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels) + + def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]: + if self.args.num_pose_params != 0: + assert pose is not None + else: + assert pose is None + outputs = [] + feature = image + outputs.append(feature) + for block in self.downsample_blocks: + feature = block(feature) + outputs.append(feature) + if pose is not None: + n, c = pose.shape + pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size) + feature = torch.cat([feature, pose], dim=1) + for block in self.bottleneck_blocks: + feature = block(feature) + outputs.append(feature) + for block in self.upsample_blocks: + feature = block(feature) + outputs.append(feature) + outputs.reverse() + return outputs diff --git a/tha3/nn/common/resize_conv_encoder_decoder.py b/tha3/nn/common/resize_conv_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d78aba4cb81e969f64c3735d2a487eed0d14d9cc --- /dev/null +++ b/tha3/nn/common/resize_conv_encoder_decoder.py @@ -0,0 +1,125 @@ +import math +from typing import Optional, List + +import torch +from torch import Tensor +from torch.nn import Module, ModuleList, Sequential, Upsample + +from tha3.nn.common.conv_block_factory import ConvBlockFactory +from tha3.nn.nonlinearity_factory import LeakyReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class ResizeConvEncoderDecoderArgs: + def __init__(self, + image_size: int, + input_channels: int, + start_channels: int, + bottleneck_image_size, + num_bottleneck_blocks, + max_channels: int, + block_args: Optional[BlockArgs] = None, + upsample_mode: str = 'bilinear', + use_separable_convolution=False): + self.use_separable_convolution = use_separable_convolution + self.upsample_mode = upsample_mode + self.block_args = block_args + self.max_channels = max_channels + self.num_bottleneck_blocks = num_bottleneck_blocks + self.bottleneck_image_size = bottleneck_image_size + self.start_channels = start_channels + self.image_size = image_size + self.input_channels = input_channels + + +class ResizeConvEncoderDecoder(Module): + def __init__(self, args: ResizeConvEncoderDecoderArgs): + super().__init__() + self.args = args + + self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1 + + conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution) + + self.downsample_blocks = ModuleList() + self.downsample_blocks.append(conv_block_factory.create_conv7_block(args.input_channels, args.start_channels)) + current_image_size = args.image_size + current_num_channels = args.start_channels + while current_image_size > args.bottleneck_image_size: + next_image_size = current_image_size // 2 + next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) + self.downsample_blocks.append(conv_block_factory.create_downsample_block( + in_channels=current_num_channels, + out_channels=next_num_channels, + is_output_1x1=False)) + current_image_size = next_image_size + current_num_channels = next_num_channels + assert len(self.downsample_blocks) == self.num_levels + + self.bottleneck_blocks = ModuleList() + for i in range(args.num_bottleneck_blocks): + self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_num_channels, is_1x1=False)) + + self.output_image_sizes = [current_image_size] + self.output_num_channels = [current_num_channels] + self.upsample_blocks = ModuleList() + if args.upsample_mode == 'nearest': + align_corners = None + else: + align_corners = False + while current_image_size < args.image_size: + next_image_size = current_image_size * 2 + next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) + self.upsample_blocks.append( + Sequential( + Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners), + conv_block_factory.create_conv3_block( + in_channels=current_num_channels, out_channels=next_num_channels))) + current_image_size = next_image_size + current_num_channels = next_num_channels + self.output_image_sizes.append(current_image_size) + self.output_num_channels.append(current_num_channels) + + def get_num_output_channels_from_level(self, level: int): + return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level)) + + def get_num_output_channels_from_image_size(self, image_size: int): + return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels) + + def forward(self, feature: Tensor) -> List[Tensor]: + outputs = [] + for block in self.downsample_blocks: + feature = block(feature) + for block in self.bottleneck_blocks: + feature = block(feature) + outputs.append(feature) + for block in self.upsample_blocks: + feature = block(feature) + outputs.append(feature) + return outputs + + +if __name__ == "__main__": + device = torch.device('cuda') + args = ResizeConvEncoderDecoderArgs( + image_size=512, + input_channels=4 + 6, + start_channels=32, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + max_channels=512, + use_separable_convolution=True, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1))) + module = ResizeConvEncoderDecoder(args).to(device) + print(module.output_image_sizes) + print(module.output_num_channels) + + input = torch.zeros(8, 4 + 6, 512, 512, device=device) + outputs = module(input) + for output in outputs: + print(output.shape) diff --git a/tha3/nn/common/resize_conv_unet.py b/tha3/nn/common/resize_conv_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b3c41479484d1c0989ebec4f98fb16b36e2079 --- /dev/null +++ b/tha3/nn/common/resize_conv_unet.py @@ -0,0 +1,155 @@ +from typing import Optional, List + +import torch +from torch import Tensor +from torch.nn import ModuleList, Module, Upsample + +from tha3.nn.common.conv_block_factory import ConvBlockFactory +from tha3.nn.nonlinearity_factory import ReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class ResizeConvUNetArgs: + def __init__(self, + image_size: int, + input_channels: int, + start_channels: int, + bottleneck_image_size: int, + num_bottleneck_blocks: int, + max_channels: int, + upsample_mode: str = 'bilinear', + block_args: Optional[BlockArgs] = None, + use_separable_convolution: bool = False): + if block_args is None: + block_args = BlockArgs( + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=False)) + + self.use_separable_convolution = use_separable_convolution + self.block_args = block_args + self.upsample_mode = upsample_mode + self.max_channels = max_channels + self.num_bottleneck_blocks = num_bottleneck_blocks + self.bottleneck_image_size = bottleneck_image_size + self.input_channels = input_channels + self.start_channels = start_channels + self.image_size = image_size + + +class ResizeConvUNet(Module): + def __init__(self, args: ResizeConvUNetArgs): + super().__init__() + self.args = args + conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution) + + self.downsample_blocks = ModuleList() + self.downsample_blocks.append(conv_block_factory.create_conv3_block( + self.args.input_channels, + self.args.start_channels)) + current_channels = self.args.start_channels + current_size = self.args.image_size + + size_to_channel = { + current_size: current_channels + } + while current_size > self.args.bottleneck_image_size: + next_size = current_size // 2 + next_channels = min(self.args.max_channels, current_channels * 2) + self.downsample_blocks.append(conv_block_factory.create_downsample_block( + current_channels, + next_channels, + is_output_1x1=False)) + current_size = next_size + current_channels = next_channels + size_to_channel[current_size] = current_channels + + self.bottleneck_blocks = ModuleList() + for i in range(self.args.num_bottleneck_blocks): + self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_channels, is_1x1=False)) + + self.output_image_sizes = [current_size] + self.output_num_channels = [current_channels] + self.upsample_blocks = ModuleList() + while current_size < self.args.image_size: + next_size = current_size * 2 + next_channels = size_to_channel[next_size] + self.upsample_blocks.append(conv_block_factory.create_conv3_block( + current_channels + next_channels, + next_channels)) + current_size = next_size + current_channels = next_channels + self.output_image_sizes.append(current_size) + self.output_num_channels.append(current_channels) + + if args.upsample_mode == 'nearest': + align_corners = None + else: + align_corners = False + self.double_resolution = Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners) + + def forward(self, feature: Tensor) -> List[Tensor]: + downsampled_features = [] + for block in self.downsample_blocks: + feature = block(feature) + downsampled_features.append(feature) + + for block in self.bottleneck_blocks: + feature = block(feature) + + outputs = [feature] + for i in range(0, len(self.upsample_blocks)): + feature = self.double_resolution(feature) + feature = torch.cat([feature, downsampled_features[-i - 2]], dim=1) + feature = self.upsample_blocks[i](feature) + outputs.append(feature) + + return outputs + + +if __name__ == "__main__": + device = torch.device('cuda') + + image_size = 512 + image_channels = 4 + num_pose_params = 6 + args = ResizeConvUNetArgs( + image_size=512, + input_channels=10, + start_channels=32, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + max_channels=512, + upsample_mode='nearest', + use_separable_convolution=False, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=False))) + module = ResizeConvUNet(args).to(device) + + image_count = 8 + input = torch.zeros(image_count, 10, 512, 512, device=device) + outputs = module.forward(input) + for output in outputs: + print(output.shape) + + + if True: + repeat = 100 + acc = 0.0 + for i in range(repeat + 2): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + module.forward(input) + end.record() + torch.cuda.synchronize() + if i >= 2: + elapsed_time = start.elapsed_time(end) + print("%d:" % i, elapsed_time) + acc = acc + elapsed_time + + print("average:", acc / repeat) \ No newline at end of file diff --git a/tha3/nn/conv.py b/tha3/nn/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad46c4c74383ac0ddb966f76768d6f9ab75e551 --- /dev/null +++ b/tha3/nn/conv.py @@ -0,0 +1,189 @@ +from typing import Optional, Union, Callable + +from torch.nn import Conv2d, Module, Sequential, ConvTranspose2d + +from tha3.module.module_factory import ModuleFactory +from tha3.nn.nonlinearity_factory import resolve_nonlinearity_factory +from tha3.nn.normalization import NormalizationLayerFactory +from tha3.nn.util import wrap_conv_or_linear_module, BlockArgs + + +def create_conv7(in_channels: int, out_channels: int, + bias: bool = False, + initialization_method: Union[str, Callable[[Module], Module]] = 'he', + use_spectral_norm: bool = False) -> Module: + return wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=bias), + initialization_method, + use_spectral_norm) + + +def create_conv7_from_block_args(in_channels: int, + out_channels: int, + bias: bool = False, + block_args: Optional[BlockArgs] = None) -> Module: + if block_args is None: + block_args = BlockArgs() + return create_conv7( + in_channels, out_channels, bias, + block_args.initialization_method, + block_args.use_spectral_norm) + + +def create_conv3(in_channels: int, + out_channels: int, + bias: bool = False, + initialization_method: Union[str, Callable[[Module], Module]] = 'he', + use_spectral_norm: bool = False) -> Module: + return wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias), + initialization_method, + use_spectral_norm) + + +def create_conv3_from_block_args(in_channels: int, out_channels: int, + bias: bool = False, + block_args: Optional[BlockArgs] = None): + if block_args is None: + block_args = BlockArgs() + return create_conv3(in_channels, out_channels, bias, + block_args.initialization_method, + block_args.use_spectral_norm) + + +def create_conv1(in_channels: int, out_channels: int, + initialization_method: Union[str, Callable[[Module], Module]] = 'he', + bias: bool = False, + use_spectral_norm: bool = False) -> Module: + return wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), + initialization_method, + use_spectral_norm) + + +def create_conv1_from_block_args(in_channels: int, + out_channels: int, + bias: bool = False, + block_args: Optional[BlockArgs] = None) -> Module: + if block_args is None: + block_args = BlockArgs() + return create_conv1( + in_channels=in_channels, + out_channels=out_channels, + initialization_method=block_args.initialization_method, + bias=bias, + use_spectral_norm=block_args.use_spectral_norm) + + +def create_conv7_block(in_channels: int, out_channels: int, + initialization_method: Union[str, Callable[[Module], Module]] = 'he', + nonlinearity_factory: Optional[ModuleFactory] = None, + normalization_layer_factory: Optional[NormalizationLayerFactory] = None, + use_spectral_norm: bool = False) -> Module: + nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) + return Sequential( + create_conv7(in_channels, out_channels, + bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm), + NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True), + resolve_nonlinearity_factory(nonlinearity_factory).create()) + + +def create_conv7_block_from_block_args( + in_channels: int, out_channels: int, + block_args: Optional[BlockArgs] = None) -> Module: + if block_args is None: + block_args = BlockArgs() + return create_conv7_block(in_channels, out_channels, + block_args.initialization_method, + block_args.nonlinearity_factory, + block_args.normalization_layer_factory, + block_args.use_spectral_norm) + + +def create_conv3_block(in_channels: int, out_channels: int, + initialization_method: Union[str, Callable[[Module], Module]] = 'he', + nonlinearity_factory: Optional[ModuleFactory] = None, + normalization_layer_factory: Optional[NormalizationLayerFactory] = None, + use_spectral_norm: bool = False) -> Module: + nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) + return Sequential( + create_conv3(in_channels, out_channels, + bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm), + NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True), + resolve_nonlinearity_factory(nonlinearity_factory).create()) + + +def create_conv3_block_from_block_args( + in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None): + if block_args is None: + block_args = BlockArgs() + return create_conv3_block(in_channels, out_channels, + block_args.initialization_method, + block_args.nonlinearity_factory, + block_args.normalization_layer_factory, + block_args.use_spectral_norm) + + +def create_downsample_block(in_channels: int, out_channels: int, + is_output_1x1: bool = False, + initialization_method: Union[str, Callable[[Module], Module]] = 'he', + nonlinearity_factory: Optional[ModuleFactory] = None, + normalization_layer_factory: Optional[NormalizationLayerFactory] = None, + use_spectral_norm: bool = False) -> Module: + if is_output_1x1: + return Sequential( + wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), + initialization_method, + use_spectral_norm), + resolve_nonlinearity_factory(nonlinearity_factory).create()) + else: + return Sequential( + wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), + initialization_method, + use_spectral_norm), + NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True), + resolve_nonlinearity_factory(nonlinearity_factory).create()) + + +def create_downsample_block_from_block_args(in_channels: int, out_channels: int, + is_output_1x1: bool = False, + block_args: Optional[BlockArgs] = None): + if block_args is None: + block_args = BlockArgs() + return create_downsample_block( + in_channels, out_channels, + is_output_1x1, + block_args.initialization_method, + block_args.nonlinearity_factory, + block_args.normalization_layer_factory, + block_args.use_spectral_norm) + + +def create_upsample_block(in_channels: int, + out_channels: int, + initialization_method: Union[str, Callable[[Module], Module]] = 'he', + nonlinearity_factory: Optional[ModuleFactory] = None, + normalization_layer_factory: Optional[NormalizationLayerFactory] = None, + use_spectral_norm: bool = False) -> Module: + nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) + return Sequential( + wrap_conv_or_linear_module( + ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), + initialization_method, + use_spectral_norm), + NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True), + resolve_nonlinearity_factory(nonlinearity_factory).create()) + + +def create_upsample_block_from_block_args(in_channels: int, + out_channels: int, + block_args: Optional[BlockArgs] = None) -> Module: + if block_args is None: + block_args = BlockArgs() + return create_upsample_block(in_channels, out_channels, + block_args.initialization_method, + block_args.nonlinearity_factory, + block_args.normalization_layer_factory, + block_args.use_spectral_norm) diff --git a/tha3/nn/editor/__init__.py b/tha3/nn/editor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/nn/editor/editor_07.py b/tha3/nn/editor/editor_07.py new file mode 100644 index 0000000000000000000000000000000000000000..08c9fe3fed09ae464287e998790ac35d9e503030 --- /dev/null +++ b/tha3/nn/editor/editor_07.py @@ -0,0 +1,180 @@ +from typing import Optional, List + +import torch +from matplotlib import pyplot +from torch import Tensor +from torch.nn import Module, Sequential, Tanh, Sigmoid + +from tha3.nn.image_processing_util import GridChangeApplier, apply_color_change +from tha3.nn.common.resize_conv_unet import ResizeConvUNet, ResizeConvUNetArgs +from tha3.util import numpy_linear_to_srgb +from tha3.module.module_factory import ModuleFactory +from tha3.nn.conv import create_conv3_from_block_args, create_conv3 +from tha3.nn.nonlinearity_factory import ReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class Editor07Args: + def __init__(self, + image_size: int = 512, + image_channels: int = 4, + num_pose_params: int = 6, + start_channels: int = 32, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + max_channels: int = 512, + upsampling_mode: str = 'nearest', + block_args: Optional[BlockArgs] = None, + use_separable_convolution: bool = False): + if block_args is None: + block_args = BlockArgs( + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=False)) + + self.block_args = block_args + self.upsampling_mode = upsampling_mode + self.max_channels = max_channels + self.num_bottleneck_blocks = num_bottleneck_blocks + self.bottleneck_image_size = bottleneck_image_size + self.start_channels = start_channels + self.num_pose_params = num_pose_params + self.image_channels = image_channels + self.image_size = image_size + self.use_separable_convolution = use_separable_convolution + + +class Editor07(Module): + def __init__(self, args: Editor07Args): + super().__init__() + self.args = args + + self.body = ResizeConvUNet(ResizeConvUNetArgs( + image_size=args.image_size, + input_channels=2 * args.image_channels + args.num_pose_params + 2, + start_channels=args.start_channels, + bottleneck_image_size=args.bottleneck_image_size, + num_bottleneck_blocks=args.num_bottleneck_blocks, + max_channels=args.max_channels, + upsample_mode=args.upsampling_mode, + block_args=args.block_args, + use_separable_convolution=args.use_separable_convolution)) + self.color_change_creator = Sequential( + create_conv3_from_block_args( + in_channels=self.args.start_channels, + out_channels=self.args.image_channels, + bias=True, + block_args=self.args.block_args), + Tanh()) + self.alpha_creator = Sequential( + create_conv3_from_block_args( + in_channels=self.args.start_channels, + out_channels=self.args.image_channels, + bias=True, + block_args=self.args.block_args), + Sigmoid()) + self.grid_change_creator = create_conv3( + in_channels=self.args.start_channels, + out_channels=2, + bias=False, + initialization_method='zero', + use_spectral_norm=False) + self.grid_change_applier = GridChangeApplier() + + def forward(self, + input_original_image: Tensor, + input_warped_image: Tensor, + input_grid_change: Tensor, + pose: Tensor, + *args) -> List[Tensor]: + n, c = pose.shape + pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.image_size, self.args.image_size) + feature = torch.cat([input_original_image, input_warped_image, input_grid_change, pose], dim=1) + + feature = self.body.forward(feature)[-1] + output_grid_change = input_grid_change + self.grid_change_creator(feature) + + output_color_change = self.color_change_creator(feature) + output_color_change_alpha = self.alpha_creator(feature) + output_warped_image = self.grid_change_applier.apply(output_grid_change, input_original_image) + output_color_changed = apply_color_change(output_color_change_alpha, output_color_change, output_warped_image) + + return [ + output_color_changed, + output_color_change_alpha, + output_color_change, + output_warped_image, + output_grid_change, + ] + + COLOR_CHANGED_IMAGE_INDEX = 0 + COLOR_CHANGE_ALPHA_INDEX = 1 + COLOR_CHANGE_IMAGE_INDEX = 2 + WARPED_IMAGE_INDEX = 3 + GRID_CHANGE_INDEX = 4 + OUTPUT_LENGTH = 5 + + +class Editor07Factory(ModuleFactory): + def __init__(self, args: Editor07Args): + super().__init__() + self.args = args + + def create(self) -> Module: + return Editor07(self.args) + + +def show_image(pytorch_image): + numpy_image = ((pytorch_image + 1.0) / 2.0).squeeze(0).numpy() + numpy_image[0:3, :, :] = numpy_linear_to_srgb(numpy_image[0:3, :, :]) + c, h, w = numpy_image.shape + numpy_image = numpy_image.reshape((c, h * w)).transpose().reshape((h, w, c)) + pyplot.imshow(numpy_image) + pyplot.show() + + +if __name__ == "__main__": + cuda = torch.device('cuda') + + image_size = 512 + image_channels = 4 + num_pose_params = 6 + args = Editor07Args( + image_size=512, + image_channels=4, + start_channels=32, + num_pose_params=6, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + max_channels=512, + upsampling_mode='nearest', + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=False))) + module = Editor07(args).to(cuda) + + image_count = 1 + input_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda) + direct_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda) + warped_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda) + grid_change = torch.zeros(image_count, 2, image_size, image_size, device=cuda) + pose = torch.zeros(image_count, num_pose_params, device=cuda) + + repeat = 100 + acc = 0.0 + for i in range(repeat + 2): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + module.forward(input_image, warped_image, grid_change, pose) + end.record() + torch.cuda.synchronize() + if i >= 2: + elapsed_time = start.elapsed_time(end) + print("%d:" % i, elapsed_time) + acc = acc + elapsed_time + + print("average:", acc / repeat) diff --git a/tha3/nn/eyebrow_decomposer/__init__.py b/tha3/nn/eyebrow_decomposer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/nn/eyebrow_decomposer/eyebrow_decomposer_00.py b/tha3/nn/eyebrow_decomposer/eyebrow_decomposer_00.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6fad9748c953f5cb78e73dc190da0410a93642 --- /dev/null +++ b/tha3/nn/eyebrow_decomposer/eyebrow_decomposer_00.py @@ -0,0 +1,102 @@ +from typing import List, Optional + +import torch +from torch import Tensor +from torch.nn import Module + +from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args, PoserEncoderDecoder00 +from tha3.nn.image_processing_util import apply_color_change +from tha3.module.module_factory import ModuleFactory +from tha3.nn.nonlinearity_factory import ReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class EyebrowDecomposer00Args(PoserEncoderDecoder00Args): + def __init__(self, + image_size: int = 128, + image_channels: int = 4, + start_channels: int = 64, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels: int = 512, + block_args: Optional[BlockArgs] = None): + super().__init__( + image_size, + image_channels, + image_channels, + 0, + start_channels, + bottleneck_image_size, + num_bottleneck_blocks, + max_channels, + block_args) + + +class EyebrowDecomposer00(Module): + def __init__(self, args: EyebrowDecomposer00Args): + super().__init__() + self.args = args + self.body = PoserEncoderDecoder00(args) + self.background_layer_alpha = self.args.create_alpha_block() + self.background_layer_color_change = self.args.create_color_change_block() + self.eyebrow_layer_alpha = self.args.create_alpha_block() + self.eyebrow_layer_color_change = self.args.create_color_change_block() + + def forward(self, image: Tensor, *args) -> List[Tensor]: + feature = self.body(image)[0] + + background_layer_alpha = self.background_layer_alpha(feature) + background_layer_color_change = self.background_layer_color_change(feature) + background_layer_1 = apply_color_change(background_layer_alpha, background_layer_color_change, image) + + eyebrow_layer_alpha = self.eyebrow_layer_alpha(feature) + eyebrow_layer_color_change = self.eyebrow_layer_color_change(feature) + eyebrow_layer = apply_color_change(eyebrow_layer_alpha, image, eyebrow_layer_color_change) + + return [ + eyebrow_layer, # 0 + eyebrow_layer_alpha, # 1 + eyebrow_layer_color_change, # 2 + background_layer_1, # 3 + background_layer_alpha, # 4 + background_layer_color_change, # 5 + ] + + EYEBROW_LAYER_INDEX = 0 + EYEBROW_LAYER_ALPHA_INDEX = 1 + EYEBROW_LAYER_COLOR_CHANGE_INDEX = 2 + BACKGROUND_LAYER_INDEX = 3 + BACKGROUND_LAYER_ALPHA_INDEX = 4 + BACKGROUND_LAYER_COLOR_CHANGE_INDEX = 5 + OUTPUT_LENGTH = 6 + + +class EyebrowDecomposer00Factory(ModuleFactory): + def __init__(self, args: EyebrowDecomposer00Args): + super().__init__() + self.args = args + + def create(self) -> Module: + return EyebrowDecomposer00(self.args) + + +if __name__ == "__main__": + cuda = torch.device('cuda') + args = EyebrowDecomposer00Args( + image_size=128, + image_channels=4, + start_channels=64, + bottleneck_image_size=16, + num_bottleneck_blocks=3, + block_args=BlockArgs( + initialization_method='xavier', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True))) + face_morpher = EyebrowDecomposer00(args).to(cuda) + + image = torch.randn(8, 4, 128, 128, device=cuda) + outputs = face_morpher.forward(image) + for i in range(len(outputs)): + print(i, outputs[i].shape) diff --git a/tha3/nn/eyebrow_decomposer/eyebrow_decomposer_03.py b/tha3/nn/eyebrow_decomposer/eyebrow_decomposer_03.py new file mode 100644 index 0000000000000000000000000000000000000000..a8c5212a8715e727a9514cd33daf9b32c7fd0a5c --- /dev/null +++ b/tha3/nn/eyebrow_decomposer/eyebrow_decomposer_03.py @@ -0,0 +1,109 @@ +from typing import List, Optional + +import torch +from torch import Tensor +from torch.nn import Module + +from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args +from tha3.nn.common.poser_encoder_decoder_00_separable import PoserEncoderDecoder00Separable +from tha3.nn.image_processing_util import apply_color_change +from tha3.module.module_factory import ModuleFactory +from tha3.nn.nonlinearity_factory import ReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class EyebrowDecomposer03Args(PoserEncoderDecoder00Args): + def __init__(self, + image_size: int = 128, + image_channels: int = 4, + start_channels: int = 64, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels: int = 512, + block_args: Optional[BlockArgs] = None): + super().__init__( + image_size, + image_channels, + image_channels, + 0, + start_channels, + bottleneck_image_size, + num_bottleneck_blocks, + max_channels, + block_args) + + +class EyebrowDecomposer03(Module): + def __init__(self, args: EyebrowDecomposer03Args): + super().__init__() + self.args = args + self.body = PoserEncoderDecoder00Separable(args) + self.background_layer_alpha = self.args.create_alpha_block() + self.background_layer_color_change = self.args.create_color_change_block() + self.eyebrow_layer_alpha = self.args.create_alpha_block() + self.eyebrow_layer_color_change = self.args.create_color_change_block() + + def forward(self, image: Tensor, *args) -> List[Tensor]: + feature = self.body(image)[0] + + background_layer_alpha = self.background_layer_alpha(feature) + background_layer_color_change = self.background_layer_color_change(feature) + background_layer_1 = apply_color_change(background_layer_alpha, background_layer_color_change, image) + + eyebrow_layer_alpha = self.eyebrow_layer_alpha(feature) + eyebrow_layer_color_change = self.eyebrow_layer_color_change(feature) + eyebrow_layer = apply_color_change(eyebrow_layer_alpha, image, eyebrow_layer_color_change) + + return [ + eyebrow_layer, # 0 + eyebrow_layer_alpha, # 1 + eyebrow_layer_color_change, # 2 + background_layer_1, # 3 + background_layer_alpha, # 4 + background_layer_color_change, # 5 + ] + + EYEBROW_LAYER_INDEX = 0 + EYEBROW_LAYER_ALPHA_INDEX = 1 + EYEBROW_LAYER_COLOR_CHANGE_INDEX = 2 + BACKGROUND_LAYER_INDEX = 3 + BACKGROUND_LAYER_ALPHA_INDEX = 4 + BACKGROUND_LAYER_COLOR_CHANGE_INDEX = 5 + OUTPUT_LENGTH = 6 + + +class EyebrowDecomposer03Factory(ModuleFactory): + def __init__(self, args: EyebrowDecomposer03Args): + super().__init__() + self.args = args + + def create(self) -> Module: + return EyebrowDecomposer03(self.args) + + +if __name__ == "__main__": + cuda = torch.device('cuda') + args = EyebrowDecomposer03Args( + image_size=128, + image_channels=4, + start_channels=64, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + block_args=BlockArgs( + initialization_method='xavier', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True))) + face_morpher = EyebrowDecomposer03(args).to(cuda) + + #image = torch.randn(8, 4, 128, 128, device=cuda) + #outputs = face_morpher.forward(image) + #for i in range(len(outputs)): + # print(i, outputs[i].shape) + + state_dict = face_morpher.state_dict() + index = 0 + for key in state_dict: + print(f"[{index}]", key, state_dict[key].shape) + index += 1 diff --git a/tha3/nn/eyebrow_morphing_combiner/__init__.py b/tha3/nn/eyebrow_morphing_combiner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_00.py b/tha3/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_00.py new file mode 100644 index 0000000000000000000000000000000000000000..0447b05e5d6bdafbc6b63f21710908b9f0bf1c20 --- /dev/null +++ b/tha3/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_00.py @@ -0,0 +1,115 @@ +from typing import List, Optional + +import torch +from torch import Tensor +from torch.nn import Module + +from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args, PoserEncoderDecoder00 +from tha3.nn.image_processing_util import apply_color_change, apply_grid_change, apply_rgb_change +from tha3.module.module_factory import ModuleFactory +from tha3.nn.nonlinearity_factory import ReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class EyebrowMorphingCombiner00Args(PoserEncoderDecoder00Args): + def __init__(self, + image_size: int = 128, + image_channels: int = 4, + num_pose_params: int = 12, + start_channels: int = 64, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels: int = 512, + block_args: Optional[BlockArgs] = None): + super().__init__( + image_size, + 2 * image_channels, + image_channels, + num_pose_params, + start_channels, + bottleneck_image_size, + num_bottleneck_blocks, + max_channels, + block_args) + + +class EyebrowMorphingCombiner00(Module): + def __init__(self, args: EyebrowMorphingCombiner00Args): + super().__init__() + self.args = args + self.body = PoserEncoderDecoder00(args) + self.morphed_eyebrow_layer_grid_change = self.args.create_grid_change_block() + self.morphed_eyebrow_layer_alpha = self.args.create_alpha_block() + self.morphed_eyebrow_layer_color_change = self.args.create_color_change_block() + self.combine_alpha = self.args.create_alpha_block() + + def forward(self, background_layer: Tensor, eyebrow_layer: Tensor, pose: Tensor, *args) -> List[Tensor]: + combined_image = torch.cat([background_layer, eyebrow_layer], dim=1) + feature = self.body(combined_image, pose)[0] + + morphed_eyebrow_layer_grid_change = self.morphed_eyebrow_layer_grid_change(feature) + morphed_eyebrow_layer_alpha = self.morphed_eyebrow_layer_alpha(feature) + morphed_eyebrow_layer_color_change = self.morphed_eyebrow_layer_color_change(feature) + warped_eyebrow_layer = apply_grid_change(morphed_eyebrow_layer_grid_change, eyebrow_layer) + morphed_eyebrow_layer = apply_color_change( + morphed_eyebrow_layer_alpha, morphed_eyebrow_layer_color_change, warped_eyebrow_layer) + + combine_alpha = self.combine_alpha(feature) + eyebrow_image = apply_rgb_change(combine_alpha, morphed_eyebrow_layer, background_layer) + eyebrow_image_no_combine_alpha = apply_rgb_change( + (morphed_eyebrow_layer[:, 3:4, :, :] + 1.0) / 2.0, morphed_eyebrow_layer, background_layer) + + return [ + eyebrow_image, # 0 + combine_alpha, # 1 + eyebrow_image_no_combine_alpha, # 2 + morphed_eyebrow_layer, # 3 + morphed_eyebrow_layer_alpha, # 4 + morphed_eyebrow_layer_color_change, # 5 + warped_eyebrow_layer, # 6 + morphed_eyebrow_layer_grid_change, # 7 + ] + + EYEBROW_IMAGE_INDEX = 0 + COMBINE_ALPHA_INDEX = 1 + EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX = 2 + MORPHED_EYEBROW_LAYER_INDEX = 3 + MORPHED_EYEBROW_LAYER_ALPHA_INDEX = 4 + MORPHED_EYEBROW_LAYER_COLOR_CHANGE_INDEX = 5 + WARPED_EYEBROW_LAYER_INDEX = 6 + MORPHED_EYEBROW_LAYER_GRID_CHANGE_INDEX = 7 + OUTPUT_LENGTH = 8 + + +class EyebrowMorphingCombiner00Factory(ModuleFactory): + def __init__(self, args: EyebrowMorphingCombiner00Args): + super().__init__() + self.args = args + + def create(self) -> Module: + return EyebrowMorphingCombiner00(self.args) + + +if __name__ == "__main__": + cuda = torch.device('cuda') + args = EyebrowMorphingCombiner00Args( + image_size=128, + image_channels=4, + num_pose_params=12, + start_channels=64, + bottleneck_image_size=16, + num_bottleneck_blocks=3, + block_args=BlockArgs( + initialization_method='xavier', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True))) + face_morpher = EyebrowMorphingCombiner00(args).to(cuda) + + background_layer = torch.randn(8, 4, 128, 128, device=cuda) + eyebrow_layer = torch.randn(8, 4, 128, 128, device=cuda) + pose = torch.randn(8, 12, device=cuda) + outputs = face_morpher.forward(background_layer, eyebrow_layer, pose) + for i in range(len(outputs)): + print(i, outputs[i].shape) diff --git a/tha3/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_03.py b/tha3/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_03.py new file mode 100644 index 0000000000000000000000000000000000000000..e37cfcc9d255a891a2417a0a6c18074b59444c99 --- /dev/null +++ b/tha3/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_03.py @@ -0,0 +1,117 @@ +from typing import List, Optional + +import torch +from torch import Tensor +from torch.nn import Module + +from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args +from tha3.nn.common.poser_encoder_decoder_00_separable import PoserEncoderDecoder00Separable +from tha3.nn.image_processing_util import apply_color_change, apply_rgb_change, GridChangeApplier +from tha3.module.module_factory import ModuleFactory +from tha3.nn.nonlinearity_factory import ReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class EyebrowMorphingCombiner03Args(PoserEncoderDecoder00Args): + def __init__(self, + image_size: int = 128, + image_channels: int = 4, + num_pose_params: int = 12, + start_channels: int = 64, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels: int = 512, + block_args: Optional[BlockArgs] = None): + super().__init__( + image_size, + 2 * image_channels, + image_channels, + num_pose_params, + start_channels, + bottleneck_image_size, + num_bottleneck_blocks, + max_channels, + block_args) + + +class EyebrowMorphingCombiner03(Module): + def __init__(self, args: EyebrowMorphingCombiner03Args): + super().__init__() + self.args = args + self.body = PoserEncoderDecoder00Separable(args) + self.morphed_eyebrow_layer_grid_change = self.args.create_grid_change_block() + self.morphed_eyebrow_layer_alpha = self.args.create_alpha_block() + self.morphed_eyebrow_layer_color_change = self.args.create_color_change_block() + self.combine_alpha = self.args.create_alpha_block() + self.grid_change_applier = GridChangeApplier() + + def forward(self, background_layer: Tensor, eyebrow_layer: Tensor, pose: Tensor, *args) -> List[Tensor]: + combined_image = torch.cat([background_layer, eyebrow_layer], dim=1) + feature = self.body(combined_image, pose)[0] + + morphed_eyebrow_layer_grid_change = self.morphed_eyebrow_layer_grid_change(feature) + morphed_eyebrow_layer_alpha = self.morphed_eyebrow_layer_alpha(feature) + morphed_eyebrow_layer_color_change = self.morphed_eyebrow_layer_color_change(feature) + warped_eyebrow_layer = self.grid_change_applier.apply(morphed_eyebrow_layer_grid_change, eyebrow_layer) + morphed_eyebrow_layer = apply_color_change( + morphed_eyebrow_layer_alpha, morphed_eyebrow_layer_color_change, warped_eyebrow_layer) + + combine_alpha = self.combine_alpha(feature) + eyebrow_image = apply_rgb_change(combine_alpha, morphed_eyebrow_layer, background_layer) + eyebrow_image_no_combine_alpha = apply_rgb_change( + (morphed_eyebrow_layer[:, 3:4, :, :] + 1.0) / 2.0, morphed_eyebrow_layer, background_layer) + + return [ + eyebrow_image, # 0 + combine_alpha, # 1 + eyebrow_image_no_combine_alpha, # 2 + morphed_eyebrow_layer, # 3 + morphed_eyebrow_layer_alpha, # 4 + morphed_eyebrow_layer_color_change, # 5 + warped_eyebrow_layer, # 6 + morphed_eyebrow_layer_grid_change, # 7 + ] + + EYEBROW_IMAGE_INDEX = 0 + COMBINE_ALPHA_INDEX = 1 + EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX = 2 + MORPHED_EYEBROW_LAYER_INDEX = 3 + MORPHED_EYEBROW_LAYER_ALPHA_INDEX = 4 + MORPHED_EYEBROW_LAYER_COLOR_CHANGE_INDEX = 5 + WARPED_EYEBROW_LAYER_INDEX = 6 + MORPHED_EYEBROW_LAYER_GRID_CHANGE_INDEX = 7 + OUTPUT_LENGTH = 8 + + +class EyebrowMorphingCombiner03Factory(ModuleFactory): + def __init__(self, args: EyebrowMorphingCombiner03Args): + super().__init__() + self.args = args + + def create(self) -> Module: + return EyebrowMorphingCombiner03(self.args) + + +if __name__ == "__main__": + cuda = torch.device('cuda') + args = EyebrowMorphingCombiner03Args( + image_size=128, + image_channels=4, + num_pose_params=12, + start_channels=64, + bottleneck_image_size=16, + num_bottleneck_blocks=3, + block_args=BlockArgs( + initialization_method='xavier', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True))) + face_morpher = EyebrowMorphingCombiner03(args).to(cuda) + + background_layer = torch.randn(8, 4, 128, 128, device=cuda) + eyebrow_layer = torch.randn(8, 4, 128, 128, device=cuda) + pose = torch.randn(8, 12, device=cuda) + outputs = face_morpher.forward(background_layer, eyebrow_layer, pose) + for i in range(len(outputs)): + print(i, outputs[i].shape) diff --git a/tha3/nn/face_morpher/__init__.py b/tha3/nn/face_morpher/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/nn/face_morpher/face_morpher_08.py b/tha3/nn/face_morpher/face_morpher_08.py new file mode 100644 index 0000000000000000000000000000000000000000..2b77b469e04a03538fb9759a2b95e2f6fb3c08be --- /dev/null +++ b/tha3/nn/face_morpher/face_morpher_08.py @@ -0,0 +1,241 @@ +import math +from typing import List, Optional + +import torch +from torch import Tensor +from torch.nn import ModuleList, Sequential, Sigmoid, Tanh, Module +from torch.nn.functional import affine_grid, grid_sample + +from tha3.module.module_factory import ModuleFactory +from tha3.nn.conv import create_conv3_block_from_block_args, \ + create_downsample_block_from_block_args, create_upsample_block_from_block_args, create_conv3_from_block_args, \ + create_conv3 +from tha3.nn.nonlinearity_factory import LeakyReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.resnet_block import ResnetBlock +from tha3.nn.util import BlockArgs + + +class FaceMorpher08Args: + def __init__(self, + image_size: int = 256, + image_channels: int = 4, + num_expression_params: int = 67, + start_channels: int = 16, + bottleneck_image_size=4, + num_bottleneck_blocks=3, + max_channels: int = 512, + block_args: Optional[BlockArgs] = None): + self.max_channels = max_channels + self.num_bottleneck_blocks = num_bottleneck_blocks + assert bottleneck_image_size > 1 + self.bottleneck_image_size = bottleneck_image_size + self.start_channels = start_channels + self.image_channels = image_channels + self.num_expression_params = num_expression_params + self.image_size = image_size + + if block_args is None: + self.block_args = BlockArgs( + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(negative_slope=0.2, inplace=True)) + else: + self.block_args = block_args + + +class FaceMorpher08(Module): + def __init__(self, args: FaceMorpher08Args): + super().__init__() + self.args = args + self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1 + + self.downsample_blocks = ModuleList() + self.downsample_blocks.append( + create_conv3_block_from_block_args( + args.image_channels, + args.start_channels, + args.block_args)) + current_image_size = args.image_size + current_num_channels = args.start_channels + while current_image_size > args.bottleneck_image_size: + next_image_size = current_image_size // 2 + next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) + self.downsample_blocks.append(create_downsample_block_from_block_args( + in_channels=current_num_channels, + out_channels=next_num_channels, + is_output_1x1=False, + block_args=args.block_args)) + current_image_size = next_image_size + current_num_channels = next_num_channels + assert len(self.downsample_blocks) == self.num_levels + + self.bottleneck_blocks = ModuleList() + self.bottleneck_blocks.append(create_conv3_block_from_block_args( + in_channels=current_num_channels + args.num_expression_params, + out_channels=current_num_channels, + block_args=args.block_args)) + for i in range(1, args.num_bottleneck_blocks): + self.bottleneck_blocks.append( + ResnetBlock.create( + num_channels=current_num_channels, + is1x1=False, + block_args=args.block_args)) + + self.upsample_blocks = ModuleList() + while current_image_size < args.image_size: + next_image_size = current_image_size * 2 + next_num_channels = self.get_num_output_channels_from_image_size(next_image_size) + self.upsample_blocks.append(create_upsample_block_from_block_args( + in_channels=current_num_channels, + out_channels=next_num_channels, + block_args=args.block_args)) + current_image_size = next_image_size + current_num_channels = next_num_channels + + self.iris_mouth_grid_change = self.create_grid_change_block() + self.iris_mouth_color_change = self.create_color_change_block() + self.iris_mouth_alpha = self.create_alpha_block() + + self.eye_color_change = self.create_color_change_block() + self.eye_alpha = self.create_alpha_block() + + def create_alpha_block(self): + return Sequential( + create_conv3( + in_channels=self.args.start_channels, + out_channels=1, + bias=True, + initialization_method=self.args.block_args.initialization_method, + use_spectral_norm=False), + Sigmoid()) + + def create_color_change_block(self): + return Sequential( + create_conv3_from_block_args( + in_channels=self.args.start_channels, + out_channels=self.args.image_channels, + bias=True, + block_args=self.args.block_args), + Tanh()) + + def create_grid_change_block(self): + return create_conv3( + in_channels=self.args.start_channels, + out_channels=2, + bias=False, + initialization_method='zero', + use_spectral_norm=False) + + def get_num_output_channels_from_level(self, level: int): + return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level)) + + def get_num_output_channels_from_image_size(self, image_size: int): + return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels) + + def merge_down(self, top_layer: Tensor, bottom_layer: Tensor): + top_layer_rgb = top_layer[:, 0:3, :, :] + top_layer_a = top_layer[:, 3:4, :, :] + return bottom_layer * (1-top_layer_a) + torch.cat([top_layer_rgb * top_layer_a, top_layer_a], dim=1) + + def apply_grid_change(self, grid_change, image: Tensor) -> Tensor: + n, c, h, w = image.shape + device = grid_change.device + grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2) + identity = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + device=device, + dtype=grid_change.dtype).unsqueeze(0).repeat(n, 1, 1) + base_grid = affine_grid(identity, [n, c, h, w], align_corners=False) + grid = base_grid + grid_change + resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False) + return resampled_image + + def apply_color_change(self, alpha, color_change, image: Tensor) -> Tensor: + return color_change * alpha + image * (1 - alpha) + + def forward(self, image: Tensor, pose: Tensor, *args) -> List[Tensor]: + feature = image + for block in self.downsample_blocks: + feature = block(feature) + n, c = pose.shape + pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size) + feature = torch.cat([feature, pose], dim=1) + for block in self.bottleneck_blocks: + feature = block(feature) + for block in self.upsample_blocks: + feature = block(feature) + + iris_mouth_grid_change = self.iris_mouth_grid_change(feature) + iris_mouth_image_0 = self.apply_grid_change(iris_mouth_grid_change, image) + iris_mouth_color_change = self.iris_mouth_color_change(feature) + iris_mouth_alpha = self.iris_mouth_alpha(feature) + iris_mouth_image_1 = self.apply_color_change(iris_mouth_alpha, iris_mouth_color_change, iris_mouth_image_0) + + eye_color_change = self.eye_color_change(feature) + eye_alpha = self.eye_alpha(feature) + output_image = self.apply_color_change(eye_alpha, eye_color_change, iris_mouth_image_1.detach()) + + return [ + output_image, #0 + eye_alpha, #1 + eye_color_change, #2 + iris_mouth_image_1, #3 + iris_mouth_alpha, #4 + iris_mouth_color_change, #5 + iris_mouth_image_0, #6 + ] + + OUTPUT_IMAGE_INDEX = 0 + EYE_ALPHA_INDEX = 1 + EYE_COLOR_CHANGE_INDEX = 2 + IRIS_MOUTH_IMAGE_1_INDEX = 3 + IRIS_MOUTH_ALPHA_INDEX = 4 + IRIS_MOUTH_COLOR_CHANGE_INDEX = 5 + IRIS_MOUTh_IMAGE_0_INDEX = 6 + + +class FaceMorpher08Factory(ModuleFactory): + def __init__(self, args: FaceMorpher08Args): + super().__init__() + self.args = args + + def create(self) -> Module: + return FaceMorpher08(self.args) + + +if __name__ == "__main__": + cuda = torch.device('cuda') + args = FaceMorpher08Args( + image_size=256, + image_channels=4, + num_expression_params=12, + start_channels=64, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=True, negative_slope=0.2))) + module = FaceMorpher08(args).to(cuda) + + image = torch.zeros(16, 4, 256, 256, device=cuda) + pose = torch.zeros(16, 12, device=cuda) + + repeat = 100 + acc = 0.0 + for i in range(repeat + 2): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + module.forward(image, pose) + end.record() + torch.cuda.synchronize() + + if i >= 2: + elapsed_time = start.elapsed_time(end) + print("%d:" % i, elapsed_time) + acc += elapsed_time + + print("average:", acc / repeat) \ No newline at end of file diff --git a/tha3/nn/face_morpher/face_morpher_09.py b/tha3/nn/face_morpher/face_morpher_09.py new file mode 100644 index 0000000000000000000000000000000000000000..46678e2e0d39c52a8645e10d8f2994a0aa87a0d0 --- /dev/null +++ b/tha3/nn/face_morpher/face_morpher_09.py @@ -0,0 +1,187 @@ +from typing import List, Optional + +import torch +from torch import Tensor +from torch.nn import Sequential, Sigmoid, Tanh, Module +from torch.nn.functional import affine_grid, grid_sample + +from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args +from tha3.nn.common.poser_encoder_decoder_00_separable import PoserEncoderDecoder00Separable +from tha3.nn.image_processing_util import GridChangeApplier +from tha3.module.module_factory import ModuleFactory +from tha3.nn.conv import create_conv3_from_block_args, create_conv3 +from tha3.nn.nonlinearity_factory import LeakyReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class FaceMorpher09Args(PoserEncoderDecoder00Args): + def __init__(self, + image_size: int = 256, + image_channels: int = 4, + num_pose_params: int = 67, + start_channels: int = 16, + bottleneck_image_size=4, + num_bottleneck_blocks=3, + max_channels: int = 512, + block_args: Optional[BlockArgs] = None): + super().__init__( + image_size, + image_channels, + image_channels, + num_pose_params, + start_channels, + bottleneck_image_size, + num_bottleneck_blocks, + max_channels, + block_args) + + +class FaceMorpher09(Module): + def __init__(self, args: FaceMorpher09Args): + super().__init__() + self.args = args + self.body = PoserEncoderDecoder00Separable(args) + + self.iris_mouth_grid_change = self.create_grid_change_block() + self.iris_mouth_color_change = self.create_color_change_block() + self.iris_mouth_alpha = self.create_alpha_block() + + self.eye_color_change = self.create_color_change_block() + self.eye_alpha = self.create_alpha_block() + + self.grid_change_applier = GridChangeApplier() + + def create_alpha_block(self): + return Sequential( + create_conv3( + in_channels=self.args.start_channels, + out_channels=1, + bias=True, + initialization_method=self.args.block_args.initialization_method, + use_spectral_norm=False), + Sigmoid()) + + def create_color_change_block(self): + return Sequential( + create_conv3_from_block_args( + in_channels=self.args.start_channels, + out_channels=self.args.input_image_channels, + bias=True, + block_args=self.args.block_args), + Tanh()) + + def create_grid_change_block(self): + return create_conv3( + in_channels=self.args.start_channels, + out_channels=2, + bias=False, + initialization_method='zero', + use_spectral_norm=False) + + def get_num_output_channels_from_level(self, level: int): + return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level)) + + def get_num_output_channels_from_image_size(self, image_size: int): + return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels) + + def forward(self, image: Tensor, pose: Tensor, *args) -> List[Tensor]: + feature = self.body(image, pose)[0] + + iris_mouth_grid_change = self.iris_mouth_grid_change(feature) + iris_mouth_image_0 = self.grid_change_applier.apply(iris_mouth_grid_change, image) + iris_mouth_color_change = self.iris_mouth_color_change(feature) + iris_mouth_alpha = self.iris_mouth_alpha(feature) + iris_mouth_image_1 = self.apply_color_change(iris_mouth_alpha, iris_mouth_color_change, iris_mouth_image_0) + + eye_color_change = self.eye_color_change(feature) + eye_alpha = self.eye_alpha(feature) + output_image = self.apply_color_change(eye_alpha, eye_color_change, iris_mouth_image_1.detach()) + + return [ + output_image, # 0 + eye_alpha, # 1 + eye_color_change, # 2 + iris_mouth_image_1, # 3 + iris_mouth_alpha, # 4 + iris_mouth_color_change, # 5 + iris_mouth_image_0, # 6 + ] + + OUTPUT_IMAGE_INDEX = 0 + EYE_ALPHA_INDEX = 1 + EYE_COLOR_CHANGE_INDEX = 2 + IRIS_MOUTH_IMAGE_1_INDEX = 3 + IRIS_MOUTH_ALPHA_INDEX = 4 + IRIS_MOUTH_COLOR_CHANGE_INDEX = 5 + IRIS_MOUTh_IMAGE_0_INDEX = 6 + + def merge_down(self, top_layer: Tensor, bottom_layer: Tensor): + top_layer_rgb = top_layer[:, 0:3, :, :] + top_layer_a = top_layer[:, 3:4, :, :] + return bottom_layer * (1 - top_layer_a) + torch.cat([top_layer_rgb * top_layer_a, top_layer_a], dim=1) + + def apply_grid_change(self, grid_change, image: Tensor) -> Tensor: + n, c, h, w = image.shape + device = grid_change.device + grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2) + identity = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device).unsqueeze(0).repeat(n, 1, 1) + base_grid = affine_grid(identity, [n, c, h, w], align_corners=False) + grid = base_grid + grid_change + resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False) + return resampled_image + + def apply_color_change(self, alpha, color_change, image: Tensor) -> Tensor: + return color_change * alpha + image * (1 - alpha) + + +class FaceMorpher09Factory(ModuleFactory): + def __init__(self, args: FaceMorpher09Args): + super().__init__() + self.args = args + + def create(self) -> Module: + return FaceMorpher09(self.args) + + +if __name__ == "__main__": + cuda = torch.device('cuda') + args = FaceMorpher09Args( + image_size=256, + image_channels=4, + num_pose_params=12, + start_channels=64, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + block_args=BlockArgs( + initialization_method='xavier', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=True, negative_slope=0.2))) + module = FaceMorpher09(args).to(cuda) + + image = torch.zeros(16, 4, 256, 256, device=cuda) + pose = torch.zeros(16, 12, device=cuda) + + state_dict = module.state_dict() + for key in state_dict: + print(key, state_dict[key].shape) + + if False: + repeat = 100 + acc = 0.0 + for i in range(repeat + 2): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + module.forward(image, pose) + end.record() + torch.cuda.synchronize() + + if i >= 2: + elapsed_time = start.elapsed_time(end) + print("%d:" % i, elapsed_time) + acc += elapsed_time + + print("average:", acc / repeat) \ No newline at end of file diff --git a/tha3/nn/image_processing_util.py b/tha3/nn/image_processing_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d1a05fb07569e5c91daa298139ec081b5845a9 --- /dev/null +++ b/tha3/nn/image_processing_util.py @@ -0,0 +1,58 @@ +import torch +from torch import Tensor +from torch.nn.functional import affine_grid, grid_sample + + +def apply_rgb_change(alpha: Tensor, color_change: Tensor, image: Tensor): + image_rgb = image[:, 0:3, :, :] + color_change_rgb = color_change[:, 0:3, :, :] + output_rgb = color_change_rgb * alpha + image_rgb * (1 - alpha) + return torch.cat([output_rgb, image[:, 3:4, :, :]], dim=1) + + +def apply_grid_change(grid_change, image: Tensor) -> Tensor: + n, c, h, w = image.shape + device = grid_change.device + grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2) + identity = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + dtype=grid_change.dtype, + device=device).unsqueeze(0).repeat(n, 1, 1) + base_grid = affine_grid(identity, [n, c, h, w], align_corners=False) + grid = base_grid + grid_change + resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False) + return resampled_image + + +class GridChangeApplier: + def __init__(self): + self.last_n = None + self.last_device = None + self.last_identity = None + + def apply(self, grid_change: Tensor, image: Tensor, align_corners: bool = False) -> Tensor: + n, c, h, w = image.shape + device = grid_change.device + grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2) + + if n == self.last_n and device == self.last_device: + identity = self.last_identity + else: + identity = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + dtype=grid_change.dtype, + device=device, + requires_grad=False) \ + .unsqueeze(0).repeat(n, 1, 1) + self.last_identity = identity + self.last_n = n + self.last_device = device + base_grid = affine_grid(identity, [n, c, h, w], align_corners=align_corners) + + grid = base_grid + grid_change + resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=align_corners) + return resampled_image + + +def apply_color_change(alpha, color_change, image: Tensor) -> Tensor: + return color_change * alpha + image * (1 - alpha) diff --git a/tha3/nn/init_function.py b/tha3/nn/init_function.py new file mode 100644 index 0000000000000000000000000000000000000000..b5061bde67c9ff573149de095f1c531c498ad8db --- /dev/null +++ b/tha3/nn/init_function.py @@ -0,0 +1,76 @@ +from typing import Callable + +import torch +from torch import zero_ +from torch.nn import Module +from torch.nn.init import kaiming_normal_, xavier_normal_, normal_ + + +def create_init_function(method: str = 'none') -> Callable[[Module], Module]: + def init(module: Module): + if method == 'none': + return module + elif method == 'he': + kaiming_normal_(module.weight) + return module + elif method == 'xavier': + xavier_normal_(module.weight) + return module + elif method == 'dcgan': + normal_(module.weight, 0.0, 0.02) + return module + elif method == 'dcgan_001': + normal_(module.weight, 0.0, 0.01) + return module + elif method == "zero": + with torch.no_grad(): + zero_(module.weight) + return module + else: + raise ("Invalid initialization method %s" % method) + + return init + + +class HeInitialization: + def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'): + self.nonlinearity = nonlinearity + self.mode = mode + self.a = a + + def __call__(self, module: Module) -> Module: + with torch.no_grad(): + kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity) + return module + + +class NormalInitialization: + def __init__(self, mean: float = 0.0, std: float = 1.0): + self.std = std + self.mean = mean + + def __call__(self, module: Module) -> Module: + with torch.no_grad(): + normal_(module.weight, self.mean, self.std) + return module + + +class XavierInitialization: + def __init__(self, gain: float = 1.0): + self.gain = gain + + def __call__(self, module: Module) -> Module: + with torch.no_grad(): + xavier_normal_(module.weight, self.gain) + return module + + +class ZeroInitialization: + def __call__(self, module: Module) -> Module: + with torch.no_grad: + zero_(module.weight) + return module + +class NoInitialization: + def __call__(self, module: Module) -> Module: + return module \ No newline at end of file diff --git a/tha3/nn/nonlinearity_factory.py b/tha3/nn/nonlinearity_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..db8af392a8315beaac8a565697f6ca458b02d7b9 --- /dev/null +++ b/tha3/nn/nonlinearity_factory.py @@ -0,0 +1,72 @@ +from typing import Optional + +from torch.nn import Module, ReLU, LeakyReLU, ELU, ReLU6, Hardswish, SiLU, Tanh, Sigmoid + +from tha3.module.module_factory import ModuleFactory + + +class ReLUFactory(ModuleFactory): + def __init__(self, inplace: bool = False): + self.inplace = inplace + + def create(self) -> Module: + return ReLU(self.inplace) + + +class LeakyReLUFactory(ModuleFactory): + def __init__(self, inplace: bool = False, negative_slope: float = 1e-2): + self.negative_slope = negative_slope + self.inplace = inplace + + def create(self) -> Module: + return LeakyReLU(inplace=self.inplace, negative_slope=self.negative_slope) + + +class ELUFactory(ModuleFactory): + def __init__(self, inplace: bool = False, alpha: float = 1.0): + self.alpha = alpha + self.inplace = inplace + + def create(self) -> Module: + return ELU(inplace=self.inplace, alpha=self.alpha) + + +class ReLU6Factory(ModuleFactory): + def __init__(self, inplace: bool = False): + self.inplace = inplace + + def create(self) -> Module: + return ReLU6(inplace=self.inplace) + + +class SiLUFactory(ModuleFactory): + def __init__(self, inplace: bool = False): + self.inplace = inplace + + def create(self) -> Module: + return SiLU(inplace=self.inplace) + + +class HardswishFactory(ModuleFactory): + def __init__(self, inplace: bool = False): + self.inplace = inplace + + def create(self) -> Module: + return Hardswish(inplace=self.inplace) + + +class TanhFactory(ModuleFactory): + def create(self) -> Module: + return Tanh() + + +class SigmoidFactory(ModuleFactory): + def create(self) -> Module: + return Sigmoid() + + +def resolve_nonlinearity_factory(nonlinearity_fatory: Optional[ModuleFactory]) -> ModuleFactory: + if nonlinearity_fatory is None: + return ReLUFactory(inplace=False) + else: + return nonlinearity_fatory diff --git a/tha3/nn/normalization.py b/tha3/nn/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a21e0de45bc95f38fbb7ca3f8da523e6b08225 --- /dev/null +++ b/tha3/nn/normalization.py @@ -0,0 +1,126 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from torch import layer_norm +from torch.nn import Module, BatchNorm2d, InstanceNorm2d, Parameter +from torch.nn.init import normal_, constant_ + +from tha3.nn.pass_through import PassThrough + + +class PixelNormalization(Module): + def __init__(self, epsilon=1e-8): + super().__init__() + self.epsilon = epsilon + + def forward(self, x): + return x / torch.sqrt((x ** 2).mean(dim=1, keepdim=True) + self.epsilon) + + +class NormalizationLayerFactory(ABC): + def __init__(self): + super().__init__() + + @abstractmethod + def create(self, num_features: int, affine: bool = True) -> Module: + pass + + @staticmethod + def resolve_2d(factory: Optional['NormalizationLayerFactory']) -> 'NormalizationLayerFactory': + if factory is None: + return InstanceNorm2dFactory() + else: + return factory + + +class Bias2d(Module): + def __init__(self, num_features: int): + super().__init__() + self.num_features = num_features + self.bias = Parameter(torch.zeros(1, num_features, 1, 1)) + + def forward(self, x): + return x + self.bias + + +class NoNorm2dFactory(NormalizationLayerFactory): + def __init__(self): + super().__init__() + + def create(self, num_features: int, affine: bool = True) -> Module: + if affine: + return Bias2d(num_features) + else: + return PassThrough() + + +class BatchNorm2dFactory(NormalizationLayerFactory): + def __init__(self, + weight_mean: Optional[float] = None, + weight_std: Optional[float] = None, + bias: Optional[float] = None): + super().__init__() + self.bias = bias + self.weight_std = weight_std + self.weight_mean = weight_mean + + def get_weight_mean(self): + if self.weight_mean is None: + return 1.0 + else: + return self.weight_mean + + def get_weight_std(self): + if self.weight_std is None: + return 0.02 + else: + return self.weight_std + + def create(self, num_features: int, affine: bool = True) -> Module: + module = BatchNorm2d(num_features=num_features, affine=affine) + if affine: + if self.weight_mean is not None or self.weight_std is not None: + normal_(module.weight, self.get_weight_mean(), self.get_weight_std()) + if self.bias is not None: + constant_(module.bias, self.bias) + return module + + +class InstanceNorm2dFactory(NormalizationLayerFactory): + def __init__(self): + super().__init__() + + def create(self, num_features: int, affine: bool = True) -> Module: + return InstanceNorm2d(num_features=num_features, affine=affine) + + +class PixelNormFactory(NormalizationLayerFactory): + def __init__(self): + super().__init__() + + def create(self, num_features: int, affine: bool = True) -> Module: + return PixelNormalization() + + +class LayerNorm2d(Module): + def __init__(self, channels: int, affine: bool = True): + super(LayerNorm2d, self).__init__() + self.channels = channels + self.affine = affine + + if self.affine: + self.weight = Parameter(torch.ones(1, channels, 1, 1)) + self.bias = Parameter(torch.zeros(1, channels, 1, 1)) + + def forward(self, x): + shape = x.size()[1:] + y = layer_norm(x, shape) * self.weight + self.bias + return y + +class LayerNorm2dFactory(NormalizationLayerFactory): + def __init__(self): + super().__init__() + + def create(self, num_features: int, affine: bool = True) -> Module: + return LayerNorm2d(channels=num_features, affine=affine) diff --git a/tha3/nn/pass_through.py b/tha3/nn/pass_through.py new file mode 100644 index 0000000000000000000000000000000000000000..c64d6786df0fb57fbff19ebaa80e0f41ab6fce71 --- /dev/null +++ b/tha3/nn/pass_through.py @@ -0,0 +1,9 @@ +from torch.nn import Module + + +class PassThrough(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x \ No newline at end of file diff --git a/tha3/nn/resnet_block.py b/tha3/nn/resnet_block.py new file mode 100644 index 0000000000000000000000000000000000000000..60480253fc6019e0c79f2228c7e73434391184fd --- /dev/null +++ b/tha3/nn/resnet_block.py @@ -0,0 +1,67 @@ +from typing import Optional + +import torch +from torch.nn import Module, Sequential, Parameter + +from tha3.module.module_factory import ModuleFactory +from tha3.nn.conv import create_conv1, create_conv3 +from tha3.nn.nonlinearity_factory import resolve_nonlinearity_factory +from tha3.nn.normalization import NormalizationLayerFactory +from tha3.nn.util import BlockArgs + + +class ResnetBlock(Module): + @staticmethod + def create(num_channels: int, + is1x1: bool = False, + use_scale_parameters: bool = False, + block_args: Optional[BlockArgs] = None): + if block_args is None: + block_args = BlockArgs() + return ResnetBlock(num_channels, + is1x1, + block_args.initialization_method, + block_args.nonlinearity_factory, + block_args.normalization_layer_factory, + block_args.use_spectral_norm, + use_scale_parameters) + + def __init__(self, + num_channels: int, + is1x1: bool = False, + initialization_method: str = 'he', + nonlinearity_factory: ModuleFactory = None, + normalization_layer_factory: Optional[NormalizationLayerFactory] = None, + use_spectral_norm: bool = False, + use_scale_parameter: bool = False): + super().__init__() + self.use_scale_parameter = use_scale_parameter + if self.use_scale_parameter: + self.scale = Parameter(torch.zeros(1)) + nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) + if is1x1: + self.resnet_path = Sequential( + create_conv1(num_channels, num_channels, initialization_method, + bias=True, + use_spectral_norm=use_spectral_norm), + nonlinearity_factory.create(), + create_conv1(num_channels, num_channels, initialization_method, + bias=True, + use_spectral_norm=use_spectral_norm)) + else: + self.resnet_path = Sequential( + create_conv3(num_channels, num_channels, + bias=False, initialization_method=initialization_method, + use_spectral_norm=use_spectral_norm), + NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True), + nonlinearity_factory.create(), + create_conv3(num_channels, num_channels, + bias=False, initialization_method=initialization_method, + use_spectral_norm=use_spectral_norm), + NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True)) + + def forward(self, x): + if self.use_scale_parameter: + return x + self.scale * self.resnet_path(x) + else: + return x + self.resnet_path(x) diff --git a/tha3/nn/resnet_block_seperable.py b/tha3/nn/resnet_block_seperable.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd3621020419a682ae82e55d612e8eda7a0dfe1 --- /dev/null +++ b/tha3/nn/resnet_block_seperable.py @@ -0,0 +1,71 @@ +from typing import Optional + +import torch +from torch.nn import Module, Sequential, Parameter + +from tha3.module.module_factory import ModuleFactory +from tha3.nn.conv import create_conv1 +from tha3.nn.nonlinearity_factory import resolve_nonlinearity_factory +from tha3.nn.normalization import NormalizationLayerFactory +from tha3.nn.separable_conv import create_separable_conv3 +from tha3.nn.util import BlockArgs + + +class ResnetBlockSeparable(Module): + @staticmethod + def create(num_channels: int, + is1x1: bool = False, + use_scale_parameters: bool = False, + block_args: Optional[BlockArgs] = None): + if block_args is None: + block_args = BlockArgs() + return ResnetBlockSeparable( + num_channels, + is1x1, + block_args.initialization_method, + block_args.nonlinearity_factory, + block_args.normalization_layer_factory, + block_args.use_spectral_norm, + use_scale_parameters) + + def __init__(self, + num_channels: int, + is1x1: bool = False, + initialization_method: str = 'he', + nonlinearity_factory: ModuleFactory = None, + normalization_layer_factory: Optional[NormalizationLayerFactory] = None, + use_spectral_norm: bool = False, + use_scale_parameter: bool = False): + super().__init__() + self.use_scale_parameter = use_scale_parameter + if self.use_scale_parameter: + self.scale = Parameter(torch.zeros(1)) + nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) + if is1x1: + self.resnet_path = Sequential( + create_conv1(num_channels, num_channels, initialization_method, + bias=True, + use_spectral_norm=use_spectral_norm), + nonlinearity_factory.create(), + create_conv1(num_channels, num_channels, initialization_method, + bias=True, + use_spectral_norm=use_spectral_norm)) + else: + self.resnet_path = Sequential( + create_separable_conv3( + num_channels, num_channels, + bias=False, initialization_method=initialization_method, + use_spectral_norm=use_spectral_norm), + NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True), + nonlinearity_factory.create(), + create_separable_conv3( + num_channels, num_channels, + bias=False, initialization_method=initialization_method, + use_spectral_norm=use_spectral_norm), + NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True)) + + def forward(self, x): + if self.use_scale_parameter: + return x + self.scale * self.resnet_path(x) + else: + return x + self.resnet_path(x) diff --git a/tha3/nn/separable_conv.py b/tha3/nn/separable_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..e33bce3c25aa279e7a7fbb0a7998a3f3788e4c25 --- /dev/null +++ b/tha3/nn/separable_conv.py @@ -0,0 +1,119 @@ +from typing import Optional + +from torch.nn import Sequential, Conv2d, ConvTranspose2d, Module + +from tha3.nn.normalization import NormalizationLayerFactory +from tha3.nn.util import BlockArgs, wrap_conv_or_linear_module + + +def create_separable_conv3(in_channels: int, out_channels: int, + bias: bool = False, + initialization_method='he', + use_spectral_norm: bool = False) -> Module: + return Sequential( + wrap_conv_or_linear_module( + Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=in_channels), + initialization_method, + use_spectral_norm), + wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), + initialization_method, + use_spectral_norm)) + + +def create_separable_conv7(in_channels: int, out_channels: int, + bias: bool = False, + initialization_method='he', + use_spectral_norm: bool = False) -> Module: + return Sequential( + wrap_conv_or_linear_module( + Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, bias=False, groups=in_channels), + initialization_method, + use_spectral_norm), + wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), + initialization_method, + use_spectral_norm)) + + +def create_separable_conv3_block( + in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None): + if block_args is None: + block_args = BlockArgs() + return Sequential( + wrap_conv_or_linear_module( + Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=in_channels), + block_args.initialization_method, + block_args.use_spectral_norm), + wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + block_args.initialization_method, + block_args.use_spectral_norm), + NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory).create(out_channels, affine=True), + block_args.nonlinearity_factory.create()) + + +def create_separable_conv7_block( + in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None): + if block_args is None: + block_args = BlockArgs() + return Sequential( + wrap_conv_or_linear_module( + Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, bias=False, groups=in_channels), + block_args.initialization_method, + block_args.use_spectral_norm), + wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + block_args.initialization_method, + block_args.use_spectral_norm), + NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory).create(out_channels, affine=True), + block_args.nonlinearity_factory.create()) + + +def create_separable_downsample_block( + in_channels: int, out_channels: int, is_output_1x1: bool, block_args: Optional[BlockArgs] = None): + if block_args is None: + block_args = BlockArgs() + if is_output_1x1: + return Sequential( + wrap_conv_or_linear_module( + Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels), + block_args.initialization_method, + block_args.use_spectral_norm), + wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + block_args.initialization_method, + block_args.use_spectral_norm), + block_args.nonlinearity_factory.create()) + else: + return Sequential( + wrap_conv_or_linear_module( + Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels), + block_args.initialization_method, + block_args.use_spectral_norm), + wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + block_args.initialization_method, + block_args.use_spectral_norm), + NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory) + .create(out_channels, affine=True), + block_args.nonlinearity_factory.create()) + + +def create_separable_upsample_block( + in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None): + if block_args is None: + block_args = BlockArgs() + return Sequential( + wrap_conv_or_linear_module( + ConvTranspose2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1, bias=False, groups=in_channels), + block_args.initialization_method, + block_args.use_spectral_norm), + wrap_conv_or_linear_module( + Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + block_args.initialization_method, + block_args.use_spectral_norm), + NormalizationLayerFactory.resolve_2d(block_args.normalization_layer_factory) + .create(out_channels, affine=True), + block_args.nonlinearity_factory.create()) diff --git a/tha3/nn/spectral_norm.py b/tha3/nn/spectral_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..4f08dd49ce96516002e9b69ce17d63a2c89ec802 --- /dev/null +++ b/tha3/nn/spectral_norm.py @@ -0,0 +1,9 @@ +from torch.nn import Module +from torch.nn.utils import spectral_norm + + +def apply_spectral_norm(module: Module, use_spectrial_norm: bool = False) -> Module: + if use_spectrial_norm: + return spectral_norm(module) + else: + return module diff --git a/tha3/nn/two_algo_body_rotator/__init__.py b/tha3/nn/two_algo_body_rotator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/nn/two_algo_body_rotator/two_algo_face_body_rotator_05.py b/tha3/nn/two_algo_body_rotator/two_algo_face_body_rotator_05.py new file mode 100644 index 0000000000000000000000000000000000000000..7793ac41407031fbeeb772d290437a9ed2b626cd --- /dev/null +++ b/tha3/nn/two_algo_body_rotator/two_algo_face_body_rotator_05.py @@ -0,0 +1,149 @@ +from typing import Optional, List + +import torch +from torch import Tensor +from torch.nn import Module, Sequential, Tanh + +from tha3.nn.image_processing_util import GridChangeApplier +from tha3.nn.common.resize_conv_encoder_decoder import ResizeConvEncoderDecoder, ResizeConvEncoderDecoderArgs +from tha3.module.module_factory import ModuleFactory +from tha3.nn.conv import create_conv3_from_block_args, create_conv3 +from tha3.nn.nonlinearity_factory import ReLUFactory, LeakyReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class TwoAlgoFaceBodyRotator05Args: + def __init__(self, + image_size: int = 512, + image_channels: int = 4, + num_pose_params: int = 6, + start_channels: int = 32, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + max_channels: int = 512, + upsample_mode: str = 'bilinear', + block_args: Optional[BlockArgs] = None, + use_separable_convolution=False): + if block_args is None: + block_args = BlockArgs( + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=False)) + + self.use_separable_convolution = use_separable_convolution + self.upsample_mode = upsample_mode + self.max_channels = max_channels + self.num_bottleneck_blocks = num_bottleneck_blocks + self.bottleneck_image_size = bottleneck_image_size + self.start_channels = start_channels + self.num_pose_params = num_pose_params + self.image_channels = image_channels + self.image_size = image_size + self.block_args = block_args + + +class TwoAlgoFaceBodyRotator05(Module): + def __init__(self, args: TwoAlgoFaceBodyRotator05Args): + super().__init__() + self.args = args + + self.encoder_decoder = ResizeConvEncoderDecoder( + ResizeConvEncoderDecoderArgs( + image_size=args.image_size, + input_channels=args.image_channels + args.num_pose_params, + start_channels=args.start_channels, + bottleneck_image_size=args.bottleneck_image_size, + num_bottleneck_blocks=args.num_bottleneck_blocks, + max_channels=args.max_channels, + block_args=args.block_args, + upsample_mode=args.upsample_mode, + use_separable_convolution=args.use_separable_convolution)) + + self.direct_creator = Sequential( + create_conv3_from_block_args( + in_channels=self.args.start_channels, + out_channels=self.args.image_channels, + bias=True, + block_args=self.args.block_args), + Tanh()) + self.grid_change_creator = create_conv3( + in_channels=self.args.start_channels, + out_channels=2, + bias=False, + initialization_method='zero', + use_spectral_norm=False) + self.grid_change_applier = GridChangeApplier() + + def forward(self, image: Tensor, pose: Tensor, *args) -> List[Tensor]: + n, c = pose.shape + pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.image_size, self.args.image_size) + feature = torch.cat([image, pose], dim=1) + + feature = self.encoder_decoder.forward(feature)[-1] + grid_change = self.grid_change_creator(feature) + direct_image = self.direct_creator(feature) + warped_image = self.grid_change_applier.apply(grid_change, image) + + return [ + direct_image, + warped_image, + grid_change] + + DIRECT_IMAGE_INDEX = 0 + WARPED_IMAGE_INDEX = 1 + GRID_CHANGE_INDEX = 2 + OUTPUT_LENGTH = 3 + + +class TwoAlgoFaceBodyRotator05Factory(ModuleFactory): + def __init__(self, args: TwoAlgoFaceBodyRotator05Args): + super().__init__() + self.args = args + + def create(self) -> Module: + return TwoAlgoFaceBodyRotator05(self.args) + + +if __name__ == "__main__": + cuda = torch.device('cuda') + + image_size = 256 + image_channels = 4 + num_pose_params = 6 + args = TwoAlgoFaceBodyRotator05Args( + image_size=256, + image_channels=4, + start_channels=64, + num_pose_params=6, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + max_channels=512, + upsample_mode='nearest', + use_separable_convolution=True, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1))) + module = TwoAlgoFaceBodyRotator05(args).to(cuda) + + image_count = 1 + image = torch.zeros(image_count, 4, image_size, image_size, device=cuda) + pose = torch.zeros(image_count, num_pose_params, device=cuda) + + repeat = 100 + acc = 0.0 + for i in range(repeat + 2): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + module.forward(image, pose) + end.record() + torch.cuda.synchronize() + if i >= 2: + elapsed_time = start.elapsed_time(end) + print("%d:" % i, elapsed_time) + acc = acc + elapsed_time + + print("average:", acc / repeat) diff --git a/tha3/nn/util.py b/tha3/nn/util.py new file mode 100644 index 0000000000000000000000000000000000000000..f9ae05d01343cab2bdc079020c0e65257eded7b1 --- /dev/null +++ b/tha3/nn/util.py @@ -0,0 +1,40 @@ +from typing import Optional, Callable, Union + +from torch.nn import Module + +from tha3.module.module_factory import ModuleFactory +from tha3.nn.init_function import create_init_function +from tha3.nn.nonlinearity_factory import resolve_nonlinearity_factory +from tha3.nn.normalization import NormalizationLayerFactory +from tha3.nn.spectral_norm import apply_spectral_norm + + +def wrap_conv_or_linear_module(module: Module, + initialization_method: Union[str, Callable[[Module], Module]], + use_spectral_norm: bool): + if isinstance(initialization_method, str): + init = create_init_function(initialization_method) + else: + init = initialization_method + return apply_spectral_norm(init(module), use_spectral_norm) + + +class BlockArgs: + def __init__(self, + initialization_method: Union[str, Callable[[Module], Module]] = 'he', + use_spectral_norm: bool = False, + normalization_layer_factory: Optional[NormalizationLayerFactory] = None, + nonlinearity_factory: Optional[ModuleFactory] = None): + self.nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) + self.normalization_layer_factory = normalization_layer_factory + self.use_spectral_norm = use_spectral_norm + self.initialization_method = initialization_method + + def wrap_module(self, module: Module) -> Module: + return wrap_conv_or_linear_module(module, self.get_init_func(), self.use_spectral_norm) + + def get_init_func(self) -> Callable[[Module], Module]: + if isinstance(self.initialization_method, str): + return create_init_function(self.initialization_method) + else: + return self.initialization_method diff --git a/tha3/poser/__init__.py b/tha3/poser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/poser/general_poser_02.py b/tha3/poser/general_poser_02.py new file mode 100644 index 0000000000000000000000000000000000000000..bf40cadd209d836d210c32a1a104f5e9d2c5ad8f --- /dev/null +++ b/tha3/poser/general_poser_02.py @@ -0,0 +1,85 @@ +from typing import List, Optional, Tuple, Dict, Callable + +import torch +from torch import Tensor +from torch.nn import Module + +from tha3.poser.poser import PoseParameterGroup, Poser +from tha3.compute.cached_computation_func import TensorListCachedComputationFunc + + +class GeneralPoser02(Poser): + def __init__(self, + module_loaders: Dict[str, Callable[[], Module]], + device: torch.device, + output_length: int, + pose_parameters: List[PoseParameterGroup], + output_list_func: TensorListCachedComputationFunc, + subrect: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None, + default_output_index: int = 0, + image_size: int = 256, + dtype: torch.dtype = torch.float): + self.dtype = dtype + self.image_size = image_size + self.default_output_index = default_output_index + self.output_list_func = output_list_func + self.subrect = subrect + self.pose_parameters = pose_parameters + self.device = device + self.module_loaders = module_loaders + + self.modules = None + + self.num_parameters = 0 + for pose_parameter in self.pose_parameters: + self.num_parameters += pose_parameter.get_arity() + + self.output_length = output_length + + def get_image_size(self) -> int: + return self.image_size + + def get_modules(self): + if self.modules is None: + self.modules = {} + for key in self.module_loaders: + module = self.module_loaders[key]() + self.modules[key] = module + module.to(self.device) + module.train(False) + return self.modules + + def get_pose_parameter_groups(self) -> List[PoseParameterGroup]: + return self.pose_parameters + + def get_num_parameters(self) -> int: + return self.num_parameters + + def pose(self, image: Tensor, pose: Tensor, output_index: Optional[int] = None) -> Tensor: + if output_index is None: + output_index = self.default_output_index + output_list = self.get_posing_outputs(image, pose) + return output_list[output_index] + + def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]: + modules = self.get_modules() + + if len(image.shape) == 3: + image = image.unsqueeze(0) + if len(pose.shape) == 1: + pose = pose.unsqueeze(0) + if self.subrect is not None: + image = image[:, :, self.subrect[0][0]:self.subrect[0][1], self.subrect[1][0]:self.subrect[1][1]] + batch = [image, pose] + + outputs = {} + return self.output_list_func(modules, batch, outputs) + + def get_output_length(self) -> int: + return self.output_length + + def free(self): + self.modules = None + + def get_dtype(self) -> torch.dtype: + return self.dtype diff --git a/tha3/poser/modes/__init__.py b/tha3/poser/modes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tha3/poser/modes/load_poser.py b/tha3/poser/modes/load_poser.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f371a0bffb6a91a0e41248d304539bcb33ec01 --- /dev/null +++ b/tha3/poser/modes/load_poser.py @@ -0,0 +1,19 @@ +import torch + + +def load_poser(model: str, device: torch.device): + print("Using the %s model." % model) + if model == "standard_float": + from tha3.poser.modes.standard_float import create_poser + return create_poser(device) + elif model == "standard_half": + from tha3.poser.modes.standard_half import create_poser + return create_poser(device) + elif model == "separable_float": + from tha3.poser.modes.separable_float import create_poser + return create_poser(device) + elif model == "separable_half": + from tha3.poser.modes.separable_half import create_poser + return create_poser(device) + else: + raise RuntimeError("Invalid model: '%s'" % model) \ No newline at end of file diff --git a/tha3/poser/modes/pose_parameters.py b/tha3/poser/modes/pose_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..433a68dd556fabb878633913209364c2634fdadd --- /dev/null +++ b/tha3/poser/modes/pose_parameters.py @@ -0,0 +1,36 @@ +from tha3.poser.poser import PoseParameters, PoseParameterCategory + + +def get_pose_parameters(): + return PoseParameters.Builder() \ + .add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \ + .add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \ + .add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \ + .add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \ + .add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \ + .build() \ No newline at end of file diff --git a/tha3/poser/modes/separable_float.py b/tha3/poser/modes/separable_float.py new file mode 100644 index 0000000000000000000000000000000000000000..ca22d61eee6688bbdfd5531b6ac74b4fa8a3de77 --- /dev/null +++ b/tha3/poser/modes/separable_float.py @@ -0,0 +1,331 @@ +from enum import Enum +from typing import Dict, Optional, List + +import torch +from torch import Tensor +from torch.nn import Module +from torch.nn.functional import interpolate + +from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import EyebrowMorphingCombiner00 +from tha3.nn.eyebrow_decomposer.eyebrow_decomposer_03 import EyebrowDecomposer03Factory, \ + EyebrowDecomposer03Args, EyebrowDecomposer03 +from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_03 import \ + EyebrowMorphingCombiner03Factory, EyebrowMorphingCombiner03Args +from tha3.nn.face_morpher.face_morpher_09 import FaceMorpher09Factory, FaceMorpher09Args +from tha3.poser.general_poser_02 import GeneralPoser02 +from tha3.nn.editor.editor_07 import Editor07, Editor07Args +from tha3.nn.two_algo_body_rotator.two_algo_face_body_rotator_05 import TwoAlgoFaceBodyRotator05, \ + TwoAlgoFaceBodyRotator05Args +from tha3.poser.modes.pose_parameters import get_pose_parameters +from tha3.util import torch_load +from tha3.compute.cached_computation_func import TensorListCachedComputationFunc +from tha3.compute.cached_computation_protocol import CachedComputationProtocol +from tha3.nn.nonlinearity_factory import ReLUFactory, LeakyReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class Network(Enum): + eyebrow_decomposer = 1 + eyebrow_morphing_combiner = 2 + face_morpher = 3 + two_algo_face_body_rotator = 4 + editor = 5 + + @property + def outputs_key(self): + return f"{self.name}_outputs" + + +class Branch(Enum): + face_morphed_half = 1 + face_morphed_full = 2 + all_outputs = 3 + + +NUM_EYEBROW_PARAMS = 12 +NUM_FACE_PARAMS = 27 +NUM_ROTATION_PARAMS = 6 + + +class FiveStepPoserComputationProtocol(CachedComputationProtocol): + def __init__(self, eyebrow_morphed_image_index: int): + super().__init__() + self.eyebrow_morphed_image_index = eyebrow_morphed_image_index + self.cached_batch_0 = None + self.cached_eyebrow_decomposer_output = None + + def compute_func(self) -> TensorListCachedComputationFunc: + def func(modules: Dict[str, Module], + batch: List[Tensor], + outputs: Dict[str, List[Tensor]]): + if self.cached_batch_0 is None: + new_batch_0 = True + elif batch[0].shape[0] != self.cached_batch_0.shape[0]: + new_batch_0 = True + else: + new_batch_0 = torch.max((batch[0] - self.cached_batch_0).abs()).item() > 0 + if not new_batch_0: + outputs[Network.eyebrow_decomposer.outputs_key] = self.cached_eyebrow_decomposer_output + output = self.get_output(Branch.all_outputs.name, modules, batch, outputs) + if new_batch_0: + self.cached_batch_0 = batch[0] + self.cached_eyebrow_decomposer_output = outputs[Network.eyebrow_decomposer.outputs_key] + return output + + return func + + def compute_output(self, key: str, modules: Dict[str, Module], batch: List[Tensor], + outputs: Dict[str, List[Tensor]]) -> List[Tensor]: + if key == Network.eyebrow_decomposer.outputs_key: + input_image = batch[0][:, :, 64:192, 64 + 128:192 + 128] + return modules[Network.eyebrow_decomposer.name].forward(input_image) + elif key == Network.eyebrow_morphing_combiner.outputs_key: + eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, modules, batch, outputs) + background_layer = eyebrow_decomposer_output[EyebrowDecomposer03.BACKGROUND_LAYER_INDEX] + eyebrow_layer = eyebrow_decomposer_output[EyebrowDecomposer03.EYEBROW_LAYER_INDEX] + eyebrow_pose = batch[1][:, :NUM_EYEBROW_PARAMS] + return modules[Network.eyebrow_morphing_combiner.name].forward( + background_layer, + eyebrow_layer, + eyebrow_pose) + elif key == Network.face_morpher.outputs_key: + eyebrow_morphing_combiner_output = self.get_output( + Network.eyebrow_morphing_combiner.outputs_key, modules, batch, outputs) + eyebrow_morphed_image = eyebrow_morphing_combiner_output[self.eyebrow_morphed_image_index] + input_image = batch[0][:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone() + input_image[:, :, 32:32 + 128, 32:32 + 128] = eyebrow_morphed_image + face_pose = batch[1][:, NUM_EYEBROW_PARAMS:NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS] + return modules[Network.face_morpher.name].forward(input_image, face_pose) + elif key == Branch.face_morphed_full.name: + face_morpher_output = self.get_output(Network.face_morpher.outputs_key, modules, batch, outputs) + face_morphed_image = face_morpher_output[0] + input_image = batch[0].clone() + input_image[:, :, 32:32 + 192, 32 + 128:32 + 192 + 128] = face_morphed_image + return [input_image] + elif key == Branch.face_morphed_half.name: + face_morphed_full = self.get_output(Branch.face_morphed_full.name, modules, batch, outputs)[0] + return [ + interpolate(face_morphed_full, size=(256, 256), mode='bilinear', align_corners=False) + ] + elif key == Network.two_algo_face_body_rotator.outputs_key: + face_morphed_half = self.get_output(Branch.face_morphed_half.name, modules, batch, outputs)[0] + rotation_pose = batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] + output = modules[Network.two_algo_face_body_rotator.name].forward(face_morphed_half, rotation_pose) + return output + elif key == Network.editor.outputs_key: + input_original_image = self.get_output(Branch.face_morphed_full.name, modules, batch, outputs)[0] + rotator_outputs = self.get_output( + Network.two_algo_face_body_rotator.outputs_key, modules, batch, outputs) + half_warped_image = rotator_outputs[TwoAlgoFaceBodyRotator05.WARPED_IMAGE_INDEX] + full_warped_image = interpolate( + half_warped_image, size=(512, 512), mode='bilinear', align_corners=False) + half_grid_change = rotator_outputs[TwoAlgoFaceBodyRotator05.GRID_CHANGE_INDEX] + full_grid_change = interpolate( + half_grid_change, size=(512, 512), mode='bilinear', align_corners=False) + rotation_pose = batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] + return modules[Network.editor.name].forward( + input_original_image, full_warped_image, full_grid_change, rotation_pose) + elif key == Branch.all_outputs.name: + editor_output = self.get_output(Network.editor.outputs_key, modules, batch, outputs) + rotater_output = self.get_output(Network.two_algo_face_body_rotator.outputs_key, modules, batch, outputs) + face_morpher_output = self.get_output(Network.face_morpher.outputs_key, modules, batch, outputs) + eyebrow_morphing_combiner_output = self.get_output( + Network.eyebrow_morphing_combiner.outputs_key, modules, batch, outputs) + eyebrow_decomposer_output = self.get_output( + Network.eyebrow_decomposer.outputs_key, modules, batch, outputs) + output = editor_output \ + + rotater_output \ + + face_morpher_output \ + + eyebrow_morphing_combiner_output \ + + eyebrow_decomposer_output + return output + else: + raise RuntimeError("Unsupported key: " + key) + + +def load_eyebrow_decomposer(file_name: str): + factory = EyebrowDecomposer03Factory( + EyebrowDecomposer03Args( + image_size=128, + image_channels=4, + start_channels=64, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True)))) + print("Loading the eyebrow decomposer ... ", end="") + module = factory.create() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_eyebrow_morphing_combiner(file_name: str): + factory = EyebrowMorphingCombiner03Factory( + EyebrowMorphingCombiner03Args( + image_size=128, + image_channels=4, + start_channels=64, + num_pose_params=12, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True)))) + print("Loading the eyebrow morphing conbiner ... ", end="") + module = factory.create() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_face_morpher(file_name: str): + factory = FaceMorpher09Factory( + FaceMorpher09Args( + image_size=192, + image_channels=4, + num_pose_params=27, + start_channels=64, + bottleneck_image_size=24, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=False)))) + print("Loading the face morpher ... ", end="") + module = factory.create() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_two_algo_generator(file_name) -> Module: + module = TwoAlgoFaceBodyRotator05( + TwoAlgoFaceBodyRotator05Args( + image_size=256, + image_channels=4, + start_channels=64, + num_pose_params=6, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + max_channels=512, + upsample_mode='nearest', + use_separable_convolution=True, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1)))) + print("Loading the face-body rotator ... ", end="") + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_editor(file_name) -> Module: + module = Editor07( + Editor07Args( + image_size=512, + image_channels=4, + num_pose_params=6, + start_channels=32, + bottleneck_image_size=64, + num_bottleneck_blocks=6, + max_channels=512, + upsampling_mode='nearest', + use_separable_convolution=True, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1)))) + print("Loading the combiner ... ", end="") + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def create_poser( + device: torch.device, + module_file_names: Optional[Dict[str, str]] = None, + eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX, + default_output_index: int = 0) -> GeneralPoser02: + if module_file_names is None: + module_file_names = {} + if Network.eyebrow_decomposer.name not in module_file_names: + dir = "data/models/separable_float" + file_name = dir + "/eyebrow_decomposer.pt" + module_file_names[Network.eyebrow_decomposer.name] = file_name + if Network.eyebrow_morphing_combiner.name not in module_file_names: + dir = "data/models/separable_float" + file_name = dir + "/eyebrow_morphing_combiner.pt" + module_file_names[Network.eyebrow_morphing_combiner.name] = file_name + if Network.face_morpher.name not in module_file_names: + dir = "data/models/separable_float" + file_name = dir + "/face_morpher.pt" + module_file_names[Network.face_morpher.name] = file_name + if Network.two_algo_face_body_rotator.name not in module_file_names: + dir = "data/models/separable_float" + file_name = dir + "/two_algo_face_body_rotator.pt" + module_file_names[Network.two_algo_face_body_rotator.name] = file_name + if Network.editor.name not in module_file_names: + dir = "data/models/separable_float" + file_name = dir + "/editor.pt" + module_file_names[Network.editor.name] = file_name + + loaders = { + Network.eyebrow_decomposer.name: + lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]), + Network.eyebrow_morphing_combiner.name: + lambda: load_eyebrow_morphing_combiner(module_file_names[Network.eyebrow_morphing_combiner.name]), + Network.face_morpher.name: + lambda: load_face_morpher(module_file_names[Network.face_morpher.name]), + Network.two_algo_face_body_rotator.name: + lambda: load_two_algo_generator(module_file_names[Network.two_algo_face_body_rotator.name]), + Network.editor.name: + lambda: load_editor(module_file_names[Network.editor.name]), + } + return GeneralPoser02( + image_size=512, + module_loaders=loaders, + pose_parameters=get_pose_parameters().get_pose_parameter_groups(), + output_list_func=FiveStepPoserComputationProtocol(eyebrow_morphed_image_index).compute_func(), + subrect=None, + device=device, + output_length=29, + default_output_index=default_output_index) + + +if __name__ == "__main__": + device = torch.device('cuda') + poser = create_poser(device) + + image = torch.zeros(1, 4, 512, 512, device=device) + pose = torch.zeros(1, 45, device=device) + + repeat = 100 + acc = 0.0 + for i in range(repeat + 2): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + poser.pose(image, pose) + end.record() + torch.cuda.synchronize() + if i >= 2: + elapsed_time = start.elapsed_time(end) + print("%d:" % i, elapsed_time) + acc = acc + elapsed_time + + print("average:", acc / repeat) \ No newline at end of file diff --git a/tha3/poser/modes/separable_half.py b/tha3/poser/modes/separable_half.py new file mode 100644 index 0000000000000000000000000000000000000000..cab8e0b2f33189c6f731c99841348f149f86c9c3 --- /dev/null +++ b/tha3/poser/modes/separable_half.py @@ -0,0 +1,365 @@ +from enum import Enum +from typing import List, Dict, Optional + +import torch +from torch import Tensor +from torch.nn import Module +from torch.nn.functional import interpolate + +from tha3.nn.eyebrow_decomposer.eyebrow_decomposer_03 import EyebrowDecomposer03Factory, \ + EyebrowDecomposer03Args, EyebrowDecomposer03 +from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_03 import \ + EyebrowMorphingCombiner03Factory, EyebrowMorphingCombiner03Args, EyebrowMorphingCombiner03 +from tha3.nn.face_morpher.face_morpher_09 import FaceMorpher09Factory, FaceMorpher09Args +from tha3.poser.general_poser_02 import GeneralPoser02 +from tha3.poser.poser import PoseParameterCategory, PoseParameters +from tha3.nn.editor.editor_07 import Editor07, Editor07Args +from tha3.nn.two_algo_body_rotator.two_algo_face_body_rotator_05 import TwoAlgoFaceBodyRotator05, \ + TwoAlgoFaceBodyRotator05Args +from tha3.util import torch_load +from tha3.compute.cached_computation_func import TensorListCachedComputationFunc +from tha3.compute.cached_computation_protocol import CachedComputationProtocol +from tha3.nn.nonlinearity_factory import ReLUFactory, LeakyReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class Network(Enum): + eyebrow_decomposer = 1 + eyebrow_morphing_combiner = 2 + face_morpher = 3 + two_algo_face_body_rotator = 4 + editor = 5 + + @property + def outputs_key(self): + return f"{self.name}_outputs" + + +class Branch(Enum): + face_morphed_half = 1 + face_morphed_full = 2 + all_outputs = 3 + + +NUM_EYEBROW_PARAMS = 12 +NUM_FACE_PARAMS = 27 +NUM_ROTATION_PARAMS = 6 + + +class FiveStepPoserComputationProtocol(CachedComputationProtocol): + def __init__(self, eyebrow_morphed_image_index: int): + super().__init__() + self.eyebrow_morphed_image_index = eyebrow_morphed_image_index + self.cached_batch_0 = None + self.cached_eyebrow_decomposer_output = None + + def compute_func(self) -> TensorListCachedComputationFunc: + def func(modules: Dict[str, Module], + batch: List[Tensor], + outputs: Dict[str, List[Tensor]]): + if self.cached_batch_0 is None: + new_batch_0 = True + elif batch[0].shape[0] != self.cached_batch_0.shape[0]: + new_batch_0 = True + else: + new_batch_0 = torch.max((batch[0] - self.cached_batch_0).abs()).item() > 0 + if not new_batch_0: + outputs[Network.eyebrow_decomposer.outputs_key] = self.cached_eyebrow_decomposer_output + output = self.get_output(Branch.all_outputs.name, modules, batch, outputs) + if new_batch_0: + self.cached_batch_0 = batch[0] + self.cached_eyebrow_decomposer_output = outputs[Network.eyebrow_decomposer.outputs_key] + return output + + return func + + def compute_output(self, key: str, modules: Dict[str, Module], batch: List[Tensor], + outputs: Dict[str, List[Tensor]]) -> List[Tensor]: + if key == Network.eyebrow_decomposer.outputs_key: + input_image = batch[0][:, :, 64:192, 64 + 128:192 + 128] + return modules[Network.eyebrow_decomposer.name].forward(input_image) + elif key == Network.eyebrow_morphing_combiner.outputs_key: + eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, modules, batch, outputs) + background_layer = eyebrow_decomposer_output[EyebrowDecomposer03.BACKGROUND_LAYER_INDEX] + eyebrow_layer = eyebrow_decomposer_output[EyebrowDecomposer03.EYEBROW_LAYER_INDEX] + eyebrow_pose = batch[1][:, :NUM_EYEBROW_PARAMS] + return modules[Network.eyebrow_morphing_combiner.name].forward( + background_layer, + eyebrow_layer, + eyebrow_pose) + elif key == Network.face_morpher.outputs_key: + eyebrow_morphing_combiner_output = self.get_output( + Network.eyebrow_morphing_combiner.outputs_key, modules, batch, outputs) + eyebrow_morphed_image = eyebrow_morphing_combiner_output[self.eyebrow_morphed_image_index] + input_image = batch[0][:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone() + input_image[:, :, 32:32 + 128, 32:32 + 128] = eyebrow_morphed_image + face_pose = batch[1][:, NUM_EYEBROW_PARAMS:NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS] + return modules[Network.face_morpher.name].forward(input_image, face_pose) + elif key == Branch.face_morphed_full.name: + face_morpher_output = self.get_output(Network.face_morpher.outputs_key, modules, batch, outputs) + face_morphed_image = face_morpher_output[0] + input_image = batch[0].clone() + input_image[:, :, 32:32 + 192, 32 + 128:32 + 192 + 128] = face_morphed_image + return [input_image] + elif key == Branch.face_morphed_half.name: + face_morphed_full = self.get_output(Branch.face_morphed_full.name, modules, batch, outputs)[0] + return [ + interpolate(face_morphed_full, size=(256, 256), mode='bilinear', align_corners=False) + ] + elif key == Network.two_algo_face_body_rotator.outputs_key: + face_morphed_half = self.get_output(Branch.face_morphed_half.name, modules, batch, outputs)[0] + rotation_pose = batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] + return modules[Network.two_algo_face_body_rotator.name].forward(face_morphed_half, rotation_pose) + elif key == Network.editor.outputs_key: + input_original_image = self.get_output(Branch.face_morphed_full.name, modules, batch, outputs)[0] + rotator_outputs = self.get_output( + Network.two_algo_face_body_rotator.outputs_key, modules, batch, outputs) + half_warped_image = rotator_outputs[TwoAlgoFaceBodyRotator05.WARPED_IMAGE_INDEX] + full_warped_image = interpolate( + half_warped_image, size=(512, 512), mode='bilinear', align_corners=False) + half_grid_change = rotator_outputs[TwoAlgoFaceBodyRotator05.GRID_CHANGE_INDEX] + full_grid_change = interpolate( + half_grid_change, size=(512, 512), mode='bilinear', align_corners=False) + rotation_pose = batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] + return modules[Network.editor.name].forward( + input_original_image, full_warped_image, full_grid_change, rotation_pose) + elif key == Branch.all_outputs.name: + editor_output = self.get_output(Network.editor.outputs_key, modules, batch, outputs) + rotater_output = self.get_output(Network.two_algo_face_body_rotator.outputs_key, modules, batch, outputs) + face_morpher_output = self.get_output(Network.face_morpher.outputs_key, modules, batch, outputs) + eyebrow_morphing_combiner_output = self.get_output( + Network.eyebrow_morphing_combiner.outputs_key, modules, batch, outputs) + eyebrow_decomposer_output = self.get_output( + Network.eyebrow_decomposer.outputs_key, modules, batch, outputs) + output = editor_output \ + + rotater_output \ + + face_morpher_output \ + + eyebrow_morphing_combiner_output \ + + eyebrow_decomposer_output + return output + else: + raise RuntimeError("Unsupported key: " + key) + + +def load_eyebrow_decomposer(file_name: str): + factory = EyebrowDecomposer03Factory( + EyebrowDecomposer03Args( + image_size=128, + image_channels=4, + start_channels=64, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True)))) + print("Loading the eyebrow decomposer ... ", end="") + module = factory.create().half() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_eyebrow_morphing_combiner(file_name: str): + factory = EyebrowMorphingCombiner03Factory( + EyebrowMorphingCombiner03Args( + image_size=128, + image_channels=4, + start_channels=64, + num_pose_params=12, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True)))) + print("Loading the eyebrow morphing conbiner ... ", end="") + module = factory.create().half() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_face_morpher(file_name: str): + factory = FaceMorpher09Factory( + FaceMorpher09Args( + image_size=192, + image_channels=4, + num_pose_params=27, + start_channels=64, + bottleneck_image_size=24, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=False)))) + print("Loading the face morpher ... ", end="") + module = factory.create().half() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_two_algo_generator(file_name) -> Module: + module = TwoAlgoFaceBodyRotator05( + TwoAlgoFaceBodyRotator05Args( + image_size=256, + image_channels=4, + start_channels=64, + num_pose_params=6, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + max_channels=512, + upsample_mode='nearest', + use_separable_convolution=True, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1)))).half() + print("Loading the face-body rotator ... ", end="") + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_editor(file_name) -> Module: + module = Editor07( + Editor07Args( + image_size=512, + image_channels=4, + num_pose_params=6, + start_channels=32, + bottleneck_image_size=64, + num_bottleneck_blocks=6, + max_channels=512, + upsampling_mode='nearest', + use_separable_convolution=True, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1)))).half() + print("Loading the combiner ... ", end="") + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def get_pose_parameters(): + return PoseParameters.Builder() \ + .add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \ + .add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \ + .add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \ + .add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \ + .add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \ + .build() + + +def create_poser( + device: torch.device, + module_file_names: Optional[Dict[str, str]] = None, + eyebrow_morphed_image_index: int = EyebrowMorphingCombiner03.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX, + default_output_index: int = 0) -> GeneralPoser02: + if module_file_names is None: + module_file_names = {} + if Network.eyebrow_decomposer.name not in module_file_names: + dir = "data/models/separable_half" + file_name = dir + "/eyebrow_decomposer.pt" + module_file_names[Network.eyebrow_decomposer.name] = file_name + if Network.eyebrow_morphing_combiner.name not in module_file_names: + dir = "data/models/separable_half" + file_name = dir + "/eyebrow_morphing_combiner.pt" + module_file_names[Network.eyebrow_morphing_combiner.name] = file_name + if Network.face_morpher.name not in module_file_names: + dir = "data/models/separable_half" + file_name = dir + "/face_morpher.pt" + module_file_names[Network.face_morpher.name] = file_name + if Network.two_algo_face_body_rotator.name not in module_file_names: + dir = "data/models/separable_half" + file_name = dir + "/two_algo_face_body_rotator.pt" + module_file_names[Network.two_algo_face_body_rotator.name] = file_name + if Network.editor.name not in module_file_names: + dir = "data/models/separable_half" + file_name = dir + "/editor.pt" + module_file_names[Network.editor.name] = file_name + + loaders = { + Network.eyebrow_decomposer.name: + lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]), + Network.eyebrow_morphing_combiner.name: + lambda: load_eyebrow_morphing_combiner(module_file_names[Network.eyebrow_morphing_combiner.name]), + Network.face_morpher.name: + lambda: load_face_morpher(module_file_names[Network.face_morpher.name]), + Network.two_algo_face_body_rotator.name: + lambda: load_two_algo_generator(module_file_names[Network.two_algo_face_body_rotator.name]), + Network.editor.name: + lambda: load_editor(module_file_names[Network.editor.name]), + } + return GeneralPoser02( + image_size=512, + module_loaders=loaders, + pose_parameters=get_pose_parameters().get_pose_parameter_groups(), + output_list_func=FiveStepPoserComputationProtocol(eyebrow_morphed_image_index).compute_func(), + subrect=None, + device=device, + output_length=29, + dtype=torch.half, + default_output_index=default_output_index) + + +if __name__ == "__main__": + device = torch.device('cuda') + poser = create_poser(device) + + image = torch.zeros(1, 4, 512, 512, device=device, dtype=torch.half) + pose = torch.zeros(1, 45, device=device, dtype=torch.half) + + repeat = 100 + acc = 0.0 + for i in range(repeat + 2): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + poser.pose(image, pose) + end.record() + torch.cuda.synchronize() + if i >= 2: + elapsed_time = start.elapsed_time(end) + print("%d:" % i, elapsed_time) + acc = acc + elapsed_time + + print("average:", acc / repeat) diff --git a/tha3/poser/modes/standard_float.py b/tha3/poser/modes/standard_float.py new file mode 100644 index 0000000000000000000000000000000000000000..ec96ace7e1c4aced4b5de3d7878f78542d8add1d --- /dev/null +++ b/tha3/poser/modes/standard_float.py @@ -0,0 +1,362 @@ +from enum import Enum +from typing import List, Dict, Optional + +import torch +from torch import Tensor +from torch.nn import Module +from torch.nn.functional import interpolate + +from tha3.nn.eyebrow_decomposer.eyebrow_decomposer_00 import EyebrowDecomposer00, \ + EyebrowDecomposer00Factory, EyebrowDecomposer00Args +from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \ + EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00 +from tha3.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory +from tha3.poser.general_poser_02 import GeneralPoser02 +from tha3.poser.poser import PoseParameterCategory, PoseParameters +from tha3.nn.editor.editor_07 import Editor07, Editor07Args +from tha3.nn.two_algo_body_rotator.two_algo_face_body_rotator_05 import TwoAlgoFaceBodyRotator05, \ + TwoAlgoFaceBodyRotator05Args +from tha3.util import torch_load +from tha3.compute.cached_computation_func import TensorListCachedComputationFunc +from tha3.compute.cached_computation_protocol import CachedComputationProtocol +from tha3.nn.nonlinearity_factory import ReLUFactory, LeakyReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class Network(Enum): + eyebrow_decomposer = 1 + eyebrow_morphing_combiner = 2 + face_morpher = 3 + two_algo_face_body_rotator = 4 + editor = 5 + + @property + def outputs_key(self): + return f"{self.name}_outputs" + + +class Branch(Enum): + face_morphed_half = 1 + face_morphed_full = 2 + all_outputs = 3 + + +NUM_EYEBROW_PARAMS = 12 +NUM_FACE_PARAMS = 27 +NUM_ROTATION_PARAMS = 6 + + +class FiveStepPoserComputationProtocol(CachedComputationProtocol): + def __init__(self, eyebrow_morphed_image_index: int): + super().__init__() + self.eyebrow_morphed_image_index = eyebrow_morphed_image_index + self.cached_batch_0 = None + self.cached_eyebrow_decomposer_output = None + + def compute_func(self) -> TensorListCachedComputationFunc: + def func(modules: Dict[str, Module], + batch: List[Tensor], + outputs: Dict[str, List[Tensor]]): + if self.cached_batch_0 is None: + new_batch_0 = True + elif batch[0].shape[0] != self.cached_batch_0.shape[0]: + new_batch_0 = True + else: + new_batch_0 = torch.max((batch[0] - self.cached_batch_0).abs()).item() > 0 + if not new_batch_0: + outputs[Network.eyebrow_decomposer.outputs_key] = self.cached_eyebrow_decomposer_output + output = self.get_output(Branch.all_outputs.name, modules, batch, outputs) + if new_batch_0: + self.cached_batch_0 = batch[0] + self.cached_eyebrow_decomposer_output = outputs[Network.eyebrow_decomposer.outputs_key] + return output + + return func + + def compute_output(self, key: str, modules: Dict[str, Module], batch: List[Tensor], + outputs: Dict[str, List[Tensor]]) -> List[Tensor]: + if key == Network.eyebrow_decomposer.outputs_key: + input_image = batch[0][:, :, 64:192, 64 + 128:192 + 128] + return modules[Network.eyebrow_decomposer.name].forward(input_image) + elif key == Network.eyebrow_morphing_combiner.outputs_key: + eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, modules, batch, outputs) + background_layer = eyebrow_decomposer_output[EyebrowDecomposer00.BACKGROUND_LAYER_INDEX] + eyebrow_layer = eyebrow_decomposer_output[EyebrowDecomposer00.EYEBROW_LAYER_INDEX] + eyebrow_pose = batch[1][:, :NUM_EYEBROW_PARAMS] + return modules[Network.eyebrow_morphing_combiner.name].forward( + background_layer, + eyebrow_layer, + eyebrow_pose) + elif key == Network.face_morpher.outputs_key: + eyebrow_morphing_combiner_output = self.get_output( + Network.eyebrow_morphing_combiner.outputs_key, modules, batch, outputs) + eyebrow_morphed_image = eyebrow_morphing_combiner_output[self.eyebrow_morphed_image_index] + input_image = batch[0][:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone() + input_image[:, :, 32:32 + 128, 32:32 + 128] = eyebrow_morphed_image + face_pose = batch[1][:, NUM_EYEBROW_PARAMS:NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS] + return modules[Network.face_morpher.name].forward(input_image, face_pose) + elif key == Branch.face_morphed_full.name: + face_morpher_output = self.get_output(Network.face_morpher.outputs_key, modules, batch, outputs) + face_morphed_image = face_morpher_output[0] + input_image = batch[0].clone() + input_image[:, :, 32:32 + 192, 32 + 128:32 + 192 + 128] = face_morphed_image + return [input_image] + elif key == Branch.face_morphed_half.name: + face_morphed_full = self.get_output(Branch.face_morphed_full.name, modules, batch, outputs)[0] + return [ + interpolate(face_morphed_full, size=(256, 256), mode='bilinear', align_corners=False) + ] + elif key == Network.two_algo_face_body_rotator.outputs_key: + face_morphed_half = self.get_output(Branch.face_morphed_half.name, modules, batch, outputs)[0] + rotation_pose = batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] + return modules[Network.two_algo_face_body_rotator.name].forward(face_morphed_half, rotation_pose) + elif key == Network.editor.outputs_key: + input_original_image = self.get_output(Branch.face_morphed_full.name, modules, batch, outputs)[0] + rotator_outputs = self.get_output( + Network.two_algo_face_body_rotator.outputs_key, modules, batch, outputs) + half_warped_image = rotator_outputs[TwoAlgoFaceBodyRotator05.WARPED_IMAGE_INDEX] + full_warped_image = interpolate( + half_warped_image, size=(512, 512), mode='bilinear', align_corners=False) + half_grid_change = rotator_outputs[TwoAlgoFaceBodyRotator05.GRID_CHANGE_INDEX] + full_grid_change = interpolate( + half_grid_change, size=(512, 512), mode='bilinear', align_corners=False) + rotation_pose = batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] + return modules[Network.editor.name].forward( + input_original_image, full_warped_image, full_grid_change, rotation_pose) + elif key == Branch.all_outputs.name: + editor_output = self.get_output(Network.editor.outputs_key, modules, batch, outputs) + rotater_output = self.get_output(Network.two_algo_face_body_rotator.outputs_key, modules, batch, outputs) + face_morpher_output = self.get_output(Network.face_morpher.outputs_key, modules, batch, outputs) + eyebrow_morphing_combiner_output = self.get_output( + Network.eyebrow_morphing_combiner.outputs_key, modules, batch, outputs) + eyebrow_decomposer_output = self.get_output( + Network.eyebrow_decomposer.outputs_key, modules, batch, outputs) + output = editor_output \ + + rotater_output \ + + face_morpher_output \ + + eyebrow_morphing_combiner_output \ + + eyebrow_decomposer_output + return output + else: + raise RuntimeError("Unsupported key: " + key) + + +def load_eyebrow_decomposer(file_name: str): + factory = EyebrowDecomposer00Factory( + EyebrowDecomposer00Args( + image_size=128, + image_channels=4, + start_channels=64, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True)))) + print("Loading the eyebrow decomposer ... ", end="") + module = factory.create() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_eyebrow_morphing_combiner(file_name: str): + factory = EyebrowMorphingCombiner00Factory( + EyebrowMorphingCombiner00Args( + image_size=128, + image_channels=4, + start_channels=64, + num_pose_params=12, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True)))) + print("Loading the eyebrow morphing conbiner ... ", end="") + module = factory.create() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_face_morpher(file_name: str): + factory = FaceMorpher08Factory( + FaceMorpher08Args( + image_size=192, + image_channels=4, + num_expression_params=27, + start_channels=64, + bottleneck_image_size=24, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=False)))) + print("Loading the face morpher ... ", end="") + module = factory.create() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_two_algo_generator(file_name) -> Module: + module = TwoAlgoFaceBodyRotator05( + TwoAlgoFaceBodyRotator05Args( + image_size=256, + image_channels=4, + start_channels=64, + num_pose_params=6, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + max_channels=512, + upsample_mode='nearest', + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1)))) + print("Loading the face-body rotator ... ", end="") + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_editor(file_name) -> Module: + module = Editor07( + Editor07Args( + image_size=512, + image_channels=4, + num_pose_params=6, + start_channels=32, + bottleneck_image_size=64, + num_bottleneck_blocks=6, + max_channels=512, + upsampling_mode='nearest', + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1)))) + print("Loading the combiner ... ", end="") + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def get_pose_parameters(): + return PoseParameters.Builder() \ + .add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \ + .add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \ + .add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \ + .add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \ + .add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \ + .build() + + +def create_poser( + device: torch.device, + module_file_names: Optional[Dict[str, str]] = None, + eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX, + default_output_index: int = 0) -> GeneralPoser02: + if module_file_names is None: + module_file_names = {} + if Network.eyebrow_decomposer.name not in module_file_names: + dir = "data/models/standard_float" + file_name = dir + "/eyebrow_decomposer.pt" + module_file_names[Network.eyebrow_decomposer.name] = file_name + if Network.eyebrow_morphing_combiner.name not in module_file_names: + dir = "data/models/standard_float" + file_name = dir + "/eyebrow_morphing_combiner.pt" + module_file_names[Network.eyebrow_morphing_combiner.name] = file_name + if Network.face_morpher.name not in module_file_names: + dir = "data/models/standard_float" + file_name = dir + "/face_morpher.pt" + module_file_names[Network.face_morpher.name] = file_name + if Network.two_algo_face_body_rotator.name not in module_file_names: + dir = "data/models/standard_float" + file_name = dir + "/two_algo_face_body_rotator.pt" + module_file_names[Network.two_algo_face_body_rotator.name] = file_name + if Network.editor.name not in module_file_names: + dir = "data/models/standard_float" + file_name = dir + "/editor.pt" + module_file_names[Network.editor.name] = file_name + + loaders = { + Network.eyebrow_decomposer.name: + lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]), + Network.eyebrow_morphing_combiner.name: + lambda: load_eyebrow_morphing_combiner(module_file_names[Network.eyebrow_morphing_combiner.name]), + Network.face_morpher.name: + lambda: load_face_morpher(module_file_names[Network.face_morpher.name]), + Network.two_algo_face_body_rotator.name: + lambda: load_two_algo_generator(module_file_names[Network.two_algo_face_body_rotator.name]), + Network.editor.name: + lambda: load_editor(module_file_names[Network.editor.name]), + } + return GeneralPoser02( + image_size=512, + module_loaders=loaders, + pose_parameters=get_pose_parameters().get_pose_parameter_groups(), + output_list_func=FiveStepPoserComputationProtocol(eyebrow_morphed_image_index).compute_func(), + subrect=None, + device=device, + output_length=29, + default_output_index=default_output_index) + + +if __name__ == "__main__": + device = torch.device('cuda') + poser = create_poser(device) + + image = torch.zeros(1, 4, 512, 512, device=device) + pose = torch.zeros(1, 45, device=device) + + repeat = 100 + acc = 0.0 + for i in range(repeat + 2): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + poser.pose(image, pose) + end.record() + torch.cuda.synchronize() + if i >= 2: + elapsed_time = start.elapsed_time(end) + print("%d:" % i, elapsed_time) + acc = acc + elapsed_time + + print("average:", acc / repeat) diff --git a/tha3/poser/modes/standard_half.py b/tha3/poser/modes/standard_half.py new file mode 100644 index 0000000000000000000000000000000000000000..aa6a6de20e480ddb6452828a8a54b00e497bbe0d --- /dev/null +++ b/tha3/poser/modes/standard_half.py @@ -0,0 +1,363 @@ +from enum import Enum +from typing import List, Dict, Optional + +import torch +from torch import Tensor +from torch.nn import Module +from torch.nn.functional import interpolate + +from tha3.nn.eyebrow_decomposer.eyebrow_decomposer_00 import EyebrowDecomposer00, \ + EyebrowDecomposer00Factory, EyebrowDecomposer00Args +from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \ + EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00 +from tha3.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory +from tha3.poser.general_poser_02 import GeneralPoser02 +from tha3.poser.poser import PoseParameterCategory, PoseParameters +from tha3.nn.editor.editor_07 import Editor07, Editor07Args +from tha3.nn.two_algo_body_rotator.two_algo_face_body_rotator_05 import TwoAlgoFaceBodyRotator05, \ + TwoAlgoFaceBodyRotator05Args +from tha3.util import torch_load +from tha3.compute.cached_computation_func import TensorListCachedComputationFunc +from tha3.compute.cached_computation_protocol import CachedComputationProtocol +from tha3.nn.nonlinearity_factory import ReLUFactory, LeakyReLUFactory +from tha3.nn.normalization import InstanceNorm2dFactory +from tha3.nn.util import BlockArgs + + +class Network(Enum): + eyebrow_decomposer = 1 + eyebrow_morphing_combiner = 2 + face_morpher = 3 + two_algo_face_body_rotator = 4 + editor = 5 + + @property + def outputs_key(self): + return f"{self.name}_outputs" + + +class Branch(Enum): + face_morphed_half = 1 + face_morphed_full = 2 + all_outputs = 3 + + +NUM_EYEBROW_PARAMS = 12 +NUM_FACE_PARAMS = 27 +NUM_ROTATION_PARAMS = 6 + + +class FiveStepPoserComputationProtocol(CachedComputationProtocol): + def __init__(self, eyebrow_morphed_image_index: int): + super().__init__() + self.eyebrow_morphed_image_index = eyebrow_morphed_image_index + self.cached_batch_0 = None + self.cached_eyebrow_decomposer_output = None + + def compute_func(self) -> TensorListCachedComputationFunc: + def func(modules: Dict[str, Module], + batch: List[Tensor], + outputs: Dict[str, List[Tensor]]): + if self.cached_batch_0 is None: + new_batch_0 = True + elif batch[0].shape[0] != self.cached_batch_0.shape[0]: + new_batch_0 = True + else: + new_batch_0 = torch.max((batch[0] - self.cached_batch_0).abs()).item() > 0 + if not new_batch_0: + outputs[Network.eyebrow_decomposer.outputs_key] = self.cached_eyebrow_decomposer_output + output = self.get_output(Branch.all_outputs.name, modules, batch, outputs) + if new_batch_0: + self.cached_batch_0 = batch[0] + self.cached_eyebrow_decomposer_output = outputs[Network.eyebrow_decomposer.outputs_key] + return output + + return func + + def compute_output(self, key: str, modules: Dict[str, Module], batch: List[Tensor], + outputs: Dict[str, List[Tensor]]) -> List[Tensor]: + if key == Network.eyebrow_decomposer.outputs_key: + input_image = batch[0][:, :, 64:192, 64 + 128:192 + 128] + return modules[Network.eyebrow_decomposer.name].forward(input_image) + elif key == Network.eyebrow_morphing_combiner.outputs_key: + eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, modules, batch, outputs) + background_layer = eyebrow_decomposer_output[EyebrowDecomposer00.BACKGROUND_LAYER_INDEX] + eyebrow_layer = eyebrow_decomposer_output[EyebrowDecomposer00.EYEBROW_LAYER_INDEX] + eyebrow_pose = batch[1][:, :NUM_EYEBROW_PARAMS] + return modules[Network.eyebrow_morphing_combiner.name].forward( + background_layer, + eyebrow_layer, + eyebrow_pose) + elif key == Network.face_morpher.outputs_key: + eyebrow_morphing_combiner_output = self.get_output( + Network.eyebrow_morphing_combiner.outputs_key, modules, batch, outputs) + eyebrow_morphed_image = eyebrow_morphing_combiner_output[self.eyebrow_morphed_image_index] + input_image = batch[0][:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone() + input_image[:, :, 32:32 + 128, 32:32 + 128] = eyebrow_morphed_image + face_pose = batch[1][:, NUM_EYEBROW_PARAMS:NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS] + return modules[Network.face_morpher.name].forward(input_image, face_pose) + elif key == Branch.face_morphed_full.name: + face_morpher_output = self.get_output(Network.face_morpher.outputs_key, modules, batch, outputs) + face_morphed_image = face_morpher_output[0] + input_image = batch[0].clone() + input_image[:, :, 32:32 + 192, 32 + 128:32 + 192 + 128] = face_morphed_image + return [input_image] + elif key == Branch.face_morphed_half.name: + face_morphed_full = self.get_output(Branch.face_morphed_full.name, modules, batch, outputs)[0] + return [ + interpolate(face_morphed_full, size=(256, 256), mode='bilinear', align_corners=False) + ] + elif key == Network.two_algo_face_body_rotator.outputs_key: + face_morphed_half = self.get_output(Branch.face_morphed_half.name, modules, batch, outputs)[0] + rotation_pose = batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] + return modules[Network.two_algo_face_body_rotator.name].forward(face_morphed_half, rotation_pose) + elif key == Network.editor.outputs_key: + input_original_image = self.get_output(Branch.face_morphed_full.name, modules, batch, outputs)[0] + rotator_outputs = self.get_output( + Network.two_algo_face_body_rotator.outputs_key, modules, batch, outputs) + half_warped_image = rotator_outputs[TwoAlgoFaceBodyRotator05.WARPED_IMAGE_INDEX] + full_warped_image = interpolate( + half_warped_image, size=(512, 512), mode='bilinear', align_corners=False) + half_grid_change = rotator_outputs[TwoAlgoFaceBodyRotator05.GRID_CHANGE_INDEX] + full_grid_change = interpolate( + half_grid_change, size=(512, 512), mode='bilinear', align_corners=False) + rotation_pose = batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] + return modules[Network.editor.name].forward( + input_original_image, full_warped_image, full_grid_change, rotation_pose) + elif key == Branch.all_outputs.name: + editor_output = self.get_output(Network.editor.outputs_key, modules, batch, outputs) + rotater_output = self.get_output(Network.two_algo_face_body_rotator.outputs_key, modules, batch, outputs) + face_morpher_output = self.get_output(Network.face_morpher.outputs_key, modules, batch, outputs) + eyebrow_morphing_combiner_output = self.get_output( + Network.eyebrow_morphing_combiner.outputs_key, modules, batch, outputs) + eyebrow_decomposer_output = self.get_output( + Network.eyebrow_decomposer.outputs_key, modules, batch, outputs) + output = editor_output \ + + rotater_output \ + + face_morpher_output \ + + eyebrow_morphing_combiner_output \ + + eyebrow_decomposer_output + return output + else: + raise RuntimeError("Unsupported key: " + key) + + +def load_eyebrow_decomposer(file_name: str): + factory = EyebrowDecomposer00Factory( + EyebrowDecomposer00Args( + image_size=128, + image_channels=4, + start_channels=64, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True)))) + print("Loading the eyebrow decomposer ... ", end="") + module = factory.create().half() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_eyebrow_morphing_combiner(file_name: str): + factory = EyebrowMorphingCombiner00Factory( + EyebrowMorphingCombiner00Args( + image_size=128, + image_channels=4, + start_channels=64, + num_pose_params=12, + bottleneck_image_size=16, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=True)))) + print("Loading the eyebrow morphing conbiner ... ", end="") + module = factory.create().half() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_face_morpher(file_name: str): + factory = FaceMorpher08Factory( + FaceMorpher08Args( + image_size=192, + image_channels=4, + num_expression_params=27, + start_channels=64, + bottleneck_image_size=24, + num_bottleneck_blocks=6, + max_channels=512, + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=ReLUFactory(inplace=False)))) + print("Loading the face morpher ... ", end="") + module = factory.create().half() + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_two_algo_generator(file_name) -> Module: + module = TwoAlgoFaceBodyRotator05( + TwoAlgoFaceBodyRotator05Args( + image_size=256, + image_channels=4, + start_channels=64, + num_pose_params=6, + bottleneck_image_size=32, + num_bottleneck_blocks=6, + max_channels=512, + upsample_mode='nearest', + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1)))).half() + print("Loading the face-body rotator ... ", end="") + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def load_editor(file_name) -> Module: + module = Editor07( + Editor07Args( + image_size=512, + image_channels=4, + num_pose_params=6, + start_channels=32, + bottleneck_image_size=64, + num_bottleneck_blocks=6, + max_channels=512, + upsampling_mode='nearest', + block_args=BlockArgs( + initialization_method='he', + use_spectral_norm=False, + normalization_layer_factory=InstanceNorm2dFactory(), + nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1)))).half() + print("Loading the combiner ... ", end="") + module.load_state_dict(torch_load(file_name)) + print("DONE!!!") + return module + + +def get_pose_parameters(): + return PoseParameters.Builder() \ + .add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \ + .add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \ + .add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \ + .add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \ + .add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \ + .add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \ + .add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \ + .add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ + .add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \ + .build() + + +def create_poser( + device: torch.device, + module_file_names: Optional[Dict[str, str]] = None, + eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX, + default_output_index: int = 0) -> GeneralPoser02: + if module_file_names is None: + module_file_names = {} + if Network.eyebrow_decomposer.name not in module_file_names: + dir = "data/models/standard_half" + file_name = dir + "/eyebrow_decomposer.pt" + module_file_names[Network.eyebrow_decomposer.name] = file_name + if Network.eyebrow_morphing_combiner.name not in module_file_names: + dir = "data/models/standard_half" + file_name = dir + "/eyebrow_morphing_combiner.pt" + module_file_names[Network.eyebrow_morphing_combiner.name] = file_name + if Network.face_morpher.name not in module_file_names: + dir = "data/models/standard_half" + file_name = dir + "/face_morpher.pt" + module_file_names[Network.face_morpher.name] = file_name + if Network.two_algo_face_body_rotator.name not in module_file_names: + dir = "data/models/standard_half" + file_name = dir + "/two_algo_face_body_rotator.pt" + module_file_names[Network.two_algo_face_body_rotator.name] = file_name + if Network.editor.name not in module_file_names: + dir = "data/models/standard_half" + file_name = dir + "/editor.pt" + module_file_names[Network.editor.name] = file_name + + loaders = { + Network.eyebrow_decomposer.name: + lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]), + Network.eyebrow_morphing_combiner.name: + lambda: load_eyebrow_morphing_combiner(module_file_names[Network.eyebrow_morphing_combiner.name]), + Network.face_morpher.name: + lambda: load_face_morpher(module_file_names[Network.face_morpher.name]), + Network.two_algo_face_body_rotator.name: + lambda: load_two_algo_generator(module_file_names[Network.two_algo_face_body_rotator.name]), + Network.editor.name: + lambda: load_editor(module_file_names[Network.editor.name]), + } + return GeneralPoser02( + image_size=512, + module_loaders=loaders, + pose_parameters=get_pose_parameters().get_pose_parameter_groups(), + output_list_func=FiveStepPoserComputationProtocol(eyebrow_morphed_image_index).compute_func(), + subrect=None, + device=device, + output_length=29, + dtype=torch.half, + default_output_index=default_output_index) + + +if __name__ == "__main__": + device = torch.device('cuda') + poser = create_poser(device) + + image = torch.zeros(1, 4, 512, 512, device=device, dtype=torch.half) + pose = torch.zeros(1, 45, device=device, dtype=torch.half) + + repeat = 100 + acc = 0.0 + for i in range(repeat + 2): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + poser.pose(image, pose) + end.record() + torch.cuda.synchronize() + if i >= 2: + elapsed_time = start.elapsed_time(end) + print("%d:" % i, elapsed_time) + acc = acc + elapsed_time + + print("average:", acc / repeat) diff --git a/tha3/poser/poser.py b/tha3/poser/poser.py new file mode 100644 index 0000000000000000000000000000000000000000..e0deeb1fa58f0a47083d44ec4d820ba86ce21da5 --- /dev/null +++ b/tha3/poser/poser.py @@ -0,0 +1,158 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import Tuple, List, Optional + +import torch +from torch import Tensor + + +class PoseParameterCategory(Enum): + EYEBROW = 1 + EYE = 2 + IRIS_MORPH = 3 + IRIS_ROTATION = 4 + MOUTH = 5 + FACE_ROTATION = 6 + BODY_ROTATION = 7 + BREATHING = 8 + + +class PoseParameterGroup: + def __init__(self, + group_name: str, + parameter_index: int, + category: PoseParameterCategory, + arity: int = 1, + discrete: bool = False, + default_value: float = 0.0, + range: Optional[Tuple[float, float]] = None): + assert arity == 1 or arity == 2 + if range is None: + range = (0.0, 1.0) + if arity == 1: + parameter_names = [group_name] + else: + parameter_names = [group_name + "_left", group_name + "_right"] + assert len(parameter_names) == arity + + self.parameter_names = parameter_names + self.range = range + self.default_value = default_value + self.discrete = discrete + self.arity = arity + self.category = category + self.parameter_index = parameter_index + self.group_name = group_name + + def get_arity(self) -> int: + return self.arity + + def get_group_name(self) -> str: + return self.group_name + + def get_parameter_names(self) -> List[str]: + return self.parameter_names + + def is_discrete(self) -> bool: + return self.discrete + + def get_range(self) -> Tuple[float, float]: + return self.range + + def get_default_value(self): + return self.default_value + + def get_parameter_index(self): + return self.parameter_index + + def get_category(self) -> PoseParameterCategory: + return self.category + + +class PoseParameters: + def __init__(self, pose_parameter_groups: List[PoseParameterGroup]): + self.pose_parameter_groups = pose_parameter_groups + + def get_parameter_index(self, name: str) -> int: + index = 0 + for parameter_group in self.pose_parameter_groups: + for param_name in parameter_group.parameter_names: + if name == param_name: + return index + index += 1 + raise RuntimeError("Cannot find parameter with name %s" % name) + + def get_parameter_name(self, index: int) -> str: + assert index >= 0 and index < self.get_parameter_count() + + for group in self.pose_parameter_groups: + if index < group.get_arity(): + return group.get_parameter_names()[index] + index -= group.arity + + raise RuntimeError("Something is wrong here!!!") + + def get_pose_parameter_groups(self): + return self.pose_parameter_groups + + def get_parameter_count(self): + count = 0 + for group in self.pose_parameter_groups: + count += group.arity + return count + + class Builder: + def __init__(self): + self.index = 0 + self.pose_parameter_groups = [] + + def add_parameter_group(self, + group_name: str, + category: PoseParameterCategory, + arity: int = 1, + discrete: bool = False, + default_value: float = 0.0, + range: Optional[Tuple[float, float]] = None): + self.pose_parameter_groups.append( + PoseParameterGroup( + group_name, + self.index, + category, + arity, + discrete, + default_value, + range)) + self.index += arity + return self + + def build(self) -> 'PoseParameters': + return PoseParameters(self.pose_parameter_groups) + + +class Poser(ABC): + @abstractmethod + def get_image_size(self) -> int: + pass + + @abstractmethod + def get_output_length(self) -> int: + pass + + @abstractmethod + def get_pose_parameter_groups(self) -> List[PoseParameterGroup]: + pass + + @abstractmethod + def get_num_parameters(self) -> int: + pass + + @abstractmethod + def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor: + pass + + @abstractmethod + def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]: + pass + + def get_dtype(self) -> torch.dtype: + return torch.float diff --git a/tha3/util.py b/tha3/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a161150f4ce39d7217703be2668ae3f9db21cabe --- /dev/null +++ b/tha3/util.py @@ -0,0 +1,281 @@ +import math +import os +from typing import List + +import PIL.Image +import numpy +import torch +from matplotlib import cm +from torch import Tensor + + +def is_power2(x): + return x != 0 and ((x & (x - 1)) == 0) + + +def numpy_srgb_to_linear(x): + x = numpy.clip(x, 0.0, 1.0) + return numpy.where(x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) + + +def numpy_linear_to_srgb(x): + x = numpy.clip(x, 0.0, 1.0) + return numpy.where(x <= 0.003130804953560372, x * 12.92, 1.055 * (x ** (1.0 / 2.4)) - 0.055) + + +def torch_srgb_to_linear(x: torch.Tensor): + x = torch.clip(x, 0.0, 1.0) + return torch.where(torch.le(x, 0.04045), x / 12.92, ((x + 0.055) / 1.055) ** 2.4) + + +def torch_linear_to_srgb(x): + x = torch.clip(x, 0.0, 1.0) + return torch.where(torch.le(x, 0.003130804953560372), x * 12.92, 1.055 * (x ** (1.0 / 2.4)) - 0.055) + + +def image_linear_to_srgb(image): + assert image.shape[2] == 3 or image.shape[2] == 4 + if image.shape[2] == 3: + return numpy_linear_to_srgb(image) + else: + height, width, _ = image.shape + rgb_image = numpy_linear_to_srgb(image[:, :, 0:3]) + a_image = image[:, :, 3:4] + return numpy.concatenate((rgb_image, a_image), axis=2) + + +def image_srgb_to_linear(image): + assert image.shape[2] == 3 or image.shape[2] == 4 + if image.shape[2] == 3: + return numpy_srgb_to_linear(image) + else: + height, width, _ = image.shape + rgb_image = numpy_srgb_to_linear(image[:, :, 0:3]) + a_image = image[:, :, 3:4] + return numpy.concatenate((rgb_image, a_image), axis=2) + + +def save_rng_state(file_name): + rng_state = torch.get_rng_state() + torch_save(rng_state, file_name) + + +def load_rng_state(file_name): + rng_state = torch_load(file_name) + torch.set_rng_state(rng_state) + + +def grid_change_to_numpy_image(torch_image, num_channels=3): + height = torch_image.shape[1] + width = torch_image.shape[2] + size_image = (torch_image[0, :, :] ** 2 + torch_image[1, :, :] ** 2).sqrt().view(height, width, 1).numpy() + hsv = cm.get_cmap('hsv') + angle_image = hsv(((torch.atan2( + torch_image[0, :, :].view(height * width), + torch_image[1, :, :].view(height * width)).view(height, width) + math.pi) / (2 * math.pi)).numpy()) * 3 + numpy_image = size_image * angle_image[:, :, 0:3] + rgb_image = numpy_linear_to_srgb(numpy_image) + if num_channels == 3: + return rgb_image + elif num_channels == 4: + return numpy.concatenate([rgb_image, numpy.ones_like(size_image)], axis=2) + else: + raise RuntimeError("Unsupported num_channels: " + str(num_channels)) + + +def rgb_to_numpy_image(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0): + assert torch_image.dim() == 3 + assert torch_image.shape[0] == 3 + height = torch_image.shape[1] + width = torch_image.shape[2] + + reshaped_image = torch_image.numpy().reshape(3, height * width).transpose().reshape(height, width, 3) + numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value) + return numpy_linear_to_srgb(numpy_image) + + +def rgba_to_numpy_image_greenscreen(torch_image: Tensor, + min_pixel_value=-1.0, + max_pixel_value=1.0, + include_alpha=False): + height = torch_image.shape[1] + width = torch_image.shape[2] + + numpy_image = (torch_image.numpy().reshape(4, height * width).transpose().reshape(height, width, + 4) - min_pixel_value) \ + / (max_pixel_value - min_pixel_value) + rgb_image = numpy_linear_to_srgb(numpy_image[:, :, 0:3]) + a_image = numpy_image[:, :, 3] + rgb_image[:, :, 0:3] = rgb_image[:, :, 0:3] * a_image.reshape(a_image.shape[0], a_image.shape[1], 1) + rgb_image[:, :, 1] = rgb_image[:, :, 1] + (1 - a_image) + + if not include_alpha: + return rgb_image + else: + return numpy.concatenate((rgb_image, numpy.ones_like(numpy_image[:, :, 3:4])), axis=2) + + +def rgba_to_numpy_image(torch_image: Tensor, min_pixel_value=-1.0, max_pixel_value=1.0): + assert torch_image.dim() == 3 + assert torch_image.shape[0] == 4 + height = torch_image.shape[1] + width = torch_image.shape[2] + + reshaped_image = torch_image.numpy().reshape(4, height * width).transpose().reshape(height, width, 4) + numpy_image = (reshaped_image - min_pixel_value) / (max_pixel_value - min_pixel_value) + rgb_image = numpy_linear_to_srgb(numpy_image[:, :, 0:3]) + a_image = numpy.clip(numpy_image[:, :, 3], 0.0, 1.0) + rgba_image = numpy.concatenate((rgb_image, a_image.reshape(height, width, 1)), axis=2) + return rgba_image + + +def extract_numpy_image_from_filelike_with_pytorch_layout(file, has_alpha=True, scale=2.0, offset=-1.0): + try: + pil_image = PIL.Image.open(file) + except Exception as e: + raise RuntimeError(file) + return extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, has_alpha, scale, offset) + + +def extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, has_alpha=True, scale=2.0, offset=-1.0): + if has_alpha: + num_channel = 4 + else: + num_channel = 3 + image_size = pil_image.width + + # search for transparent pixels(alpha==0) and change them to [0 0 0 0] to avoid the color influence to the model + for i, px in enumerate(pil_image.getdata()): + if px[3] <= 0: + y = i // image_size + x = i % image_size + pil_image.putpixel((x, y), (0, 0, 0, 0)) + + raw_image = numpy.asarray(pil_image) + image = (raw_image / 255.0).reshape(image_size, image_size, num_channel) + image[:, :, 0:3] = numpy_srgb_to_linear(image[:, :, 0:3]) + image = image \ + .reshape(image_size * image_size, num_channel) \ + .transpose() \ + .reshape(num_channel, image_size, image_size) * scale + offset + return image + + +def extract_pytorch_image_from_filelike(file, has_alpha=True, scale=2.0, offset=-1.0): + try: + pil_image = PIL.Image.open(file) + except Exception as e: + raise RuntimeError(file) + image = extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, has_alpha, scale, offset) + return torch.from_numpy(image).float() + + +def extract_pytorch_image_from_PIL_image(pil_image, has_alpha=True, scale=2.0, offset=-1.0): + image = extract_numpy_image_from_PIL_image_with_pytorch_layout(pil_image, has_alpha, scale, offset) + return torch.from_numpy(image).float() + + +def extract_numpy_image_from_filelike(file): + pil_image = PIL.Image.open(file) + image_width = pil_image.width + image_height = pil_image.height + if pil_image.mode == "RGBA": + image = (numpy.asarray(pil_image) / 255.0).reshape(image_height, image_width, 4) + else: + image = (numpy.asarray(pil_image) / 255.0).reshape(image_height, image_width, 3) + image[:, :, 0:3] = numpy_srgb_to_linear(image[:, :, 0:3]) + return image + + +def convert_avs_to_avi(avs_file, avi_file): + os.makedirs(os.path.dirname(avi_file), exist_ok=True) + + file = open("temp.vdub", "w") + file.write("VirtualDub.Open(\"%s\");" % avs_file) + file.write("VirtualDub.video.SetCompression(\"cvid\", 0, 10000, 0);") + file.write("VirtualDub.SaveAVI(\"%s\");" % avi_file) + file.write("VirtualDub.Close();") + file.close() + + os.system("C:\\ProgramData\\chocolatey\\lib\\virtualdub\\tools\\vdub64.exe /i temp.vdub") + + os.remove("temp.vdub") + + +def convert_avi_to_mp4(avi_file, mp4_file): + os.makedirs(os.path.dirname(mp4_file), exist_ok=True) + os.system("ffmpeg -y -i %s -c:v libx264 -preset slow -crf 22 -c:a libfaac -b:a 128k %s" % \ + (avi_file, mp4_file)) + + +def convert_avi_to_webm(avi_file, webm_file): + os.makedirs(os.path.dirname(webm_file), exist_ok=True) + os.system("ffmpeg -y -i %s -vcodec libvpx -qmin 0 -qmax 50 -crf 10 -b:v 1M -acodec libvorbis %s" % \ + (avi_file, webm_file)) + + +def convert_mp4_to_webm(mp4_file, webm_file): + os.makedirs(os.path.dirname(webm_file), exist_ok=True) + os.system("ffmpeg -y -i %s -vcodec libvpx -qmin 0 -qmax 50 -crf 10 -b:v 1M -acodec libvorbis %s" % \ + (mp4_file, webm_file)) + + +def create_parent_dir(file_name): + os.makedirs(os.path.dirname(file_name), exist_ok=True) + + +def run_command(command_parts: List[str]): + command = " ".join(command_parts) + os.system(command) + + +def save_pytorch_image(image, file_name): + if image.shape[0] == 1: + image = image.squeeze() + if image.shape[0] == 4: + numpy_image = rgba_to_numpy_image(image.detach().cpu()) + pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(numpy_image * 255.0)), mode='RGBA') + else: + numpy_image = rgb_to_numpy_image(image.detach().cpu()) + pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(numpy_image * 255.0)), mode='RGB') + os.makedirs(os.path.dirname(file_name), exist_ok=True) + pil_image.save(file_name) + + +def torch_load(file_name): + with open(file_name, 'rb') as f: + return torch.load(f) + + +def torch_save(content, file_name): + os.makedirs(os.path.dirname(file_name), exist_ok=True) + with open(file_name, 'wb') as f: + torch.save(content, f) + + +def resize_PIL_image(pil_image, size=(256, 256)): + w, h = pil_image.size + d = min(w, h) + r = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2) + return pil_image.resize(size, resample=PIL.Image.LANCZOS, box=r) + + +def extract_PIL_image_from_filelike(file): + return PIL.Image.open(file) + + +def convert_output_image_from_torch_to_numpy(output_image): + if output_image.shape[2] == 2: + h, w, c = output_image.shape + output_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w) + if output_image.shape[0] == 4: + numpy_image = rgba_to_numpy_image(output_image) + elif output_image.shape[0] == 1: + c, h, w = output_image.shape + alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0) + numpy_image = rgba_to_numpy_image(alpha_image) + elif output_image.shape[0] == 2: + numpy_image = grid_change_to_numpy_image(output_image, num_channels=4) + else: + raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0]) + return numpy_image