Harry_FBK commited on
Commit
60094bd
1 Parent(s): 277d7e6

Clone original THA3

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +134 -0
  2. LICENSE +21 -0
  3. README.md +239 -13
  4. colab.ipynb +542 -0
  5. docs/ifacialmocap_ip.jpg +0 -0
  6. docs/ifacialmocap_puppeteer_click_start_capture.png +0 -0
  7. docs/ifacialmocap_puppeteer_ip_address_box.png +0 -0
  8. docs/ifacialmocap_puppeteer_numbers.png +0 -0
  9. docs/input_spec.png +0 -0
  10. docs/pytorch-install-command.png +0 -0
  11. environment.yml +141 -0
  12. manual_poser.ipynb +460 -0
  13. tha3/__init__.py +0 -0
  14. tha3/app/__init__.py +0 -0
  15. tha3/app/ifacialmocap_puppeteer.py +439 -0
  16. tha3/app/manual_poser.py +464 -0
  17. tha3/compute/__init__.py +0 -0
  18. tha3/compute/cached_computation_func.py +9 -0
  19. tha3/compute/cached_computation_protocol.py +43 -0
  20. tha3/mocap/__init__.py +0 -0
  21. tha3/mocap/ifacialmocap_constants.py +239 -0
  22. tha3/mocap/ifacialmocap_pose.py +27 -0
  23. tha3/mocap/ifacialmocap_pose_converter.py +12 -0
  24. tha3/mocap/ifacialmocap_poser_converter_25.py +463 -0
  25. tha3/mocap/ifacialmocap_v2.py +89 -0
  26. tha3/module/__init__.py +0 -0
  27. tha3/module/module_factory.py +9 -0
  28. tha3/nn/__init__.py +0 -0
  29. tha3/nn/common/__init__.py +0 -0
  30. tha3/nn/common/conv_block_factory.py +55 -0
  31. tha3/nn/common/poser_args.py +68 -0
  32. tha3/nn/common/poser_encoder_decoder_00.py +121 -0
  33. tha3/nn/common/poser_encoder_decoder_00_separable.py +92 -0
  34. tha3/nn/common/resize_conv_encoder_decoder.py +125 -0
  35. tha3/nn/common/resize_conv_unet.py +155 -0
  36. tha3/nn/conv.py +189 -0
  37. tha3/nn/editor/__init__.py +0 -0
  38. tha3/nn/editor/editor_07.py +180 -0
  39. tha3/nn/eyebrow_decomposer/__init__.py +0 -0
  40. tha3/nn/eyebrow_decomposer/eyebrow_decomposer_00.py +102 -0
  41. tha3/nn/eyebrow_decomposer/eyebrow_decomposer_03.py +109 -0
  42. tha3/nn/eyebrow_morphing_combiner/__init__.py +0 -0
  43. tha3/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_00.py +115 -0
  44. tha3/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_03.py +117 -0
  45. tha3/nn/face_morpher/__init__.py +0 -0
  46. tha3/nn/face_morpher/face_morpher_08.py +241 -0
  47. tha3/nn/face_morpher/face_morpher_09.py +187 -0
  48. tha3/nn/image_processing_util.py +58 -0
  49. tha3/nn/init_function.py +76 -0
  50. tha3/nn/nonlinearity_factory.py +72 -0
.gitignore ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ data/
132
+ *.iml
133
+ .idea/
134
+ *.pt
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Pramook Khungurn
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,239 @@
1
- ---
2
- title: Talking Head Anime 3
3
- emoji: 🐠
4
- colorFrom: pink
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 3.18.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo Code for "Talking Head(?) Anime from A Single Image 3: Now the Body Too"
2
+
3
+ This repository contains demo programs for the [Talking Head(?) Anime from a Single Image 3: Now the Body Too](https://pkhungurn.github.io/talking-head-anime-3/index.html) project. As the name implies, the project allows you to animate anime characters, and you only need a single image of that character to do so. There are two demo programs:
4
+
5
+ * The ``manual_poser`` lets you manipulate a character's facial expression, head rotation, body rotation, and chest expansion due to breathing through a graphical user interface.
6
+ * ``ifacialmocap_puppeteer`` lets you transfer your facial motion to an anime character.
7
+
8
+ ## Try the Manual Poser on Google Colab
9
+
10
+ If you do not have the required hardware (discussed below) or do not want to download the code and set up an environment to run it, click [![this link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pkhungurn/talking-head-anime-3-demo/blob/master/colab.ipynb) to try running the manual poser on [Google Colab](https://research.google.com/colaboratory/faq.html).
11
+
12
+ ## Hardware Requirements
13
+
14
+ Both programs require a recent and powerful Nvidia GPU to run. I could personally ran them at good speed with the Nvidia Titan RTX. However, I think recent high-end gaming GPUs such as the RTX 2080, the RTX 3080, or better would do just as well.
15
+
16
+ The `ifacialmocap_puppeteer` requires an iOS device that is capable of computing [blend shape parameters](https://developer.apple.com/documentation/arkit/arfaceanchor/2928251-blendshapes) from a video feed. This means that the device must be able to run iOS 11.0 or higher and must have a TrueDepth front-facing camera. (See [this page](https://developer.apple.com/documentation/arkit/content_anchors/tracking_and_visualizing_faces) for more info.) In other words, if you have the iPhone X or something better, you should be all set. Personally, I have used an iPhone 12 mini.
17
+
18
+ ## Software Requirements
19
+
20
+ ### GPU Related Software
21
+
22
+ Please update your GPU's device driver and install the [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) that is compatible with your GPU and is newer than the version you will be installing in the next subsection.
23
+
24
+ ### Python Environment
25
+
26
+ Both ``manual_poser`` and ``ifacialmocap_puppeteer`` are available as desktop applications. To run them, you need to set up an environment for running programs written in the [Python](http://www.python.org) language. The environment needs to have the following software packages:
27
+
28
+ * Python >= 3.8
29
+ * PyTorch >= 1.11.0 with CUDA support
30
+ * SciPY >= 1.7.3
31
+ * wxPython >= 4.1.1
32
+ * Matplotlib >= 3.5.1
33
+
34
+ One way to do so is to install [Anaconda](https://www.anaconda.com/) and run the following commands in your shell:
35
+
36
+ ```
37
+ > conda create -n talking-head-anime-3-demo python=3.8
38
+ > conda activate talking-head-anime-3-demo
39
+ > conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
40
+ > conda install scipy
41
+ > pip install wxpython
42
+ > conda install matplotlib
43
+ ```
44
+
45
+ #### Caveat 1: Do not use Python 3.10 on Windows
46
+
47
+ As of June 2006, you cannot use [wxPython](https://www.wxpython.org/) with Python 3.10 on Windows. As a result, do not use Python 3.10 until [this bug](https://github.com/wxWidgets/Phoenix/issues/2024) is fixed. This means you should not set ``python=3.10`` in the first ``conda`` command in the listing above.
48
+
49
+ #### Caveat 2: Adjust versions of Python and CUDA Toolkit as needed
50
+
51
+ The environment created by the commands above gives you Python version 3.8 and an installation of [PyTorch](http://pytorch.org) that was compiled with CUDA Toolkit version 11.3. This particular setup might not work in the future because you may find that this particular PyTorch package does not work with your new computer. The solution is to:
52
+
53
+ 1. Change the Python version in the first command to a recent one that works for your OS. (That is, do not use 3.10 if you are using Windows.)
54
+ 2. Change the version of CUDA toolkit in the third command to one that the PyTorch's website says is available. In particular, scroll to the "Install PyTorch" section and use the chooser there to pick the right command for your computer. Use that command to install PyTorch instead of the third command above.
55
+
56
+ ![The command to install PyTorch](docs/pytorch-install-command.png "The command to install PyTorch")
57
+
58
+ ### Jupyter Environment
59
+
60
+ The ``manual_poser`` is also available as a [Jupyter Nootbook](http://jupyter.org). To run it on your local machines, you also need to install:
61
+
62
+ * Jupyter Notebook >= 7.3.4
63
+ * IPywidgets >= 7.7.0
64
+
65
+ In some case, you will also need to enable the ``widgetsnbextension`` as well. So, run
66
+
67
+ ```
68
+ > jupyter nbextension enable --py widgetsnbextension
69
+ ```
70
+
71
+ After installing the above two packages. Using Anaconda, I managed to do the above with the following commands:
72
+
73
+ ```
74
+ > conda install -c conda-forge notebook
75
+ > conda install -c conda-forge ipywidgets
76
+ > jupyter nbextension enable --py widgetsnbextension
77
+ ```
78
+
79
+ ### Automatic Environment Construction with Anaconda
80
+
81
+ You can also use Anaconda to download and install all Python packages in one command. Open your shell, change the directory to where you clone the repository, and run:
82
+
83
+ ```
84
+ > conda env create -f environment.yml
85
+ ```
86
+
87
+ This will create an environment called ``talking-head-anime-3-demo`` containing all the required Python packages.
88
+
89
+ ### iFacialMocap
90
+
91
+ If you want to use ``ifacialmocap_puppeteer``, you will also need to an iOS software called [iFacialMocap](https://www.ifacialmocap.com/) (a 980 yen purchase in the App Store). You do not need to download the paired application this time. Your iOS and your computer must use the same network. For example, you may connect them to the same wireless router.
92
+
93
+ ## Download the Models
94
+
95
+ Before running the programs, you need to download the model files from this [Dropbox link](https://www.dropbox.com/s/y7b8jl4n2euv8xe/talking-head-anime-3-models.zip?dl=0) and unzip it to the ``data/models`` folder under the repository's root directory. In the end, the data folder should look like:
96
+
97
+ ```
98
+ + data
99
+ + images
100
+ - crypko_00.png
101
+ - crypko_01.png
102
+ :
103
+ - crypko_07.png
104
+ - lambda_00.png
105
+ - lambda_01.png
106
+ + models
107
+ + separable_float
108
+ - editor.pt
109
+ - eyebrow_decomposer.pt
110
+ - eyebrow_morphing_combiner.pt
111
+ - face_morpher.pt
112
+ - two_algo_face_body_rotator.pt
113
+ + separable_half
114
+ - editor.pt
115
+ :
116
+ - two_algo_face_body_rotator.pt
117
+ + standard_float
118
+ - editor.pt
119
+ :
120
+ - two_algo_face_body_rotator.pt
121
+ + standard_half
122
+ - editor.pt
123
+ :
124
+ - two_algo_face_body_rotator.pt
125
+ ```
126
+
127
+ The model files are distributed with the
128
+ [Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/legalcode), which
129
+ means that you can use them for commercial purposes. However, if you distribute them, you must, among other things, say
130
+ that I am the creator.
131
+
132
+ ## Running the `manual_poser` Desktop Application
133
+
134
+ Open a shell. Change your working directory to the repository's root directory. Then, run:
135
+
136
+ ```
137
+ > python tha3/app/manual_poser.py
138
+ ```
139
+
140
+ Note that before running the command above, you might have to activate the Python environment that contains the required
141
+ packages. If you created an environment using Anaconda as was discussed above, you need to run
142
+
143
+ ```
144
+ > conda activate talking-head-anime-3-demo
145
+ ```
146
+
147
+ if you have not already activated the environment.
148
+
149
+ ### Choosing System Variant to Use
150
+
151
+ As noted in the [project's write-up](http://pkhungurn.github.io/talking-head-anime-3/index.html), I created 4 variants of the neural network system. They are called ``standard_float``, ``separable_float``, ``standard_half``, and ``separable_half``. All of them have the same functionalities, but they differ in their sizes, RAM usage, speed, and accuracy. You can specify which variant that the ``manual_poser`` program uses through the ``--model`` command line option.
152
+
153
+ ```
154
+ > python tha3/app/manual_poser --model <variant_name>
155
+ ```
156
+
157
+ where ``<variant_name>`` must be one of the 4 names above. If no variant is specified, the ``standard_float`` variant (which is the largest, slowest, and most accurate) will be used.
158
+
159
+ ## Running the `manual_poser` Jupyter Notebook
160
+
161
+ Open a shell. Activate the environment. Change your working directory to the repository's root directory. Then, run:
162
+
163
+ ```
164
+ > jupyter notebook
165
+ ```
166
+
167
+ A browser window should open. In it, open `manual_poser.ipynb`. Once you have done so, you should see that it has two cells. Run the two cells in order. Then, scroll down to the end of the document, and you'll see the GUI there.
168
+
169
+ You can choose the system variant to use by changing the ``MODEL_NAME`` variable in the first cell. If you do, you will need to rerun both cells in order for the variant to be loaded and the GUI to be properly updated to use it.
170
+
171
+ ## Running the `ifacialmocap_poser`
172
+
173
+ First, run iFacialMocap on your iOS device. It should show you the device's IP address. Jot it down. Keep the app open.
174
+
175
+ ![IP address in iFacialMocap screen](docs/ifacialmocap_ip.jpg "IP address in iFacialMocap screen")
176
+
177
+ Open a shell. Activate the Python environment. Change your working directory to the repository's root directory. Then, run:
178
+
179
+ ```
180
+ > python tha3/app/ifacialmocap_puppeteer.py
181
+ ```
182
+
183
+ You will see a text box with label "Capture Device IP." Write the iOS device's IP address that you jotted down there.
184
+
185
+ ![Write IP address of your iOS device in the 'Capture Device IP' text box.](docs/ifacialmocap_puppeteer_ip_address_box.png "Write IP address of your iOS device in the 'Capture Device IP' text box.")
186
+
187
+ Click the "START CAPTURE!" button to the right.
188
+
189
+ ![Click the 'START CAPTURE!' button.](docs/ifacialmocap_puppeteer_click_start_capture.png "Click the 'START CAPTURE!' button.")
190
+
191
+ If the programs are connected properly, you should see the numbers in the bottom part of the window change when you move your head.
192
+
193
+ ![The numbers in the bottom part of the window should change when you move your head.](docs/ifacialmocap_puppeteer_numbers.png "The numbers in the bottom part of the window should change when you move your head.")
194
+
195
+ Now, you can load an image of a character, and it should follow your facial movement.
196
+
197
+ ## Contraints on Input Images
198
+
199
+ In order for the system to work well, the input image must obey the following constraints:
200
+
201
+ * It should be of resolution 512 x 512. (If the demo programs receives an input image of any other size, they will resize the image to this resolution and also output at this resolution.)
202
+ * It must have an alpha channel.
203
+ * It must contain only one humanoid character.
204
+ * The character should be standing upright and facing forward.
205
+ * The character's hands should be below and far from the head.
206
+ * The head of the character should roughly be contained in the 128 x 128 box in the middle of the top half of the image.
207
+ * The alpha channels of all pixels that do not belong to the character (i.e., background pixels) must be 0.
208
+
209
+ ![An example of an image that conforms to the above criteria](docs/input_spec.png "An example of an image that conforms to the above criteria")
210
+
211
+ See the project's [write-up](http://pkhungurn.github.io/talking-head-anime-3/full.html#sec:problem-spec) for more details on the input image.
212
+
213
+ ## Citation
214
+
215
+ If your academic work benefits from the code in this repository, please cite the project's web page as follows:
216
+
217
+ > Pramook Khungurn. **Talking Head(?) Anime from a Single Image 3: Now the Body Too.** http://pkhungurn.github.io/talking-head-anime-3/, 2022. Accessed: YYYY-MM-DD.
218
+
219
+ You can also used the following BibTex entry:
220
+
221
+ ```
222
+ @misc{Khungurn:2022,
223
+ author = {Pramook Khungurn},
224
+ title = {Talking Head(?) Anime from a Single Image 3: Now the Body Too},
225
+ howpublished = {\url{http://pkhungurn.github.io/talking-head-anime-3/}},
226
+ year = 2022,
227
+ note = {Accessed: YYYY-MM-DD},
228
+ }
229
+ ```
230
+
231
+ ## Disclaimer
232
+
233
+ While the author is an employee of [Google Japan](https://careers.google.com/locations/tokyo/), this software is not Google's product and is not supported by Google.
234
+
235
+ The copyright of this software belongs to me as I have requested it using the [IARC process](https://opensource.google/documentation/reference/releasing#iarc). However, Google might claim the rights to the intellectual
236
+ property of this invention.
237
+
238
+ The code is released under the [MIT license](https://github.com/pkhungurn/talking-head-anime-2-demo/blob/master/LICENSE).
239
+ The model is released under the [Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/legalcode). Please see the README.md file in the ``data/images`` directory for the licenses for the images there.
colab.ipynb ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "1027b46a",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Talking Head(?) Anime from a Single Image 3: Now the Body Too (Manual Poser Tool)\n",
9
+ "\n",
10
+ "**Instruction**\n",
11
+ "\n",
12
+ "1. Run the four cells below, one by one, in order by clicking the \"Play\" button to the left of it. Wait for each cell to finish before going to the next one.\n",
13
+ "2. Scroll down to the end of the last cell, and play with the GUI.\n",
14
+ "\n",
15
+ "**Links**\n",
16
+ "\n",
17
+ "* Github repository: http://github.com/pkhungurn/talking-head-anime-3-demo\n",
18
+ "* Project writeup: http://pkhungurn.github.io/talking-head-anime-3/"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "54cc96d7",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# Clone the repository\n",
29
+ "%cd /content\n",
30
+ "!git clone https://github.com/pkhungurn/talking-head-anime-3-demo.git"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "id": "77f2016c",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "# CD into the repository directory.\n",
41
+ "%cd /content/talking-head-anime-3-demo"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "id": "1771c927",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "# Download model files\n",
52
+ "!mkdir -p data/models/standard_float\n",
53
+ "!wget -O data/models/standard_float/editor.pt https://www.dropbox.com/s/zp3e5ox57sdws3y/editor.pt?dl=0\n",
54
+ "!wget -O data/models/standard_float/eyebrow_decomposer.pt https://www.dropbox.com/s/bcp42knbrk7egk8/eyebrow_decomposer.pt?dl=0\n",
55
+ "!wget -O data/models/standard_float/eyebrow_morphing_combiner.pt https://www.dropbox.com/s/oywaiio2s53lc57/eyebrow_morphing_combiner.pt?dl=0\n",
56
+ "!wget -O data/models/standard_float/face_morpher.pt https://www.dropbox.com/s/8qvo0u5lw7hqvtq/face_morpher.pt?dl=0\n",
57
+ "!wget -O data/models/standard_float/two_algo_face_body_rotator.pt https://www.dropbox.com/s/qmq1dnxrmzsxb4h/two_algo_face_body_rotator.pt?dl=0\n",
58
+ "\n",
59
+ "!mkdir -p data/models/standard_half\n",
60
+ "!wget -O data/models/standard_half/editor.pt https://www.dropbox.com/s/g21ps8gfuvz4kbo/editor.pt?dl=0\n",
61
+ "!wget -O data/models/standard_half/eyebrow_decomposer.pt https://www.dropbox.com/s/nwwwevzpmxiilgn/eyebrow_decomposer.pt?dl=0\n",
62
+ "!wget -O data/models/standard_half/eyebrow_morphing_combiner.pt https://www.dropbox.com/s/z5v0amgqif7yup1/eyebrow_morphing_combiner.pt?dl=0\n",
63
+ "!wget -O data/models/standard_half/face_morpher.pt https://www.dropbox.com/s/g03sfnd5yfs0m65/face_morpher.pt?dl=0\n",
64
+ "!wget -O data/models/standard_half/two_algo_face_body_rotator.pt https://www.dropbox.com/s/c5lrn7z34x12317/two_algo_face_body_rotator.pt?dl=0\n",
65
+ "\n",
66
+ "!mkdir -p data/models/separable_float \n",
67
+ "!wget -O data/models/separable_float/editor.pt https://www.dropbox.com/s/nwdxhrpa9fy19r4/editor.pt?dl=0\n",
68
+ "!wget -O data/models/separable_float/eyebrow_decomposer.pt https://www.dropbox.com/s/hfzjcu9cqr9wm3i/eyebrow_decomposer.pt?dl=0\n",
69
+ "!wget -O data/models/separable_float/eyebrow_morphing_combiner.pt https://www.dropbox.com/s/g04dyyyavh5o1e2/eyebrow_morphing_combiner.pt?dl=0\n",
70
+ "!wget -O data/models/separable_float/face_morpher.pt https://www.dropbox.com/s/vgi9dsj95y0rrwv/face_morpher.pt?dl=0\n",
71
+ "!wget -O data/models/separable_float/two_algo_face_body_rotator.pt https://www.dropbox.com/s/8u0qond8po34l24/two_algo_face_body_rotator.pt?dl=0\n",
72
+ "\n",
73
+ "!mkdir -p data/models/separable_half\n",
74
+ "!wget -O data/models/separable_half/editor.pt https://www.dropbox.com/s/on8kn6z9fj95j0h/editor.pt?dl=0\n",
75
+ "!wget -O data/models/separable_half/eyebrow_decomposer.pt https://www.dropbox.com/s/0hxu8opu1hmghqe/eyebrow_decomposer.pt?dl=0\n",
76
+ "!wget -O data/models/separable_half/eyebrow_morphing_combiner.pt https://www.dropbox.com/s/bgz02afp0xojqfs/eyebrow_morphing_combiner.pt?dl=0\n",
77
+ "!wget -O data/models/separable_half/face_morpher.pt https://www.dropbox.com/s/bgz02afp0xojqfs/eyebrow_morphing_combiner.pt?dl=0\n",
78
+ "!wget -O data/models/separable_half/two_algo_face_body_rotator.pt https://www.dropbox.com/s/vr8h2xxltszhw7w/two_algo_face_body_rotator.pt?dl=0"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "062014f7",
85
+ "metadata": {
86
+ "id": "breeding-extra"
87
+ },
88
+ "outputs": [],
89
+ "source": [
90
+ "# Set this constant to specify which system variant to use.\n",
91
+ "MODEL_NAME = \"standard_float\" \n",
92
+ "\n",
93
+ "# Load the models.\n",
94
+ "import torch\n",
95
+ "DEVICE_NAME = 'cuda'\n",
96
+ "device = torch.device(DEVICE_NAME)\n",
97
+ "\n",
98
+ "def load_poser(model: str, device: torch.device):\n",
99
+ " print(\"Using the %s model.\" % model)\n",
100
+ " if model == \"standard_float\":\n",
101
+ " from tha3.poser.modes.standard_float import create_poser\n",
102
+ " return create_poser(device)\n",
103
+ " elif model == \"standard_half\":\n",
104
+ " from tha3.poser.modes.standard_half import create_poser\n",
105
+ " return create_poser(device)\n",
106
+ " elif model == \"separable_float\":\n",
107
+ " from tha3.poser.modes.separable_float import create_poser\n",
108
+ " return create_poser(device)\n",
109
+ " elif model == \"separable_half\":\n",
110
+ " from tha3.poser.modes.separable_half import create_poser\n",
111
+ " return create_poser(device)\n",
112
+ " else:\n",
113
+ " raise RuntimeError(\"Invalid model: '%s'\" % model)\n",
114
+ " \n",
115
+ "poser = load_poser(MODEL_NAME, DEVICE_NAME)\n",
116
+ "poser.get_modules();"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "id": "breeding-extra",
123
+ "metadata": {
124
+ "id": "breeding-extra"
125
+ },
126
+ "outputs": [],
127
+ "source": [
128
+ "# Create the GUI for manipulating character images.\n",
129
+ "import PIL.Image\n",
130
+ "import io\n",
131
+ "from io import StringIO, BytesIO\n",
132
+ "import IPython.display\n",
133
+ "import numpy\n",
134
+ "import ipywidgets\n",
135
+ "import time\n",
136
+ "import threading\n",
137
+ "import torch\n",
138
+ "from tha3.util import resize_PIL_image, extract_PIL_image_from_filelike, \\\n",
139
+ " extract_pytorch_image_from_PIL_image, convert_output_image_from_torch_to_numpy\n",
140
+ "\n",
141
+ "FRAME_RATE = 30.0\n",
142
+ "\n",
143
+ "last_torch_input_image = None\n",
144
+ "torch_input_image = None\n",
145
+ "\n",
146
+ "def show_pytorch_image(pytorch_image):\n",
147
+ " output_image = pytorch_image.detach().cpu()\n",
148
+ " numpy_image = numpy.uint8(numpy.rint(convert_output_image_from_torch_to_numpy(output_image) * 255.0))\n",
149
+ " pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')\n",
150
+ " IPython.display.display(pil_image)\n",
151
+ "\n",
152
+ "upload_input_image_button = ipywidgets.FileUpload(\n",
153
+ " accept='.png',\n",
154
+ " multiple=False,\n",
155
+ " layout={\n",
156
+ " 'width': '512px'\n",
157
+ " }\n",
158
+ ")\n",
159
+ "\n",
160
+ "output_image_widget = ipywidgets.Output(\n",
161
+ " layout={\n",
162
+ " 'border': '1px solid black',\n",
163
+ " 'width': '512px',\n",
164
+ " 'height': '512px'\n",
165
+ " }\n",
166
+ ")\n",
167
+ "\n",
168
+ "eyebrow_dropdown = ipywidgets.Dropdown(\n",
169
+ " options=[\"troubled\", \"angry\", \"lowered\", \"raised\", \"happy\", \"serious\"],\n",
170
+ " value=\"troubled\",\n",
171
+ " description=\"Eyebrow:\", \n",
172
+ ")\n",
173
+ "eyebrow_left_slider = ipywidgets.FloatSlider(\n",
174
+ " value=0.0,\n",
175
+ " min=0.0,\n",
176
+ " max=1.0,\n",
177
+ " step=0.01,\n",
178
+ " description=\"Left:\",\n",
179
+ " readout=True,\n",
180
+ " readout_format=\".2f\"\n",
181
+ ")\n",
182
+ "eyebrow_right_slider = ipywidgets.FloatSlider(\n",
183
+ " value=0.0,\n",
184
+ " min=0.0,\n",
185
+ " max=1.0,\n",
186
+ " step=0.01,\n",
187
+ " description=\"Right:\",\n",
188
+ " readout=True,\n",
189
+ " readout_format=\".2f\"\n",
190
+ ")\n",
191
+ "\n",
192
+ "eye_dropdown = ipywidgets.Dropdown(\n",
193
+ " options=[\"wink\", \"happy_wink\", \"surprised\", \"relaxed\", \"unimpressed\", \"raised_lower_eyelid\"],\n",
194
+ " value=\"wink\",\n",
195
+ " description=\"Eye:\", \n",
196
+ ")\n",
197
+ "eye_left_slider = ipywidgets.FloatSlider(\n",
198
+ " value=0.0,\n",
199
+ " min=0.0,\n",
200
+ " max=1.0,\n",
201
+ " step=0.01,\n",
202
+ " description=\"Left:\",\n",
203
+ " readout=True,\n",
204
+ " readout_format=\".2f\"\n",
205
+ ")\n",
206
+ "eye_right_slider = ipywidgets.FloatSlider(\n",
207
+ " value=0.0,\n",
208
+ " min=0.0,\n",
209
+ " max=1.0,\n",
210
+ " step=0.01,\n",
211
+ " description=\"Right:\",\n",
212
+ " readout=True,\n",
213
+ " readout_format=\".2f\"\n",
214
+ ")\n",
215
+ "\n",
216
+ "mouth_dropdown = ipywidgets.Dropdown(\n",
217
+ " options=[\"aaa\", \"iii\", \"uuu\", \"eee\", \"ooo\", \"delta\", \"lowered_corner\", \"raised_corner\", \"smirk\"],\n",
218
+ " value=\"aaa\",\n",
219
+ " description=\"Mouth:\", \n",
220
+ ")\n",
221
+ "mouth_left_slider = ipywidgets.FloatSlider(\n",
222
+ " value=0.0,\n",
223
+ " min=0.0,\n",
224
+ " max=1.0,\n",
225
+ " step=0.01,\n",
226
+ " description=\"Value:\",\n",
227
+ " readout=True,\n",
228
+ " readout_format=\".2f\"\n",
229
+ ")\n",
230
+ "mouth_right_slider = ipywidgets.FloatSlider(\n",
231
+ " value=0.0,\n",
232
+ " min=0.0,\n",
233
+ " max=1.0,\n",
234
+ " step=0.01,\n",
235
+ " description=\" \",\n",
236
+ " readout=True,\n",
237
+ " readout_format=\".2f\",\n",
238
+ " disabled=True,\n",
239
+ ")\n",
240
+ "\n",
241
+ "def update_mouth_sliders(change):\n",
242
+ " if mouth_dropdown.value == \"lowered_corner\" or mouth_dropdown.value == \"raised_corner\":\n",
243
+ " mouth_left_slider.description = \"Left:\"\n",
244
+ " mouth_right_slider.description = \"Right:\"\n",
245
+ " mouth_right_slider.disabled = False\n",
246
+ " else:\n",
247
+ " mouth_left_slider.description = \"Value:\"\n",
248
+ " mouth_right_slider.description = \" \"\n",
249
+ " mouth_right_slider.disabled = True\n",
250
+ "\n",
251
+ "mouth_dropdown.observe(update_mouth_sliders, names='value')\n",
252
+ "\n",
253
+ "iris_small_left_slider = ipywidgets.FloatSlider(\n",
254
+ " value=0.0,\n",
255
+ " min=0.0,\n",
256
+ " max=1.0,\n",
257
+ " step=0.01,\n",
258
+ " description=\"Left:\",\n",
259
+ " readout=True,\n",
260
+ " readout_format=\".2f\"\n",
261
+ ")\n",
262
+ "iris_small_right_slider = ipywidgets.FloatSlider(\n",
263
+ " value=0.0,\n",
264
+ " min=0.0,\n",
265
+ " max=1.0,\n",
266
+ " step=0.01,\n",
267
+ " description=\"Right:\",\n",
268
+ " readout=True,\n",
269
+ " readout_format=\".2f\", \n",
270
+ ")\n",
271
+ "iris_rotation_x_slider = ipywidgets.FloatSlider(\n",
272
+ " value=0.0,\n",
273
+ " min=-1.0,\n",
274
+ " max=1.0,\n",
275
+ " step=0.01,\n",
276
+ " description=\"X-axis:\",\n",
277
+ " readout=True,\n",
278
+ " readout_format=\".2f\"\n",
279
+ ")\n",
280
+ "iris_rotation_y_slider = ipywidgets.FloatSlider(\n",
281
+ " value=0.0,\n",
282
+ " min=-1.0,\n",
283
+ " max=1.0,\n",
284
+ " step=0.01,\n",
285
+ " description=\"Y-axis:\",\n",
286
+ " readout=True,\n",
287
+ " readout_format=\".2f\", \n",
288
+ ")\n",
289
+ "\n",
290
+ "head_x_slider = ipywidgets.FloatSlider(\n",
291
+ " value=0.0,\n",
292
+ " min=-1.0,\n",
293
+ " max=1.0,\n",
294
+ " step=0.01,\n",
295
+ " description=\"X-axis:\",\n",
296
+ " readout=True,\n",
297
+ " readout_format=\".2f\"\n",
298
+ ")\n",
299
+ "head_y_slider = ipywidgets.FloatSlider(\n",
300
+ " value=0.0,\n",
301
+ " min=-1.0,\n",
302
+ " max=1.0,\n",
303
+ " step=0.01,\n",
304
+ " description=\"Y-axis:\",\n",
305
+ " readout=True,\n",
306
+ " readout_format=\".2f\", \n",
307
+ ")\n",
308
+ "neck_z_slider = ipywidgets.FloatSlider(\n",
309
+ " value=0.0,\n",
310
+ " min=-1.0,\n",
311
+ " max=1.0,\n",
312
+ " step=0.01,\n",
313
+ " description=\"Z-axis:\",\n",
314
+ " readout=True,\n",
315
+ " readout_format=\".2f\", \n",
316
+ ")\n",
317
+ "body_y_slider = ipywidgets.FloatSlider(\n",
318
+ " value=0.0,\n",
319
+ " min=-1.0,\n",
320
+ " max=1.0,\n",
321
+ " step=0.01,\n",
322
+ " description=\"Y-axis rotation:\",\n",
323
+ " readout=True,\n",
324
+ " readout_format=\".2f\", \n",
325
+ ")\n",
326
+ "body_z_slider = ipywidgets.FloatSlider(\n",
327
+ " value=0.0,\n",
328
+ " min=-1.0,\n",
329
+ " max=1.0,\n",
330
+ " step=0.01,\n",
331
+ " description=\"Z-axis rotation:\",\n",
332
+ " readout=True,\n",
333
+ " readout_format=\".2f\", \n",
334
+ ")\n",
335
+ "breathing_slider = ipywidgets.FloatSlider(\n",
336
+ " value=0.0,\n",
337
+ " min=0.0,\n",
338
+ " max=1.0,\n",
339
+ " step=0.01,\n",
340
+ " description=\"Breathing:\",\n",
341
+ " readout=True,\n",
342
+ " readout_format=\".2f\", \n",
343
+ ")\n",
344
+ "\n",
345
+ "\n",
346
+ "control_panel = ipywidgets.VBox([\n",
347
+ " eyebrow_dropdown,\n",
348
+ " eyebrow_left_slider,\n",
349
+ " eyebrow_right_slider,\n",
350
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
351
+ " eye_dropdown,\n",
352
+ " eye_left_slider,\n",
353
+ " eye_right_slider,\n",
354
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
355
+ " mouth_dropdown,\n",
356
+ " mouth_left_slider,\n",
357
+ " mouth_right_slider,\n",
358
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
359
+ " ipywidgets.HTML(value=\"<center><b>Iris Shrinkage</b></center>\"),\n",
360
+ " iris_small_left_slider,\n",
361
+ " iris_small_right_slider,\n",
362
+ " ipywidgets.HTML(value=\"<center><b>Iris Rotation</b></center>\"),\n",
363
+ " iris_rotation_x_slider,\n",
364
+ " iris_rotation_y_slider,\n",
365
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
366
+ " ipywidgets.HTML(value=\"<center><b>Head Rotation</b></center>\"),\n",
367
+ " head_x_slider,\n",
368
+ " head_y_slider,\n",
369
+ " neck_z_slider,\n",
370
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
371
+ " ipywidgets.HTML(value=\"<center><b>Body Rotation</b></center>\"),\n",
372
+ " body_y_slider,\n",
373
+ " body_z_slider,\n",
374
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
375
+ " ipywidgets.HTML(value=\"<center><b>Breathing</b></center>\"),\n",
376
+ " breathing_slider,\n",
377
+ "])\n",
378
+ "\n",
379
+ "controls = ipywidgets.HBox([\n",
380
+ " ipywidgets.VBox([\n",
381
+ " output_image_widget, \n",
382
+ " upload_input_image_button\n",
383
+ " ]),\n",
384
+ " control_panel,\n",
385
+ "])\n",
386
+ "\n",
387
+ "from tha3.poser.modes.pose_parameters import get_pose_parameters\n",
388
+ "pose_parameters = get_pose_parameters()\n",
389
+ "pose_size = poser.get_num_parameters()\n",
390
+ "last_pose = torch.zeros(1, pose_size, dtype=poser.get_dtype()).to(device)\n",
391
+ "\n",
392
+ "iris_small_left_index = pose_parameters.get_parameter_index(\"iris_small_left\")\n",
393
+ "iris_small_right_index = pose_parameters.get_parameter_index(\"iris_small_right\")\n",
394
+ "iris_rotation_x_index = pose_parameters.get_parameter_index(\"iris_rotation_x\")\n",
395
+ "iris_rotation_y_index = pose_parameters.get_parameter_index(\"iris_rotation_y\")\n",
396
+ "head_x_index = pose_parameters.get_parameter_index(\"head_x\")\n",
397
+ "head_y_index = pose_parameters.get_parameter_index(\"head_y\")\n",
398
+ "neck_z_index = pose_parameters.get_parameter_index(\"neck_z\")\n",
399
+ "body_y_index = pose_parameters.get_parameter_index(\"body_y\")\n",
400
+ "body_z_index = pose_parameters.get_parameter_index(\"body_z\")\n",
401
+ "breathing_index = pose_parameters.get_parameter_index(\"breathing\")\n",
402
+ "\n",
403
+ "def get_pose():\n",
404
+ " pose = torch.zeros(1, pose_size, dtype=poser.get_dtype())\n",
405
+ "\n",
406
+ " eyebrow_name = f\"eyebrow_{eyebrow_dropdown.value}\"\n",
407
+ " eyebrow_left_index = pose_parameters.get_parameter_index(f\"{eyebrow_name}_left\")\n",
408
+ " eyebrow_right_index = pose_parameters.get_parameter_index(f\"{eyebrow_name}_right\")\n",
409
+ " pose[0, eyebrow_left_index] = eyebrow_left_slider.value\n",
410
+ " pose[0, eyebrow_right_index] = eyebrow_right_slider.value\n",
411
+ "\n",
412
+ " eye_name = f\"eye_{eye_dropdown.value}\"\n",
413
+ " eye_left_index = pose_parameters.get_parameter_index(f\"{eye_name}_left\")\n",
414
+ " eye_right_index = pose_parameters.get_parameter_index(f\"{eye_name}_right\")\n",
415
+ " pose[0, eye_left_index] = eye_left_slider.value\n",
416
+ " pose[0, eye_right_index] = eye_right_slider.value\n",
417
+ "\n",
418
+ " mouth_name = f\"mouth_{mouth_dropdown.value}\"\n",
419
+ " if mouth_name == \"mouth_lowered_corner\" or mouth_name == \"mouth_raised_corner\":\n",
420
+ " mouth_left_index = pose_parameters.get_parameter_index(f\"{mouth_name}_left\")\n",
421
+ " mouth_right_index = pose_parameters.get_parameter_index(f\"{mouth_name}_right\")\n",
422
+ " pose[0, mouth_left_index] = mouth_left_slider.value\n",
423
+ " pose[0, mouth_right_index] = mouth_right_slider.value\n",
424
+ " else:\n",
425
+ " mouth_index = pose_parameters.get_parameter_index(mouth_name)\n",
426
+ " pose[0, mouth_index] = mouth_left_slider.value\n",
427
+ "\n",
428
+ " pose[0, iris_small_left_index] = iris_small_left_slider.value\n",
429
+ " pose[0, iris_small_right_index] = iris_small_right_slider.value\n",
430
+ " pose[0, iris_rotation_x_index] = iris_rotation_x_slider.value\n",
431
+ " pose[0, iris_rotation_y_index] = iris_rotation_y_slider.value\n",
432
+ " pose[0, head_x_index] = head_x_slider.value\n",
433
+ " pose[0, head_y_index] = head_y_slider.value\n",
434
+ " pose[0, neck_z_index] = neck_z_slider.value\n",
435
+ " pose[0, body_y_index] = body_y_slider.value\n",
436
+ " pose[0, body_z_index] = body_z_slider.value\n",
437
+ " pose[0, breathing_index] = breathing_slider.value\n",
438
+ "\n",
439
+ " return pose.to(device)\n",
440
+ "\n",
441
+ "display(controls)\n",
442
+ "\n",
443
+ "def update(change):\n",
444
+ " global last_pose\n",
445
+ " global last_torch_input_image\n",
446
+ "\n",
447
+ " if torch_input_image is None:\n",
448
+ " return\n",
449
+ "\n",
450
+ " needs_update = False\n",
451
+ " if last_torch_input_image is None:\n",
452
+ " needs_update = True \n",
453
+ " else:\n",
454
+ " if (torch_input_image - last_torch_input_image).abs().max().item() > 0:\n",
455
+ " needs_update = True \n",
456
+ "\n",
457
+ " pose = get_pose()\n",
458
+ " if (pose - last_pose).abs().max().item() > 0:\n",
459
+ " needs_update = True\n",
460
+ "\n",
461
+ " if not needs_update:\n",
462
+ " return\n",
463
+ "\n",
464
+ " output_image = poser.pose(torch_input_image, pose)[0]\n",
465
+ " with output_image_widget:\n",
466
+ " output_image_widget.clear_output(wait=True)\n",
467
+ " show_pytorch_image(output_image) \n",
468
+ "\n",
469
+ " last_torch_input_image = torch_input_image\n",
470
+ " last_pose = pose\n",
471
+ "\n",
472
+ "def upload_image(change):\n",
473
+ " global torch_input_image\n",
474
+ " for name, file_info in upload_input_image_button.value.items():\n",
475
+ " content = io.BytesIO(file_info['content'])\n",
476
+ " if content is not None:\n",
477
+ " pil_image = resize_PIL_image(extract_PIL_image_from_filelike(content), size=(512,512))\n",
478
+ " w, h = pil_image.size\n",
479
+ " if pil_image.mode != 'RGBA':\n",
480
+ " with output_image_widget:\n",
481
+ " torch_input_image = None\n",
482
+ " output_image_widget.clear_output(wait=True)\n",
483
+ " display(ipywidgets.HTML(\"Image must have an alpha channel!!!\"))\n",
484
+ " else:\n",
485
+ " torch_input_image = extract_pytorch_image_from_PIL_image(pil_image).to(device)\n",
486
+ " if poser.get_dtype() == torch.half:\n",
487
+ " torch_input_image = torch_input_image.half()\n",
488
+ " update(None)\n",
489
+ "\n",
490
+ "upload_input_image_button.observe(upload_image, names='value')\n",
491
+ "eyebrow_dropdown.observe(update, 'value')\n",
492
+ "eyebrow_left_slider.observe(update, 'value')\n",
493
+ "eyebrow_right_slider.observe(update, 'value')\n",
494
+ "eye_dropdown.observe(update, 'value')\n",
495
+ "eye_left_slider.observe(update, 'value')\n",
496
+ "eye_right_slider.observe(update, 'value')\n",
497
+ "mouth_dropdown.observe(update, 'value')\n",
498
+ "mouth_left_slider.observe(update, 'value')\n",
499
+ "mouth_right_slider.observe(update, 'value')\n",
500
+ "iris_small_left_slider.observe(update, 'value')\n",
501
+ "iris_small_right_slider.observe(update, 'value')\n",
502
+ "iris_rotation_x_slider.observe(update, 'value')\n",
503
+ "iris_rotation_y_slider.observe(update, 'value')\n",
504
+ "head_x_slider.observe(update, 'value')\n",
505
+ "head_y_slider.observe(update, 'value')\n",
506
+ "neck_z_slider.observe(update, 'value')\n",
507
+ "body_y_slider.observe(update, 'value')\n",
508
+ "body_z_slider.observe(update, 'value')\n",
509
+ "breathing_slider.observe(update, 'value')"
510
+ ]
511
+ }
512
+ ],
513
+ "metadata": {
514
+ "accelerator": "GPU",
515
+ "colab": {
516
+ "name": "tha3.ipynb",
517
+ "provenance": []
518
+ },
519
+ "interpreter": {
520
+ "hash": "684906ad716c90e6f3397644b72c2a23821e93080f6b0264e4cd74aee22032ce"
521
+ },
522
+ "kernelspec": {
523
+ "display_name": "Python 3 (ipykernel)",
524
+ "language": "python",
525
+ "name": "python3"
526
+ },
527
+ "language_info": {
528
+ "codemirror_mode": {
529
+ "name": "ipython",
530
+ "version": 3
531
+ },
532
+ "file_extension": ".py",
533
+ "mimetype": "text/x-python",
534
+ "name": "python",
535
+ "nbconvert_exporter": "python",
536
+ "pygments_lexer": "ipython3",
537
+ "version": "3.8.13"
538
+ }
539
+ },
540
+ "nbformat": 4,
541
+ "nbformat_minor": 5
542
+ }
docs/ifacialmocap_ip.jpg ADDED
docs/ifacialmocap_puppeteer_click_start_capture.png ADDED
docs/ifacialmocap_puppeteer_ip_address_box.png ADDED
docs/ifacialmocap_puppeteer_numbers.png ADDED
docs/input_spec.png ADDED
docs/pytorch-install-command.png ADDED
environment.yml ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: talking-head-anime-3-demo
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - argon2-cffi=21.3.0=pyhd8ed1ab_0
8
+ - argon2-cffi-bindings=21.2.0=py38h294d835_2
9
+ - asttokens=2.0.5=pyhd8ed1ab_0
10
+ - attrs=21.4.0=pyhd8ed1ab_0
11
+ - backcall=0.2.0=pyh9f0ad1d_0
12
+ - backports=1.0=py_2
13
+ - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
14
+ - beautifulsoup4=4.11.1=pyha770c72_0
15
+ - blas=1.0=mkl
16
+ - bleach=5.0.0=pyhd8ed1ab_0
17
+ - brotli=1.0.9=ha925a31_2
18
+ - brotlipy=0.7.0=py38h2bbff1b_1003
19
+ - ca-certificates=2022.5.18.1=h5b45459_0
20
+ - certifi=2022.5.18.1=py38haa244fe_0
21
+ - cffi=1.15.0=py38h2bbff1b_1
22
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
23
+ - colorama=0.4.4=pyh9f0ad1d_0
24
+ - cryptography=37.0.1=py38h21b164f_0
25
+ - cudatoolkit=11.3.1=h59b6b97_2
26
+ - cycler=0.11.0=pyhd3eb1b0_0
27
+ - debugpy=1.6.0=py38h885f38d_0
28
+ - decorator=5.1.1=pyhd8ed1ab_0
29
+ - defusedxml=0.7.1=pyhd8ed1ab_0
30
+ - entrypoints=0.4=pyhd8ed1ab_0
31
+ - executing=0.8.3=pyhd8ed1ab_0
32
+ - flit-core=3.7.1=pyhd8ed1ab_0
33
+ - fonttools=4.25.0=pyhd3eb1b0_0
34
+ - freetype=2.10.4=hd328e21_0
35
+ - icc_rt=2019.0.0=h0cc432a_1
36
+ - icu=58.2=ha925a31_3
37
+ - idna=3.3=pyhd3eb1b0_0
38
+ - importlib-metadata=4.11.4=py38haa244fe_0
39
+ - importlib_resources=5.7.1=pyhd8ed1ab_1
40
+ - intel-openmp=2021.4.0=haa95532_3556
41
+ - ipykernel=6.13.1=py38h4317176_0
42
+ - ipython=8.4.0=py38haa244fe_0
43
+ - ipython_genutils=0.2.0=py_1
44
+ - ipywidgets=7.7.0=pyhd8ed1ab_0
45
+ - jedi=0.18.1=py38haa244fe_1
46
+ - jinja2=3.1.2=pyhd8ed1ab_1
47
+ - jpeg=9e=h2bbff1b_0
48
+ - jsonschema=4.6.0=pyhd8ed1ab_0
49
+ - jupyter_client=7.3.4=pyhd8ed1ab_0
50
+ - jupyter_core=4.10.0=py38haa244fe_0
51
+ - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0
52
+ - jupyterlab_widgets=1.1.0=pyhd8ed1ab_0
53
+ - kiwisolver=1.4.2=py38hd77b12b_0
54
+ - libpng=1.6.37=h2a8f88b_0
55
+ - libsodium=1.0.18=h8d14728_1
56
+ - libtiff=4.2.0=he0120a3_1
57
+ - libuv=1.40.0=he774522_0
58
+ - libwebp=1.2.2=h2bbff1b_0
59
+ - lz4-c=1.9.3=h2bbff1b_1
60
+ - markupsafe=2.1.1=py38h294d835_1
61
+ - matplotlib=3.5.1=py38haa95532_1
62
+ - matplotlib-base=3.5.1=py38hd77b12b_1
63
+ - matplotlib-inline=0.1.3=pyhd8ed1ab_0
64
+ - mistune=0.8.4=py38h294d835_1005
65
+ - mkl=2021.4.0=haa95532_640
66
+ - mkl-service=2.4.0=py38h2bbff1b_0
67
+ - mkl_fft=1.3.1=py38h277e83a_0
68
+ - mkl_random=1.2.2=py38hf11a4ad_0
69
+ - munkres=1.1.4=py_0
70
+ - nbclient=0.6.4=pyhd8ed1ab_1
71
+ - nbconvert=6.5.0=pyhd8ed1ab_0
72
+ - nbconvert-core=6.5.0=pyhd8ed1ab_0
73
+ - nbconvert-pandoc=6.5.0=pyhd8ed1ab_0
74
+ - nbformat=5.4.0=pyhd8ed1ab_0
75
+ - nest-asyncio=1.5.5=pyhd8ed1ab_0
76
+ - notebook=6.4.12=pyha770c72_0
77
+ - numpy=1.22.3=py38h7a0a035_0
78
+ - numpy-base=1.22.3=py38hca35cd5_0
79
+ - openssl=1.1.1o=h8ffe710_0
80
+ - packaging=21.3=pyhd3eb1b0_0
81
+ - pandoc=2.18=h57928b3_0
82
+ - pandocfilters=1.5.0=pyhd8ed1ab_0
83
+ - parso=0.8.3=pyhd8ed1ab_0
84
+ - pickleshare=0.7.5=py_1003
85
+ - pillow=9.0.1=py38hdc2b20a_0
86
+ - pip=21.2.2=py38haa95532_0
87
+ - prometheus_client=0.14.1=pyhd8ed1ab_0
88
+ - prompt-toolkit=3.0.29=pyha770c72_0
89
+ - psutil=5.9.1=py38h294d835_0
90
+ - pure_eval=0.2.2=pyhd8ed1ab_0
91
+ - pycparser=2.21=pyhd3eb1b0_0
92
+ - pygments=2.12.0=pyhd8ed1ab_0
93
+ - pyopenssl=22.0.0=pyhd3eb1b0_0
94
+ - pyparsing=3.0.4=pyhd3eb1b0_0
95
+ - pyqt=5.9.2=py38hd77b12b_6
96
+ - pyrsistent=0.18.1=py38h294d835_1
97
+ - pysocks=1.7.1=py38haa95532_0
98
+ - python=3.8.13=h6244533_0
99
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
100
+ - python-fastjsonschema=2.15.3=pyhd8ed1ab_0
101
+ - python_abi=3.8=2_cp38
102
+ - pytorch=1.11.0=py3.8_cuda11.3_cudnn8_0
103
+ - pytorch-mutex=1.0=cuda
104
+ - pywin32=303=py38h294d835_0
105
+ - pywinpty=2.0.2=py38h5da7b33_0
106
+ - pyzmq=23.1.0=py38h09162b1_0
107
+ - qt=5.9.7=vc14h73c81de_0
108
+ - requests=2.27.1=pyhd3eb1b0_0
109
+ - scipy=1.7.3=py38h0a974cb_0
110
+ - send2trash=1.8.0=pyhd8ed1ab_0
111
+ - setuptools=61.2.0=py38haa95532_0
112
+ - sip=4.19.13=py38hd77b12b_0
113
+ - six=1.16.0=pyhd3eb1b0_1
114
+ - soupsieve=2.3.1=pyhd8ed1ab_0
115
+ - sqlite=3.38.3=h2bbff1b_0
116
+ - stack_data=0.2.0=pyhd8ed1ab_0
117
+ - terminado=0.15.0=py38haa244fe_0
118
+ - tinycss2=1.1.1=pyhd8ed1ab_0
119
+ - tk=8.6.12=h2bbff1b_0
120
+ - torchaudio=0.11.0=py38_cu113
121
+ - torchvision=0.12.0=py38_cu113
122
+ - tornado=6.1=py38h2bbff1b_0
123
+ - traitlets=5.2.2.post1=pyhd8ed1ab_0
124
+ - typing_extensions=4.1.1=pyh06a4308_0
125
+ - urllib3=1.26.9=py38haa95532_0
126
+ - vc=14.2=h21ff451_1
127
+ - vs2015_runtime=14.27.29016=h5e58377_2
128
+ - wcwidth=0.2.5=pyh9f0ad1d_2
129
+ - webencodings=0.5.1=py_1
130
+ - wheel=0.37.1=pyhd3eb1b0_0
131
+ - widgetsnbextension=3.6.0=py38haa244fe_0
132
+ - win_inet_pton=1.1.0=py38haa95532_0
133
+ - wincertstore=0.2=py38haa95532_2
134
+ - winpty=0.4.3=4
135
+ - xz=5.2.5=h8cc25b3_1
136
+ - zeromq=4.3.4=h0e60522_1
137
+ - zipp=3.8.0=pyhd8ed1ab_0
138
+ - zlib=1.2.12=h8cc25b3_2
139
+ - zstd=1.5.2=h19a0ad4_0
140
+ - pip:
141
+ - wxpython==4.1.1
manual_poser.ipynb ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "062014f7",
7
+ "metadata": {
8
+ "id": "breeding-extra"
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "import torch\n",
13
+ "MODEL_NAME = \"standard_float\"\n",
14
+ "DEVICE_NAME = 'cuda'\n",
15
+ "device = torch.device(DEVICE_NAME)\n",
16
+ "\n",
17
+ "def load_poser(model: str, device: torch.device):\n",
18
+ " print(\"Using the %s model.\" % model)\n",
19
+ " if model == \"standard_float\":\n",
20
+ " from tha3.poser.modes.standard_float import create_poser\n",
21
+ " return create_poser(device)\n",
22
+ " elif model == \"standard_half\":\n",
23
+ " from tha3.poser.modes.standard_half import create_poser\n",
24
+ " return create_poser(device)\n",
25
+ " elif model == \"separable_float\":\n",
26
+ " from tha3.poser.modes.separable_float import create_poser\n",
27
+ " return create_poser(device)\n",
28
+ " elif model == \"separable_half\":\n",
29
+ " from tha3.poser.modes.separable_half import create_poser\n",
30
+ " return create_poser(device)\n",
31
+ " else:\n",
32
+ " raise RuntimeError(\"Invalid model: '%s'\" % model)\n",
33
+ " \n",
34
+ "poser = load_poser(MODEL_NAME, DEVICE_NAME)\n",
35
+ "poser.get_modules();"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "breeding-extra",
42
+ "metadata": {
43
+ "id": "breeding-extra"
44
+ },
45
+ "outputs": [],
46
+ "source": [
47
+ "import PIL.Image\n",
48
+ "import io\n",
49
+ "from io import StringIO, BytesIO\n",
50
+ "import IPython.display\n",
51
+ "import numpy\n",
52
+ "import ipywidgets\n",
53
+ "import time\n",
54
+ "import threading\n",
55
+ "import torch\n",
56
+ "from tha3.util import resize_PIL_image, extract_PIL_image_from_filelike, \\\n",
57
+ " extract_pytorch_image_from_PIL_image, convert_output_image_from_torch_to_numpy\n",
58
+ "\n",
59
+ "FRAME_RATE = 30.0\n",
60
+ "\n",
61
+ "last_torch_input_image = None\n",
62
+ "torch_input_image = None\n",
63
+ "\n",
64
+ "def show_pytorch_image(pytorch_image):\n",
65
+ " output_image = pytorch_image.detach().cpu()\n",
66
+ " numpy_image = numpy.uint8(numpy.rint(convert_output_image_from_torch_to_numpy(output_image) * 255.0))\n",
67
+ " pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')\n",
68
+ " IPython.display.display(pil_image)\n",
69
+ "\n",
70
+ "upload_input_image_button = ipywidgets.FileUpload(\n",
71
+ " accept='.png',\n",
72
+ " multiple=False,\n",
73
+ " layout={\n",
74
+ " 'width': '512px'\n",
75
+ " }\n",
76
+ ")\n",
77
+ "\n",
78
+ "output_image_widget = ipywidgets.Output(\n",
79
+ " layout={\n",
80
+ " 'border': '1px solid black',\n",
81
+ " 'width': '512px',\n",
82
+ " 'height': '512px'\n",
83
+ " }\n",
84
+ ")\n",
85
+ "\n",
86
+ "eyebrow_dropdown = ipywidgets.Dropdown(\n",
87
+ " options=[\"troubled\", \"angry\", \"lowered\", \"raised\", \"happy\", \"serious\"],\n",
88
+ " value=\"troubled\",\n",
89
+ " description=\"Eyebrow:\", \n",
90
+ ")\n",
91
+ "eyebrow_left_slider = ipywidgets.FloatSlider(\n",
92
+ " value=0.0,\n",
93
+ " min=0.0,\n",
94
+ " max=1.0,\n",
95
+ " step=0.01,\n",
96
+ " description=\"Left:\",\n",
97
+ " readout=True,\n",
98
+ " readout_format=\".2f\"\n",
99
+ ")\n",
100
+ "eyebrow_right_slider = ipywidgets.FloatSlider(\n",
101
+ " value=0.0,\n",
102
+ " min=0.0,\n",
103
+ " max=1.0,\n",
104
+ " step=0.01,\n",
105
+ " description=\"Right:\",\n",
106
+ " readout=True,\n",
107
+ " readout_format=\".2f\"\n",
108
+ ")\n",
109
+ "\n",
110
+ "eye_dropdown = ipywidgets.Dropdown(\n",
111
+ " options=[\"wink\", \"happy_wink\", \"surprised\", \"relaxed\", \"unimpressed\", \"raised_lower_eyelid\"],\n",
112
+ " value=\"wink\",\n",
113
+ " description=\"Eye:\", \n",
114
+ ")\n",
115
+ "eye_left_slider = ipywidgets.FloatSlider(\n",
116
+ " value=0.0,\n",
117
+ " min=0.0,\n",
118
+ " max=1.0,\n",
119
+ " step=0.01,\n",
120
+ " description=\"Left:\",\n",
121
+ " readout=True,\n",
122
+ " readout_format=\".2f\"\n",
123
+ ")\n",
124
+ "eye_right_slider = ipywidgets.FloatSlider(\n",
125
+ " value=0.0,\n",
126
+ " min=0.0,\n",
127
+ " max=1.0,\n",
128
+ " step=0.01,\n",
129
+ " description=\"Right:\",\n",
130
+ " readout=True,\n",
131
+ " readout_format=\".2f\"\n",
132
+ ")\n",
133
+ "\n",
134
+ "mouth_dropdown = ipywidgets.Dropdown(\n",
135
+ " options=[\"aaa\", \"iii\", \"uuu\", \"eee\", \"ooo\", \"delta\", \"lowered_corner\", \"raised_corner\", \"smirk\"],\n",
136
+ " value=\"aaa\",\n",
137
+ " description=\"Mouth:\", \n",
138
+ ")\n",
139
+ "mouth_left_slider = ipywidgets.FloatSlider(\n",
140
+ " value=0.0,\n",
141
+ " min=0.0,\n",
142
+ " max=1.0,\n",
143
+ " step=0.01,\n",
144
+ " description=\"Value:\",\n",
145
+ " readout=True,\n",
146
+ " readout_format=\".2f\"\n",
147
+ ")\n",
148
+ "mouth_right_slider = ipywidgets.FloatSlider(\n",
149
+ " value=0.0,\n",
150
+ " min=0.0,\n",
151
+ " max=1.0,\n",
152
+ " step=0.01,\n",
153
+ " description=\" \",\n",
154
+ " readout=True,\n",
155
+ " readout_format=\".2f\",\n",
156
+ " disabled=True,\n",
157
+ ")\n",
158
+ "\n",
159
+ "def update_mouth_sliders(change):\n",
160
+ " if mouth_dropdown.value == \"lowered_corner\" or mouth_dropdown.value == \"raised_corner\":\n",
161
+ " mouth_left_slider.description = \"Left:\"\n",
162
+ " mouth_right_slider.description = \"Right:\"\n",
163
+ " mouth_right_slider.disabled = False\n",
164
+ " else:\n",
165
+ " mouth_left_slider.description = \"Value:\"\n",
166
+ " mouth_right_slider.description = \" \"\n",
167
+ " mouth_right_slider.disabled = True\n",
168
+ "\n",
169
+ "mouth_dropdown.observe(update_mouth_sliders, names='value')\n",
170
+ "\n",
171
+ "iris_small_left_slider = ipywidgets.FloatSlider(\n",
172
+ " value=0.0,\n",
173
+ " min=0.0,\n",
174
+ " max=1.0,\n",
175
+ " step=0.01,\n",
176
+ " description=\"Left:\",\n",
177
+ " readout=True,\n",
178
+ " readout_format=\".2f\"\n",
179
+ ")\n",
180
+ "iris_small_right_slider = ipywidgets.FloatSlider(\n",
181
+ " value=0.0,\n",
182
+ " min=0.0,\n",
183
+ " max=1.0,\n",
184
+ " step=0.01,\n",
185
+ " description=\"Right:\",\n",
186
+ " readout=True,\n",
187
+ " readout_format=\".2f\", \n",
188
+ ")\n",
189
+ "iris_rotation_x_slider = ipywidgets.FloatSlider(\n",
190
+ " value=0.0,\n",
191
+ " min=-1.0,\n",
192
+ " max=1.0,\n",
193
+ " step=0.01,\n",
194
+ " description=\"X-axis:\",\n",
195
+ " readout=True,\n",
196
+ " readout_format=\".2f\"\n",
197
+ ")\n",
198
+ "iris_rotation_y_slider = ipywidgets.FloatSlider(\n",
199
+ " value=0.0,\n",
200
+ " min=-1.0,\n",
201
+ " max=1.0,\n",
202
+ " step=0.01,\n",
203
+ " description=\"Y-axis:\",\n",
204
+ " readout=True,\n",
205
+ " readout_format=\".2f\", \n",
206
+ ")\n",
207
+ "\n",
208
+ "head_x_slider = ipywidgets.FloatSlider(\n",
209
+ " value=0.0,\n",
210
+ " min=-1.0,\n",
211
+ " max=1.0,\n",
212
+ " step=0.01,\n",
213
+ " description=\"X-axis:\",\n",
214
+ " readout=True,\n",
215
+ " readout_format=\".2f\"\n",
216
+ ")\n",
217
+ "head_y_slider = ipywidgets.FloatSlider(\n",
218
+ " value=0.0,\n",
219
+ " min=-1.0,\n",
220
+ " max=1.0,\n",
221
+ " step=0.01,\n",
222
+ " description=\"Y-axis:\",\n",
223
+ " readout=True,\n",
224
+ " readout_format=\".2f\", \n",
225
+ ")\n",
226
+ "neck_z_slider = ipywidgets.FloatSlider(\n",
227
+ " value=0.0,\n",
228
+ " min=-1.0,\n",
229
+ " max=1.0,\n",
230
+ " step=0.01,\n",
231
+ " description=\"Z-axis:\",\n",
232
+ " readout=True,\n",
233
+ " readout_format=\".2f\", \n",
234
+ ")\n",
235
+ "body_y_slider = ipywidgets.FloatSlider(\n",
236
+ " value=0.0,\n",
237
+ " min=-1.0,\n",
238
+ " max=1.0,\n",
239
+ " step=0.01,\n",
240
+ " description=\"Y-axis rotation:\",\n",
241
+ " readout=True,\n",
242
+ " readout_format=\".2f\", \n",
243
+ ")\n",
244
+ "body_z_slider = ipywidgets.FloatSlider(\n",
245
+ " value=0.0,\n",
246
+ " min=-1.0,\n",
247
+ " max=1.0,\n",
248
+ " step=0.01,\n",
249
+ " description=\"Z-axis rotation:\",\n",
250
+ " readout=True,\n",
251
+ " readout_format=\".2f\", \n",
252
+ ")\n",
253
+ "breathing_slider = ipywidgets.FloatSlider(\n",
254
+ " value=0.0,\n",
255
+ " min=0.0,\n",
256
+ " max=1.0,\n",
257
+ " step=0.01,\n",
258
+ " description=\"Breathing:\",\n",
259
+ " readout=True,\n",
260
+ " readout_format=\".2f\", \n",
261
+ ")\n",
262
+ "\n",
263
+ "\n",
264
+ "control_panel = ipywidgets.VBox([\n",
265
+ " eyebrow_dropdown,\n",
266
+ " eyebrow_left_slider,\n",
267
+ " eyebrow_right_slider,\n",
268
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
269
+ " eye_dropdown,\n",
270
+ " eye_left_slider,\n",
271
+ " eye_right_slider,\n",
272
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
273
+ " mouth_dropdown,\n",
274
+ " mouth_left_slider,\n",
275
+ " mouth_right_slider,\n",
276
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
277
+ " ipywidgets.HTML(value=\"<center><b>Iris Shrinkage</b></center>\"),\n",
278
+ " iris_small_left_slider,\n",
279
+ " iris_small_right_slider,\n",
280
+ " ipywidgets.HTML(value=\"<center><b>Iris Rotation</b></center>\"),\n",
281
+ " iris_rotation_x_slider,\n",
282
+ " iris_rotation_y_slider,\n",
283
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
284
+ " ipywidgets.HTML(value=\"<center><b>Head Rotation</b></center>\"),\n",
285
+ " head_x_slider,\n",
286
+ " head_y_slider,\n",
287
+ " neck_z_slider,\n",
288
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
289
+ " ipywidgets.HTML(value=\"<center><b>Body Rotation</b></center>\"),\n",
290
+ " body_y_slider,\n",
291
+ " body_z_slider,\n",
292
+ " ipywidgets.HTML(value=\"<hr>\"),\n",
293
+ " ipywidgets.HTML(value=\"<center><b>Breathing</b></center>\"),\n",
294
+ " breathing_slider,\n",
295
+ "])\n",
296
+ "\n",
297
+ "controls = ipywidgets.HBox([\n",
298
+ " ipywidgets.VBox([\n",
299
+ " output_image_widget, \n",
300
+ " upload_input_image_button\n",
301
+ " ]),\n",
302
+ " control_panel,\n",
303
+ "])\n",
304
+ "\n",
305
+ "from tha3.poser.modes.pose_parameters import get_pose_parameters\n",
306
+ "pose_parameters = get_pose_parameters()\n",
307
+ "pose_size = poser.get_num_parameters()\n",
308
+ "last_pose = torch.zeros(1, pose_size, dtype=poser.get_dtype()).to(device)\n",
309
+ "\n",
310
+ "iris_small_left_index = pose_parameters.get_parameter_index(\"iris_small_left\")\n",
311
+ "iris_small_right_index = pose_parameters.get_parameter_index(\"iris_small_right\")\n",
312
+ "iris_rotation_x_index = pose_parameters.get_parameter_index(\"iris_rotation_x\")\n",
313
+ "iris_rotation_y_index = pose_parameters.get_parameter_index(\"iris_rotation_y\")\n",
314
+ "head_x_index = pose_parameters.get_parameter_index(\"head_x\")\n",
315
+ "head_y_index = pose_parameters.get_parameter_index(\"head_y\")\n",
316
+ "neck_z_index = pose_parameters.get_parameter_index(\"neck_z\")\n",
317
+ "body_y_index = pose_parameters.get_parameter_index(\"body_y\")\n",
318
+ "body_z_index = pose_parameters.get_parameter_index(\"body_z\")\n",
319
+ "breathing_index = pose_parameters.get_parameter_index(\"breathing\")\n",
320
+ "\n",
321
+ "def get_pose():\n",
322
+ " pose = torch.zeros(1, pose_size, dtype=poser.get_dtype())\n",
323
+ "\n",
324
+ " eyebrow_name = f\"eyebrow_{eyebrow_dropdown.value}\"\n",
325
+ " eyebrow_left_index = pose_parameters.get_parameter_index(f\"{eyebrow_name}_left\")\n",
326
+ " eyebrow_right_index = pose_parameters.get_parameter_index(f\"{eyebrow_name}_right\")\n",
327
+ " pose[0, eyebrow_left_index] = eyebrow_left_slider.value\n",
328
+ " pose[0, eyebrow_right_index] = eyebrow_right_slider.value\n",
329
+ "\n",
330
+ " eye_name = f\"eye_{eye_dropdown.value}\"\n",
331
+ " eye_left_index = pose_parameters.get_parameter_index(f\"{eye_name}_left\")\n",
332
+ " eye_right_index = pose_parameters.get_parameter_index(f\"{eye_name}_right\")\n",
333
+ " pose[0, eye_left_index] = eye_left_slider.value\n",
334
+ " pose[0, eye_right_index] = eye_right_slider.value\n",
335
+ "\n",
336
+ " mouth_name = f\"mouth_{mouth_dropdown.value}\"\n",
337
+ " if mouth_name == \"mouth_lowered_corner\" or mouth_name == \"mouth_raised_corner\":\n",
338
+ " mouth_left_index = pose_parameters.get_parameter_index(f\"{mouth_name}_left\")\n",
339
+ " mouth_right_index = pose_parameters.get_parameter_index(f\"{mouth_name}_right\")\n",
340
+ " pose[0, mouth_left_index] = mouth_left_slider.value\n",
341
+ " pose[0, mouth_right_index] = mouth_right_slider.value\n",
342
+ " else:\n",
343
+ " mouth_index = pose_parameters.get_parameter_index(mouth_name)\n",
344
+ " pose[0, mouth_index] = mouth_left_slider.value\n",
345
+ "\n",
346
+ " pose[0, iris_small_left_index] = iris_small_left_slider.value\n",
347
+ " pose[0, iris_small_right_index] = iris_small_right_slider.value\n",
348
+ " pose[0, iris_rotation_x_index] = iris_rotation_x_slider.value\n",
349
+ " pose[0, iris_rotation_y_index] = iris_rotation_y_slider.value\n",
350
+ " pose[0, head_x_index] = head_x_slider.value\n",
351
+ " pose[0, head_y_index] = head_y_slider.value\n",
352
+ " pose[0, neck_z_index] = neck_z_slider.value\n",
353
+ " pose[0, body_y_index] = body_y_slider.value\n",
354
+ " pose[0, body_z_index] = body_z_slider.value\n",
355
+ " pose[0, breathing_index] = breathing_slider.value\n",
356
+ "\n",
357
+ " return pose.to(device)\n",
358
+ "\n",
359
+ "display(controls)\n",
360
+ "\n",
361
+ "def update(change):\n",
362
+ " global last_pose\n",
363
+ " global last_torch_input_image\n",
364
+ "\n",
365
+ " if torch_input_image is None:\n",
366
+ " return\n",
367
+ "\n",
368
+ " needs_update = False\n",
369
+ " if last_torch_input_image is None:\n",
370
+ " needs_update = True \n",
371
+ " else:\n",
372
+ " if (torch_input_image - last_torch_input_image).abs().max().item() > 0:\n",
373
+ " needs_update = True \n",
374
+ "\n",
375
+ " pose = get_pose()\n",
376
+ " if (pose - last_pose).abs().max().item() > 0:\n",
377
+ " needs_update = True\n",
378
+ "\n",
379
+ " if not needs_update:\n",
380
+ " return\n",
381
+ "\n",
382
+ " output_image = poser.pose(torch_input_image, pose)[0]\n",
383
+ " with output_image_widget:\n",
384
+ " output_image_widget.clear_output(wait=True)\n",
385
+ " show_pytorch_image(output_image) \n",
386
+ "\n",
387
+ " last_torch_input_image = torch_input_image\n",
388
+ " last_pose = pose\n",
389
+ "\n",
390
+ "def upload_image(change):\n",
391
+ " global torch_input_image\n",
392
+ " for name, file_info in upload_input_image_button.value.items():\n",
393
+ " content = io.BytesIO(file_info['content'])\n",
394
+ " if content is not None:\n",
395
+ " pil_image = resize_PIL_image(extract_PIL_image_from_filelike(content), size=(512,512))\n",
396
+ " w, h = pil_image.size\n",
397
+ " if pil_image.mode != 'RGBA':\n",
398
+ " with output_image_widget:\n",
399
+ " torch_input_image = None\n",
400
+ " output_image_widget.clear_output(wait=True)\n",
401
+ " display(ipywidgets.HTML(\"Image must have an alpha channel!!!\"))\n",
402
+ " else:\n",
403
+ " torch_input_image = extract_pytorch_image_from_PIL_image(pil_image).to(device)\n",
404
+ " if poser.get_dtype() == torch.half:\n",
405
+ " torch_input_image = torch_input_image.half()\n",
406
+ " update(None)\n",
407
+ "\n",
408
+ "upload_input_image_button.observe(upload_image, names='value')\n",
409
+ "eyebrow_dropdown.observe(update, 'value')\n",
410
+ "eyebrow_left_slider.observe(update, 'value')\n",
411
+ "eyebrow_right_slider.observe(update, 'value')\n",
412
+ "eye_dropdown.observe(update, 'value')\n",
413
+ "eye_left_slider.observe(update, 'value')\n",
414
+ "eye_right_slider.observe(update, 'value')\n",
415
+ "mouth_dropdown.observe(update, 'value')\n",
416
+ "mouth_left_slider.observe(update, 'value')\n",
417
+ "mouth_right_slider.observe(update, 'value')\n",
418
+ "iris_small_left_slider.observe(update, 'value')\n",
419
+ "iris_small_right_slider.observe(update, 'value')\n",
420
+ "iris_rotation_x_slider.observe(update, 'value')\n",
421
+ "iris_rotation_y_slider.observe(update, 'value')\n",
422
+ "head_x_slider.observe(update, 'value')\n",
423
+ "head_y_slider.observe(update, 'value')\n",
424
+ "neck_z_slider.observe(update, 'value')\n",
425
+ "body_y_slider.observe(update, 'value')\n",
426
+ "body_z_slider.observe(update, 'value')\n",
427
+ "breathing_slider.observe(update, 'value')"
428
+ ]
429
+ }
430
+ ],
431
+ "metadata": {
432
+ "accelerator": "GPU",
433
+ "colab": {
434
+ "name": "tha3.ipynb",
435
+ "provenance": []
436
+ },
437
+ "interpreter": {
438
+ "hash": "684906ad716c90e6f3397644b72c2a23821e93080f6b0264e4cd74aee22032ce"
439
+ },
440
+ "kernelspec": {
441
+ "display_name": "Python 3 (ipykernel)",
442
+ "language": "python",
443
+ "name": "python3"
444
+ },
445
+ "language_info": {
446
+ "codemirror_mode": {
447
+ "name": "ipython",
448
+ "version": 3
449
+ },
450
+ "file_extension": ".py",
451
+ "mimetype": "text/x-python",
452
+ "name": "python",
453
+ "nbconvert_exporter": "python",
454
+ "pygments_lexer": "ipython3",
455
+ "version": "3.8.13"
456
+ }
457
+ },
458
+ "nbformat": 4,
459
+ "nbformat_minor": 5
460
+ }
tha3/__init__.py ADDED
File without changes
tha3/app/__init__.py ADDED
File without changes
tha3/app/ifacialmocap_puppeteer.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import socket
4
+ import sys
5
+ import threading
6
+ import time
7
+ from typing import Optional
8
+
9
+ sys.path.append(os.getcwd())
10
+
11
+ from tha3.mocap.ifacialmocap_pose import create_default_ifacialmocap_pose
12
+ from tha3.mocap.ifacialmocap_v2 import IFACIALMOCAP_PORT, IFACIALMOCAP_START_STRING, parse_ifacialmocap_v2_pose, \
13
+ parse_ifacialmocap_v1_pose
14
+ from tha3.poser.modes.load_poser import load_poser
15
+
16
+ import torch
17
+ import wx
18
+
19
+ from tha3.poser.poser import Poser
20
+ from tha3.mocap.ifacialmocap_constants import *
21
+ from tha3.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter
22
+ from tha3.util import torch_linear_to_srgb, resize_PIL_image, extract_PIL_image_from_filelike, \
23
+ extract_pytorch_image_from_PIL_image
24
+
25
+
26
+ def convert_linear_to_srgb(image: torch.Tensor) -> torch.Tensor:
27
+ rgb_image = torch_linear_to_srgb(image[0:3, :, :])
28
+ return torch.cat([rgb_image, image[3:4, :, :]], dim=0)
29
+
30
+
31
+ class FpsStatistics:
32
+ def __init__(self):
33
+ self.count = 100
34
+ self.fps = []
35
+
36
+ def add_fps(self, fps):
37
+ self.fps.append(fps)
38
+ while len(self.fps) > self.count:
39
+ del self.fps[0]
40
+
41
+ def get_average_fps(self):
42
+ if len(self.fps) == 0:
43
+ return 0.0
44
+ else:
45
+ return sum(self.fps) / len(self.fps)
46
+
47
+
48
+ class MainFrame(wx.Frame):
49
+ def __init__(self, poser: Poser, pose_converter: IFacialMocapPoseConverter, device: torch.device):
50
+ super().__init__(None, wx.ID_ANY, "iFacialMocap Puppeteer (Marigold)")
51
+ self.pose_converter = pose_converter
52
+ self.poser = poser
53
+ self.device = device
54
+
55
+
56
+ self.ifacialmocap_pose = create_default_ifacialmocap_pose()
57
+ self.source_image_bitmap = wx.Bitmap(self.poser.get_image_size(), self.poser.get_image_size())
58
+ self.result_image_bitmap = wx.Bitmap(self.poser.get_image_size(), self.poser.get_image_size())
59
+ self.wx_source_image = None
60
+ self.torch_source_image = None
61
+ self.last_pose = None
62
+ self.fps_statistics = FpsStatistics()
63
+ self.last_update_time = None
64
+
65
+ self.create_receiving_socket()
66
+ self.create_ui()
67
+ self.create_timers()
68
+ self.Bind(wx.EVT_CLOSE, self.on_close)
69
+
70
+ self.update_source_image_bitmap()
71
+ self.update_result_image_bitmap()
72
+
73
+ def create_receiving_socket(self):
74
+ self.receiving_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
75
+ self.receiving_socket.bind(("", IFACIALMOCAP_PORT))
76
+ self.receiving_socket.setblocking(False)
77
+
78
+ def create_timers(self):
79
+ self.capture_timer = wx.Timer(self, wx.ID_ANY)
80
+ self.Bind(wx.EVT_TIMER, self.update_capture_panel, id=self.capture_timer.GetId())
81
+ self.animation_timer = wx.Timer(self, wx.ID_ANY)
82
+ self.Bind(wx.EVT_TIMER, self.update_result_image_bitmap, id=self.animation_timer.GetId())
83
+
84
+ def on_close(self, event: wx.Event):
85
+ # Stop the timers
86
+ self.animation_timer.Stop()
87
+ self.capture_timer.Stop()
88
+
89
+ # Close receiving socket
90
+ self.receiving_socket.close()
91
+
92
+ # Destroy the windows
93
+ self.Destroy()
94
+ event.Skip()
95
+
96
+ def on_start_capture(self, event: wx.Event):
97
+ capture_device_ip_address = self.capture_device_ip_text_ctrl.GetValue()
98
+ out_socket = None
99
+ try:
100
+ address = (capture_device_ip_address, IFACIALMOCAP_PORT)
101
+ out_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
102
+ out_socket.sendto(IFACIALMOCAP_START_STRING, address)
103
+ except Exception as e:
104
+ message_dialog = wx.MessageDialog(self, str(e), "Error!", wx.OK)
105
+ message_dialog.ShowModal()
106
+ message_dialog.Destroy()
107
+ finally:
108
+ if out_socket is not None:
109
+ out_socket.close()
110
+
111
+ def read_ifacialmocap_pose(self):
112
+ if not self.animation_timer.IsRunning():
113
+ return self.ifacialmocap_pose
114
+ socket_bytes = None
115
+ while True:
116
+ try:
117
+ socket_bytes = self.receiving_socket.recv(8192)
118
+ except socket.error as e:
119
+ break
120
+ if socket_bytes is not None:
121
+ socket_string = socket_bytes.decode("utf-8")
122
+ self.ifacialmocap_pose = parse_ifacialmocap_v2_pose(socket_string)
123
+ return self.ifacialmocap_pose
124
+
125
+ def on_erase_background(self, event: wx.Event):
126
+ pass
127
+
128
+ def create_animation_panel(self, parent):
129
+ self.animation_panel = wx.Panel(parent, style=wx.RAISED_BORDER)
130
+ self.animation_panel_sizer = wx.BoxSizer(wx.HORIZONTAL)
131
+ self.animation_panel.SetSizer(self.animation_panel_sizer)
132
+ self.animation_panel.SetAutoLayout(1)
133
+
134
+ image_size = self.poser.get_image_size()
135
+
136
+ if True:
137
+ self.input_panel = wx.Panel(self.animation_panel, size=(image_size, image_size + 128),
138
+ style=wx.SIMPLE_BORDER)
139
+ self.input_panel_sizer = wx.BoxSizer(wx.VERTICAL)
140
+ self.input_panel.SetSizer(self.input_panel_sizer)
141
+ self.input_panel.SetAutoLayout(1)
142
+ self.animation_panel_sizer.Add(self.input_panel, 0, wx.FIXED_MINSIZE)
143
+
144
+ self.source_image_panel = wx.Panel(self.input_panel, size=(image_size, image_size), style=wx.SIMPLE_BORDER)
145
+ self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)
146
+ self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
147
+ self.input_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)
148
+
149
+ self.load_image_button = wx.Button(self.input_panel, wx.ID_ANY, "Load Image")
150
+ self.input_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND)
151
+ self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image)
152
+
153
+ self.input_panel_sizer.Fit(self.input_panel)
154
+
155
+ if True:
156
+ self.pose_converter.init_pose_converter_panel(self.animation_panel)
157
+
158
+ if True:
159
+ self.animation_left_panel = wx.Panel(self.animation_panel, style=wx.SIMPLE_BORDER)
160
+ self.animation_left_panel_sizer = wx.BoxSizer(wx.VERTICAL)
161
+ self.animation_left_panel.SetSizer(self.animation_left_panel_sizer)
162
+ self.animation_left_panel.SetAutoLayout(1)
163
+ self.animation_panel_sizer.Add(self.animation_left_panel, 0, wx.EXPAND)
164
+
165
+ self.result_image_panel = wx.Panel(self.animation_left_panel, size=(image_size, image_size),
166
+ style=wx.SIMPLE_BORDER)
167
+ self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)
168
+ self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
169
+ self.animation_left_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)
170
+
171
+ separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5))
172
+ self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND)
173
+
174
+ background_text = wx.StaticText(self.animation_left_panel, label="--- Background ---",
175
+ style=wx.ALIGN_CENTER)
176
+ self.animation_left_panel_sizer.Add(background_text, 0, wx.EXPAND)
177
+
178
+ self.output_background_choice = wx.Choice(
179
+ self.animation_left_panel,
180
+ choices=[
181
+ "TRANSPARENT",
182
+ "GREEN",
183
+ "BLUE",
184
+ "BLACK",
185
+ "WHITE"
186
+ ])
187
+ self.output_background_choice.SetSelection(0)
188
+ self.animation_left_panel_sizer.Add(self.output_background_choice, 0, wx.EXPAND)
189
+
190
+ separator = wx.StaticLine(self.animation_left_panel, -1, size=(256, 5))
191
+ self.animation_left_panel_sizer.Add(separator, 0, wx.EXPAND)
192
+
193
+ self.fps_text = wx.StaticText(self.animation_left_panel, label="")
194
+ self.animation_left_panel_sizer.Add(self.fps_text, wx.SizerFlags().Border())
195
+
196
+ self.animation_left_panel_sizer.Fit(self.animation_left_panel)
197
+
198
+ self.animation_panel_sizer.Fit(self.animation_panel)
199
+
200
+ def create_ui(self):
201
+ self.main_sizer = wx.BoxSizer(wx.VERTICAL)
202
+ self.SetSizer(self.main_sizer)
203
+ self.SetAutoLayout(1)
204
+
205
+ self.capture_pose_lock = threading.Lock()
206
+
207
+ self.create_connection_panel(self)
208
+ self.main_sizer.Add(self.connection_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))
209
+
210
+ self.create_animation_panel(self)
211
+ self.main_sizer.Add(self.animation_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))
212
+
213
+ self.create_capture_panel(self)
214
+ self.main_sizer.Add(self.capture_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))
215
+
216
+ self.main_sizer.Fit(self)
217
+
218
+ def create_connection_panel(self, parent):
219
+ self.connection_panel = wx.Panel(parent, style=wx.RAISED_BORDER)
220
+ self.connection_panel_sizer = wx.BoxSizer(wx.HORIZONTAL)
221
+ self.connection_panel.SetSizer(self.connection_panel_sizer)
222
+ self.connection_panel.SetAutoLayout(1)
223
+
224
+ capture_device_ip_text = wx.StaticText(self.connection_panel, label="Capture Device IP:", style=wx.ALIGN_RIGHT)
225
+ self.connection_panel_sizer.Add(capture_device_ip_text, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 3))
226
+
227
+ self.capture_device_ip_text_ctrl = wx.TextCtrl(self.connection_panel, value="192.168.0.1")
228
+ self.connection_panel_sizer.Add(self.capture_device_ip_text_ctrl, wx.SizerFlags(1).Expand().Border(wx.ALL, 3))
229
+
230
+ self.start_capture_button = wx.Button(self.connection_panel, label="START CAPTURE!")
231
+ self.connection_panel_sizer.Add(self.start_capture_button, wx.SizerFlags(0).FixedMinSize().Border(wx.ALL, 3))
232
+ self.start_capture_button.Bind(wx.EVT_BUTTON, self.on_start_capture)
233
+
234
+ def create_capture_panel(self, parent):
235
+ self.capture_panel = wx.Panel(parent, style=wx.RAISED_BORDER)
236
+ self.capture_panel_sizer = wx.FlexGridSizer(cols=5)
237
+ for i in range(5):
238
+ self.capture_panel_sizer.AddGrowableCol(i)
239
+ self.capture_panel.SetSizer(self.capture_panel_sizer)
240
+ self.capture_panel.SetAutoLayout(1)
241
+
242
+ self.rotation_labels = {}
243
+ self.rotation_value_labels = {}
244
+ rotation_column_0 = self.create_rotation_column(self.capture_panel, RIGHT_EYE_BONE_ROTATIONS)
245
+ self.capture_panel_sizer.Add(rotation_column_0, wx.SizerFlags(0).Expand().Border(wx.ALL, 3))
246
+ rotation_column_1 = self.create_rotation_column(self.capture_panel, LEFT_EYE_BONE_ROTATIONS)
247
+ self.capture_panel_sizer.Add(rotation_column_1, wx.SizerFlags(0).Expand().Border(wx.ALL, 3))
248
+ rotation_column_2 = self.create_rotation_column(self.capture_panel, HEAD_BONE_ROTATIONS)
249
+ self.capture_panel_sizer.Add(rotation_column_2, wx.SizerFlags(0).Expand().Border(wx.ALL, 3))
250
+
251
+ def create_rotation_column(self, parent, rotation_names):
252
+ column_panel = wx.Panel(parent, style=wx.SIMPLE_BORDER)
253
+ column_panel_sizer = wx.FlexGridSizer(cols=2)
254
+ column_panel_sizer.AddGrowableCol(1)
255
+ column_panel.SetSizer(column_panel_sizer)
256
+ column_panel.SetAutoLayout(1)
257
+
258
+ for rotation_name in rotation_names:
259
+ self.rotation_labels[rotation_name] = wx.StaticText(
260
+ column_panel, label=rotation_name, style=wx.ALIGN_RIGHT)
261
+ column_panel_sizer.Add(self.rotation_labels[rotation_name],
262
+ wx.SizerFlags(1).Expand().Border(wx.ALL, 3))
263
+
264
+ self.rotation_value_labels[rotation_name] = wx.TextCtrl(
265
+ column_panel, style=wx.TE_RIGHT)
266
+ self.rotation_value_labels[rotation_name].SetValue("0.00")
267
+ self.rotation_value_labels[rotation_name].Disable()
268
+ column_panel_sizer.Add(self.rotation_value_labels[rotation_name],
269
+ wx.SizerFlags(1).Expand().Border(wx.ALL, 3))
270
+
271
+ column_panel.GetSizer().Fit(column_panel)
272
+ return column_panel
273
+
274
+ def paint_capture_panel(self, event: wx.Event):
275
+ self.update_capture_panel(event)
276
+
277
+ def update_capture_panel(self, event: wx.Event):
278
+ data = self.ifacialmocap_pose
279
+ for rotation_name in ROTATION_NAMES:
280
+ value = data[rotation_name]
281
+ self.rotation_value_labels[rotation_name].SetValue("%0.2f" % value)
282
+
283
+ @staticmethod
284
+ def convert_to_100(x):
285
+ return int(max(0.0, min(1.0, x)) * 100)
286
+
287
+ def paint_source_image_panel(self, event: wx.Event):
288
+ wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)
289
+
290
+ def update_source_image_bitmap(self):
291
+ dc = wx.MemoryDC()
292
+ dc.SelectObject(self.source_image_bitmap)
293
+ if self.wx_source_image is None:
294
+ self.draw_nothing_yet_string(dc)
295
+ else:
296
+ dc.Clear()
297
+ dc.DrawBitmap(self.wx_source_image, 0, 0, True)
298
+ del dc
299
+
300
+ def draw_nothing_yet_string(self, dc):
301
+ dc.Clear()
302
+ font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))
303
+ dc.SetFont(font)
304
+ w, h = dc.GetTextExtent("Nothing yet!")
305
+ dc.DrawText("Nothing yet!", (self.poser.get_image_size() - w) // 2, (self.poser.get_image_size() - h) // 2)
306
+
307
+ def paint_result_image_panel(self, event: wx.Event):
308
+ wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)
309
+
310
+ def update_result_image_bitmap(self, event: Optional[wx.Event] = None):
311
+ ifacialmocap_pose = self.read_ifacialmocap_pose()
312
+ current_pose = self.pose_converter.convert(ifacialmocap_pose)
313
+ if self.last_pose is not None and self.last_pose == current_pose:
314
+ return
315
+ self.last_pose = current_pose
316
+
317
+ if self.torch_source_image is None:
318
+ dc = wx.MemoryDC()
319
+ dc.SelectObject(self.result_image_bitmap)
320
+ self.draw_nothing_yet_string(dc)
321
+ del dc
322
+ return
323
+
324
+ pose = torch.tensor(current_pose, device=self.device, dtype=self.poser.get_dtype())
325
+
326
+ with torch.no_grad():
327
+ output_image = self.poser.pose(self.torch_source_image, pose)[0].float()
328
+ output_image = convert_linear_to_srgb((output_image + 1.0) / 2.0)
329
+
330
+ background_choice = self.output_background_choice.GetSelection()
331
+ if background_choice == 0:
332
+ pass
333
+ else:
334
+ background = torch.zeros(4, output_image.shape[1], output_image.shape[2], device=self.device)
335
+ background[3, :, :] = 1.0
336
+ if background_choice == 1:
337
+ background[1, :, :] = 1.0
338
+ output_image = self.blend_with_background(output_image, background)
339
+ elif background_choice == 2:
340
+ background[2, :, :] = 1.0
341
+ output_image = self.blend_with_background(output_image, background)
342
+ elif background_choice == 3:
343
+ output_image = self.blend_with_background(output_image, background)
344
+ else:
345
+ background[0:3, :, :] = 1.0
346
+ output_image = self.blend_with_background(output_image, background)
347
+
348
+ c, h, w = output_image.shape
349
+ output_image = 255.0 * torch.transpose(output_image.reshape(c, h * w), 0, 1).reshape(h, w, c)
350
+ output_image = output_image.byte()
351
+
352
+ numpy_image = output_image.detach().cpu().numpy()
353
+ wx_image = wx.ImageFromBuffer(numpy_image.shape[0],
354
+ numpy_image.shape[1],
355
+ numpy_image[:, :, 0:3].tobytes(),
356
+ numpy_image[:, :, 3].tobytes())
357
+ wx_bitmap = wx_image.ConvertToBitmap()
358
+
359
+ dc = wx.MemoryDC()
360
+ dc.SelectObject(self.result_image_bitmap)
361
+ dc.Clear()
362
+ dc.DrawBitmap(wx_bitmap,
363
+ (self.poser.get_image_size() - numpy_image.shape[0]) // 2,
364
+ (self.poser.get_image_size() - numpy_image.shape[1]) // 2, True)
365
+ del dc
366
+
367
+ time_now = time.time_ns()
368
+ if self.last_update_time is not None:
369
+ elapsed_time = time_now - self.last_update_time
370
+ fps = 1.0 / (elapsed_time / 10**9)
371
+ if self.torch_source_image is not None:
372
+ self.fps_statistics.add_fps(fps)
373
+ self.fps_text.SetLabelText("FPS = %0.2f" % self.fps_statistics.get_average_fps())
374
+ self.last_update_time = time_now
375
+
376
+ self.Refresh()
377
+
378
+ def blend_with_background(self, numpy_image, background):
379
+ alpha = numpy_image[3:4, :, :]
380
+ color = numpy_image[0:3, :, :]
381
+ new_color = color * alpha + (1.0 - alpha) * background[0:3, :, :]
382
+ return torch.cat([new_color, background[3:4, :, :]], dim=0)
383
+
384
+ def load_image(self, event: wx.Event):
385
+ dir_name = "data/images"
386
+ file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_OPEN)
387
+ if file_dialog.ShowModal() == wx.ID_OK:
388
+ image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
389
+ try:
390
+ pil_image = resize_PIL_image(
391
+ extract_PIL_image_from_filelike(image_file_name),
392
+ (self.poser.get_image_size(), self.poser.get_image_size()))
393
+ w, h = pil_image.size
394
+ if pil_image.mode != 'RGBA':
395
+ self.source_image_string = "Image must have alpha channel!"
396
+ self.wx_source_image = None
397
+ self.torch_source_image = None
398
+ else:
399
+ self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes())
400
+ self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image) \
401
+ .to(self.device).to(self.poser.get_dtype())
402
+ self.update_source_image_bitmap()
403
+ except:
404
+ message_dialog = wx.MessageDialog(self, "Could not load image " + image_file_name, "Poser", wx.OK)
405
+ message_dialog.ShowModal()
406
+ message_dialog.Destroy()
407
+ file_dialog.Destroy()
408
+ self.Refresh()
409
+
410
+
411
+
412
+ if __name__ == "__main__":
413
+ parser = argparse.ArgumentParser(description='Control characters with movement captured by iFacialMocap.')
414
+ parser.add_argument(
415
+ '--model',
416
+ type=str,
417
+ required=False,
418
+ default='standard_float',
419
+ choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'],
420
+ help='The model to use.')
421
+ args = parser.parse_args()
422
+
423
+ device = torch.device('cuda')
424
+ try:
425
+ poser = load_poser(args.model, device)
426
+ except RuntimeError as e:
427
+ print(e)
428
+ sys.exit()
429
+
430
+ from tha3.mocap.ifacialmocap_poser_converter_25 import create_ifacialmocap_pose_converter
431
+
432
+ pose_converter = create_ifacialmocap_pose_converter()
433
+
434
+ app = wx.App()
435
+ main_frame = MainFrame(poser, pose_converter, device)
436
+ main_frame.Show(True)
437
+ main_frame.capture_timer.Start(10)
438
+ main_frame.animation_timer.Start(10)
439
+ app.MainLoop()
tha3/app/manual_poser.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import sys
5
+ from typing import List
6
+
7
+ sys.path.append(os.getcwd())
8
+
9
+ import PIL.Image
10
+ import numpy
11
+ import torch
12
+ import wx
13
+
14
+ from tha3.poser.modes.load_poser import load_poser
15
+ from tha3.poser.poser import Poser, PoseParameterCategory, PoseParameterGroup
16
+ from tha3.util import extract_pytorch_image_from_filelike, rgba_to_numpy_image, grid_change_to_numpy_image, \
17
+ rgb_to_numpy_image, resize_PIL_image, extract_PIL_image_from_filelike, extract_pytorch_image_from_PIL_image
18
+
19
+
20
+ class MorphCategoryControlPanel(wx.Panel):
21
+ def __init__(self,
22
+ parent,
23
+ title: str,
24
+ pose_param_category: PoseParameterCategory,
25
+ param_groups: List[PoseParameterGroup]):
26
+ super().__init__(parent, style=wx.SIMPLE_BORDER)
27
+ self.pose_param_category = pose_param_category
28
+ self.sizer = wx.BoxSizer(wx.VERTICAL)
29
+ self.SetSizer(self.sizer)
30
+ self.SetAutoLayout(1)
31
+
32
+ title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER)
33
+ self.sizer.Add(title_text, 0, wx.EXPAND)
34
+
35
+ self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
36
+ self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups])
37
+ if len(self.param_groups) > 0:
38
+ self.choice.SetSelection(0)
39
+ self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated)
40
+ self.sizer.Add(self.choice, 0, wx.EXPAND)
41
+
42
+ self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
43
+ self.sizer.Add(self.left_slider, 0, wx.EXPAND)
44
+
45
+ self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
46
+ self.sizer.Add(self.right_slider, 0, wx.EXPAND)
47
+
48
+ self.checkbox = wx.CheckBox(self, label="Show")
49
+ self.checkbox.SetValue(True)
50
+ self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER)
51
+
52
+ self.update_ui()
53
+
54
+ self.sizer.Fit(self)
55
+
56
+ def update_ui(self):
57
+ param_group = self.param_groups[self.choice.GetSelection()]
58
+ if param_group.is_discrete():
59
+ self.left_slider.Enable(False)
60
+ self.right_slider.Enable(False)
61
+ self.checkbox.Enable(True)
62
+ elif param_group.get_arity() == 1:
63
+ self.left_slider.Enable(True)
64
+ self.right_slider.Enable(False)
65
+ self.checkbox.Enable(False)
66
+ else:
67
+ self.left_slider.Enable(True)
68
+ self.right_slider.Enable(True)
69
+ self.checkbox.Enable(False)
70
+
71
+ def on_choice_updated(self, event: wx.Event):
72
+ param_group = self.param_groups[self.choice.GetSelection()]
73
+ if param_group.is_discrete():
74
+ self.checkbox.SetValue(True)
75
+ self.update_ui()
76
+
77
+ def set_param_value(self, pose: List[float]):
78
+ if len(self.param_groups) == 0:
79
+ return
80
+ selected_morph_index = self.choice.GetSelection()
81
+ param_group = self.param_groups[selected_morph_index]
82
+ param_index = param_group.get_parameter_index()
83
+ if param_group.is_discrete():
84
+ if self.checkbox.GetValue():
85
+ for i in range(param_group.get_arity()):
86
+ pose[param_index + i] = 1.0
87
+ else:
88
+ param_range = param_group.get_range()
89
+ alpha = (self.left_slider.GetValue() + 1000) / 2000.0
90
+ pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
91
+ if param_group.get_arity() == 2:
92
+ alpha = (self.right_slider.GetValue() + 1000) / 2000.0
93
+ pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha
94
+
95
+
96
+ class SimpleParamGroupsControlPanel(wx.Panel):
97
+ def __init__(self, parent,
98
+ pose_param_category: PoseParameterCategory,
99
+ param_groups: List[PoseParameterGroup]):
100
+ super().__init__(parent, style=wx.SIMPLE_BORDER)
101
+ self.sizer = wx.BoxSizer(wx.VERTICAL)
102
+ self.SetSizer(self.sizer)
103
+ self.SetAutoLayout(1)
104
+
105
+ self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
106
+ for param_group in self.param_groups:
107
+ assert not param_group.is_discrete()
108
+ assert param_group.get_arity() == 1
109
+
110
+ self.sliders = []
111
+ for param_group in self.param_groups:
112
+ static_text = wx.StaticText(
113
+ self,
114
+ label=" ------------ %s ------------ " % param_group.get_group_name(), style=wx.ALIGN_CENTER)
115
+ self.sizer.Add(static_text, 0, wx.EXPAND)
116
+ range = param_group.get_range()
117
+ min_value = int(range[0] * 1000)
118
+ max_value = int(range[1] * 1000)
119
+ slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL)
120
+ self.sizer.Add(slider, 0, wx.EXPAND)
121
+ self.sliders.append(slider)
122
+
123
+ self.sizer.Fit(self)
124
+
125
+ def set_param_value(self, pose: List[float]):
126
+ if len(self.param_groups) == 0:
127
+ return
128
+ for param_group_index in range(len(self.param_groups)):
129
+ param_group = self.param_groups[param_group_index]
130
+ slider = self.sliders[param_group_index]
131
+ param_range = param_group.get_range()
132
+ param_index = param_group.get_parameter_index()
133
+ alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin())
134
+ pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
135
+
136
+
137
+ def convert_output_image_from_torch_to_numpy(output_image):
138
+ if output_image.shape[2] == 2:
139
+ h, w, c = output_image.shape
140
+ numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w)
141
+ elif output_image.shape[0] == 4:
142
+ numpy_image = rgba_to_numpy_image(output_image)
143
+ elif output_image.shape[0] == 3:
144
+ numpy_image = rgb_to_numpy_image(output_image)
145
+ elif output_image.shape[0] == 1:
146
+ c, h, w = output_image.shape
147
+ alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0)
148
+ numpy_image = rgba_to_numpy_image(alpha_image)
149
+ elif output_image.shape[0] == 2:
150
+ numpy_image = grid_change_to_numpy_image(output_image, num_channels=4)
151
+ else:
152
+ raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0])
153
+ numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0))
154
+ return numpy_image
155
+
156
+
157
+ class MainFrame(wx.Frame):
158
+ def __init__(self, poser: Poser, device: torch.device):
159
+ super().__init__(None, wx.ID_ANY, "Poser")
160
+ self.poser = poser
161
+ self.dtype = self.poser.get_dtype()
162
+ self.device = device
163
+ self.image_size = self.poser.get_image_size()
164
+
165
+ self.wx_source_image = None
166
+ self.torch_source_image = None
167
+
168
+ self.main_sizer = wx.BoxSizer(wx.HORIZONTAL)
169
+ self.SetSizer(self.main_sizer)
170
+ self.SetAutoLayout(1)
171
+ self.init_left_panel()
172
+ self.init_control_panel()
173
+ self.init_right_panel()
174
+ self.main_sizer.Fit(self)
175
+
176
+ self.timer = wx.Timer(self, wx.ID_ANY)
177
+ self.Bind(wx.EVT_TIMER, self.update_images, self.timer)
178
+
179
+ save_image_id = wx.NewIdRef()
180
+ self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id)
181
+ accelerator_table = wx.AcceleratorTable([
182
+ (wx.ACCEL_CTRL, ord('S'), save_image_id)
183
+ ])
184
+ self.SetAcceleratorTable(accelerator_table)
185
+
186
+ self.last_pose = None
187
+ self.last_output_index = self.output_index_choice.GetSelection()
188
+ self.last_output_numpy_image = None
189
+
190
+ self.wx_source_image = None
191
+ self.torch_source_image = None
192
+ self.source_image_bitmap = wx.Bitmap(self.image_size, self.image_size)
193
+ self.result_image_bitmap = wx.Bitmap(self.image_size, self.image_size)
194
+ self.source_image_dirty = True
195
+
196
+ def init_left_panel(self):
197
+ self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(self.image_size, -1))
198
+ self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
199
+ left_panel_sizer = wx.BoxSizer(wx.VERTICAL)
200
+ self.left_panel.SetSizer(left_panel_sizer)
201
+ self.left_panel.SetAutoLayout(1)
202
+
203
+ self.source_image_panel = wx.Panel(self.left_panel, size=(self.image_size, self.image_size),
204
+ style=wx.SIMPLE_BORDER)
205
+ self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)
206
+ self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
207
+ left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)
208
+
209
+ self.load_image_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad Image\n\n")
210
+ left_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND)
211
+ self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image)
212
+
213
+ left_panel_sizer.Fit(self.left_panel)
214
+ self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE)
215
+
216
+ def on_erase_background(self, event: wx.Event):
217
+ pass
218
+
219
+ def init_control_panel(self):
220
+ self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL)
221
+ self.control_panel.SetSizer(self.control_panel_sizer)
222
+ self.control_panel.SetMinSize(wx.Size(256, 1))
223
+
224
+ morph_categories = [
225
+ PoseParameterCategory.EYEBROW,
226
+ PoseParameterCategory.EYE,
227
+ PoseParameterCategory.MOUTH,
228
+ PoseParameterCategory.IRIS_MORPH
229
+ ]
230
+ morph_category_titles = {
231
+ PoseParameterCategory.EYEBROW: " ------------ Eyebrow ------------ ",
232
+ PoseParameterCategory.EYE: " ------------ Eye ------------ ",
233
+ PoseParameterCategory.MOUTH: " ------------ Mouth ------------ ",
234
+ PoseParameterCategory.IRIS_MORPH: " ------------ Iris morphs ------------ ",
235
+ }
236
+ self.morph_control_panels = {}
237
+ for category in morph_categories:
238
+ param_groups = self.poser.get_pose_parameter_groups()
239
+ filtered_param_groups = [group for group in param_groups if group.get_category() == category]
240
+ if len(filtered_param_groups) == 0:
241
+ continue
242
+ control_panel = MorphCategoryControlPanel(
243
+ self.control_panel,
244
+ morph_category_titles[category],
245
+ category,
246
+ self.poser.get_pose_parameter_groups())
247
+ self.morph_control_panels[category] = control_panel
248
+ self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
249
+
250
+ self.non_morph_control_panels = {}
251
+ non_morph_categories = [
252
+ PoseParameterCategory.IRIS_ROTATION,
253
+ PoseParameterCategory.FACE_ROTATION,
254
+ PoseParameterCategory.BODY_ROTATION,
255
+ PoseParameterCategory.BREATHING
256
+ ]
257
+ for category in non_morph_categories:
258
+ param_groups = self.poser.get_pose_parameter_groups()
259
+ filtered_param_groups = [group for group in param_groups if group.get_category() == category]
260
+ if len(filtered_param_groups) == 0:
261
+ continue
262
+ control_panel = SimpleParamGroupsControlPanel(
263
+ self.control_panel,
264
+ category,
265
+ self.poser.get_pose_parameter_groups())
266
+ self.non_morph_control_panels[category] = control_panel
267
+ self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
268
+
269
+ self.control_panel_sizer.Fit(self.control_panel)
270
+ self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE)
271
+
272
+ def init_right_panel(self):
273
+ self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
274
+ right_panel_sizer = wx.BoxSizer(wx.VERTICAL)
275
+ self.right_panel.SetSizer(right_panel_sizer)
276
+ self.right_panel.SetAutoLayout(1)
277
+
278
+ self.result_image_panel = wx.Panel(self.right_panel,
279
+ size=(self.image_size, self.image_size),
280
+ style=wx.SIMPLE_BORDER)
281
+ self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)
282
+ self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
283
+ self.output_index_choice = wx.Choice(
284
+ self.right_panel,
285
+ choices=[str(i) for i in range(self.poser.get_output_length())])
286
+ self.output_index_choice.SetSelection(0)
287
+ right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)
288
+ right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND)
289
+
290
+ self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave Image\n\n")
291
+ right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND)
292
+ self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image)
293
+
294
+ right_panel_sizer.Fit(self.right_panel)
295
+ self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE)
296
+
297
+ def create_param_category_choice(self, param_category: PoseParameterCategory):
298
+ params = []
299
+ for param_group in self.poser.get_pose_parameter_groups():
300
+ if param_group.get_category() == param_category:
301
+ params.append(param_group.get_group_name())
302
+ choice = wx.Choice(self.control_panel, choices=params)
303
+ if len(params) > 0:
304
+ choice.SetSelection(0)
305
+ return choice
306
+
307
+ def load_image(self, event: wx.Event):
308
+ dir_name = "data/images"
309
+ file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_OPEN)
310
+ if file_dialog.ShowModal() == wx.ID_OK:
311
+ image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
312
+ try:
313
+ pil_image = resize_PIL_image(extract_PIL_image_from_filelike(image_file_name),
314
+ (self.poser.get_image_size(), self.poser.get_image_size()))
315
+ w, h = pil_image.size
316
+ if pil_image.mode != 'RGBA':
317
+ self.source_image_string = "Image must have alpha channel!"
318
+ self.wx_source_image = None
319
+ self.torch_source_image = None
320
+ else:
321
+ self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes())
322
+ self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image)\
323
+ .to(self.device).to(self.dtype)
324
+ self.source_image_dirty = True
325
+ self.Refresh()
326
+ self.Update()
327
+ except:
328
+ message_dialog = wx.MessageDialog(self, "Could not load image " + image_file_name, "Poser", wx.OK)
329
+ message_dialog.ShowModal()
330
+ message_dialog.Destroy()
331
+ file_dialog.Destroy()
332
+
333
+ def paint_source_image_panel(self, event: wx.Event):
334
+ wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)
335
+
336
+ def paint_result_image_panel(self, event: wx.Event):
337
+ wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)
338
+
339
+ def draw_nothing_yet_string_to_bitmap(self, bitmap):
340
+ dc = wx.MemoryDC()
341
+ dc.SelectObject(bitmap)
342
+
343
+ dc.Clear()
344
+ font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))
345
+ dc.SetFont(font)
346
+ w, h = dc.GetTextExtent("Nothing yet!")
347
+ dc.DrawText("Nothing yet!", (self.image_size - w) // 2, (self.image_size - - h) // 2)
348
+
349
+ del dc
350
+
351
+ def get_current_pose(self):
352
+ current_pose = [0.0 for i in range(self.poser.get_num_parameters())]
353
+ for morph_control_panel in self.morph_control_panels.values():
354
+ morph_control_panel.set_param_value(current_pose)
355
+ for rotation_control_panel in self.non_morph_control_panels.values():
356
+ rotation_control_panel.set_param_value(current_pose)
357
+ return current_pose
358
+
359
+ def update_images(self, event: wx.Event):
360
+ current_pose = self.get_current_pose()
361
+ if not self.source_image_dirty \
362
+ and self.last_pose is not None \
363
+ and self.last_pose == current_pose \
364
+ and self.last_output_index == self.output_index_choice.GetSelection():
365
+ return
366
+ self.last_pose = current_pose
367
+ self.last_output_index = self.output_index_choice.GetSelection()
368
+
369
+ if self.torch_source_image is None:
370
+ self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap)
371
+ self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap)
372
+ self.source_image_dirty = False
373
+ self.Refresh()
374
+ self.Update()
375
+ return
376
+
377
+ if self.source_image_dirty:
378
+ dc = wx.MemoryDC()
379
+ dc.SelectObject(self.source_image_bitmap)
380
+ dc.Clear()
381
+ dc.DrawBitmap(self.wx_source_image, 0, 0)
382
+ self.source_image_dirty = False
383
+
384
+ pose = torch.tensor(current_pose, device=self.device, dtype=self.dtype)
385
+ output_index = self.output_index_choice.GetSelection()
386
+ with torch.no_grad():
387
+ output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu()
388
+
389
+ numpy_image = convert_output_image_from_torch_to_numpy(output_image)
390
+ self.last_output_numpy_image = numpy_image
391
+ wx_image = wx.ImageFromBuffer(
392
+ numpy_image.shape[0],
393
+ numpy_image.shape[1],
394
+ numpy_image[:, :, 0:3].tobytes(),
395
+ numpy_image[:, :, 3].tobytes())
396
+ wx_bitmap = wx_image.ConvertToBitmap()
397
+
398
+ dc = wx.MemoryDC()
399
+ dc.SelectObject(self.result_image_bitmap)
400
+ dc.Clear()
401
+ dc.DrawBitmap(wx_bitmap,
402
+ (self.image_size - numpy_image.shape[0]) // 2,
403
+ (self.image_size - numpy_image.shape[1]) // 2,
404
+ True)
405
+ del dc
406
+
407
+ self.Refresh()
408
+ self.Update()
409
+
410
+ def on_save_image(self, event: wx.Event):
411
+ if self.last_output_numpy_image is None:
412
+ logging.info("There is no output image to save!!!")
413
+ return
414
+
415
+ dir_name = "data/images"
416
+ file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_SAVE)
417
+ if file_dialog.ShowModal() == wx.ID_OK:
418
+ image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
419
+ try:
420
+ if os.path.exists(image_file_name):
421
+ message_dialog = wx.MessageDialog(self, f"Override {image_file_name}", "Manual Poser",
422
+ wx.YES_NO | wx.ICON_QUESTION)
423
+ result = message_dialog.ShowModal()
424
+ if result == wx.ID_YES:
425
+ self.save_last_numpy_image(image_file_name)
426
+ message_dialog.Destroy()
427
+ else:
428
+ self.save_last_numpy_image(image_file_name)
429
+ except:
430
+ message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}", "Manual Poser", wx.OK)
431
+ message_dialog.ShowModal()
432
+ message_dialog.Destroy()
433
+ file_dialog.Destroy()
434
+
435
+ def save_last_numpy_image(self, image_file_name):
436
+ numpy_image = self.last_output_numpy_image
437
+ pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')
438
+ os.makedirs(os.path.dirname(image_file_name), exist_ok=True)
439
+ pil_image.save(image_file_name)
440
+
441
+
442
+ if __name__ == "__main__":
443
+ parser = argparse.ArgumentParser(description='Manually pose a character image.')
444
+ parser.add_argument(
445
+ '--model',
446
+ type=str,
447
+ required=False,
448
+ default='standard_float',
449
+ choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'],
450
+ help='The model to use.')
451
+ args = parser.parse_args()
452
+
453
+ device = torch.device('cuda')
454
+ try:
455
+ poser = load_poser(args.model, device)
456
+ except RuntimeError as e:
457
+ print(e)
458
+ sys.exit()
459
+
460
+ app = wx.App()
461
+ main_frame = MainFrame(poser, device)
462
+ main_frame.Show(True)
463
+ main_frame.timer.Start(30)
464
+ app.MainLoop()
tha3/compute/__init__.py ADDED
File without changes
tha3/compute/cached_computation_func.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, List
2
+
3
+ from torch import Tensor
4
+ from torch.nn import Module
5
+
6
+ TensorCachedComputationFunc = Callable[
7
+ [Dict[str, Module], List[Tensor], Dict[str, List[Tensor]]], Tensor]
8
+ TensorListCachedComputationFunc = Callable[
9
+ [Dict[str, Module], List[Tensor], Dict[str, List[Tensor]]], List[Tensor]]
tha3/compute/cached_computation_protocol.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, List
3
+
4
+ from torch import Tensor
5
+ from torch.nn import Module
6
+
7
+ from tha3.compute.cached_computation_func import TensorCachedComputationFunc, TensorListCachedComputationFunc
8
+
9
+
10
+ class CachedComputationProtocol(ABC):
11
+ def get_output(self,
12
+ key: str,
13
+ modules: Dict[str, Module],
14
+ batch: List[Tensor],
15
+ outputs: Dict[str, List[Tensor]]):
16
+ if key in outputs:
17
+ return outputs[key]
18
+ else:
19
+ output = self.compute_output(key, modules, batch, outputs)
20
+ outputs[key] = output
21
+ return outputs[key]
22
+
23
+ @abstractmethod
24
+ def compute_output(self,
25
+ key: str,
26
+ modules: Dict[str, Module],
27
+ batch: List[Tensor],
28
+ outputs: Dict[str, List[Tensor]]) -> List[Tensor]:
29
+ pass
30
+
31
+ def get_output_tensor_func(self, key: str, index: int) -> TensorCachedComputationFunc:
32
+ def func(modules: Dict[str, Module],
33
+ batch: List[Tensor],
34
+ outputs: Dict[str, List[Tensor]]):
35
+ return self.get_output(key, modules, batch, outputs)[index]
36
+ return func
37
+
38
+ def get_output_tensor_list_func(self, key: str) -> TensorListCachedComputationFunc:
39
+ def func(modules: Dict[str, Module],
40
+ batch: List[Tensor],
41
+ outputs: Dict[str, List[Tensor]]):
42
+ return self.get_output(key, modules, batch, outputs)
43
+ return func
tha3/mocap/__init__.py ADDED
File without changes
tha3/mocap/ifacialmocap_constants.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EYE_LOOK_IN_LEFT = "eyeLookInLeft"
2
+ EYE_LOOK_OUT_LEFT = "eyeLookOutLeft"
3
+ EYE_LOOK_DOWN_LEFT = "eyeLookDownLeft"
4
+ EYE_LOOK_UP_LEFT = "eyeLookUpLeft"
5
+ EYE_BLINK_LEFT = "eyeBlinkLeft"
6
+ EYE_SQUINT_LEFT = "eyeSquintLeft"
7
+ EYE_WIDE_LEFT = "eyeWideLeft"
8
+ EYE_LOOK_IN_RIGHT = "eyeLookInRight"
9
+ EYE_LOOK_OUT_RIGHT = "eyeLookOutRight"
10
+ EYE_LOOK_DOWN_RIGHT = "eyeLookDownRight"
11
+ EYE_LOOK_UP_RIGHT = "eyeLookUpRight"
12
+ EYE_BLINK_RIGHT = "eyeBlinkRight"
13
+ EYE_SQUINT_RIGHT = "eyeSquintRight"
14
+ EYE_WIDE_RIGHT = "eyeWideRight"
15
+ BROW_DOWN_LEFT = "browDownLeft"
16
+ BROW_OUTER_UP_LEFT = "browOuterUpLeft"
17
+ BROW_DOWN_RIGHT = "browDownRight"
18
+ BROW_OUTER_UP_RIGHT = "browOuterUpRight"
19
+ BROW_INNER_UP = "browInnerUp"
20
+ NOSE_SNEER_LEFT = "noseSneerLeft"
21
+ NOSE_SNEER_RIGHT = "noseSneerRight"
22
+ CHEEK_SQUINT_LEFT = "cheekSquintLeft"
23
+ CHEEK_SQUINT_RIGHT = "cheekSquintRight"
24
+ CHEEK_PUFF = "cheekPuff"
25
+ MOUTH_LEFT = "mouthLeft"
26
+ MOUTH_DIMPLE_LEFT = "mouthDimpleLeft"
27
+ MOUTH_FROWN_LEFT = "mouthFrownLeft"
28
+ MOUTH_LOWER_DOWN_LEFT = "mouthLowerDownLeft"
29
+ MOUTH_PRESS_LEFT = "mouthPressLeft"
30
+ MOUTH_SMILE_LEFT = "mouthSmileLeft"
31
+ MOUTH_STRETCH_LEFT = "mouthStretchLeft"
32
+ MOUTH_UPPER_UP_LEFT = "mouthUpperUpLeft"
33
+ MOUTH_RIGHT = "mouthRight"
34
+ MOUTH_DIMPLE_RIGHT = "mouthDimpleRight"
35
+ MOUTH_FROWN_RIGHT = "mouthFrownRight"
36
+ MOUTH_LOWER_DOWN_RIGHT = "mouthLowerDownRight"
37
+ MOUTH_PRESS_RIGHT = "mouthPressRight"
38
+ MOUTH_SMILE_RIGHT = "mouthSmileRight"
39
+ MOUTH_STRETCH_RIGHT = "mouthStretchRight"
40
+ MOUTH_UPPER_UP_RIGHT = "mouthUpperUpRight"
41
+ MOUTH_CLOSE = "mouthClose"
42
+ MOUTH_FUNNEL = "mouthFunnel"
43
+ MOUTH_PUCKER = "mouthPucker"
44
+ MOUTH_ROLL_LOWER = "mouthRollLower"
45
+ MOUTH_ROLL_UPPER = "mouthRollUpper"
46
+ MOUTH_SHRUG_LOWER = "mouthShrugLower"
47
+ MOUTH_SHRUG_UPPER = "mouthShrugUpper"
48
+ JAW_LEFT = "jawLeft"
49
+ JAW_RIGHT = "jawRight"
50
+ JAW_FORWARD = "jawForward"
51
+ JAW_OPEN = "jawOpen"
52
+ TONGUE_OUT = "tongueOut"
53
+
54
+ BLENDSHAPE_NAMES = [
55
+ EYE_LOOK_IN_LEFT, # 0
56
+ EYE_LOOK_OUT_LEFT, # 1
57
+ EYE_LOOK_DOWN_LEFT, # 2
58
+ EYE_LOOK_UP_LEFT, # 3
59
+ EYE_BLINK_LEFT, # 4
60
+ EYE_SQUINT_LEFT, # 5
61
+ EYE_WIDE_LEFT, # 6
62
+ EYE_LOOK_IN_RIGHT, # 7
63
+ EYE_LOOK_OUT_RIGHT, # 8
64
+ EYE_LOOK_DOWN_RIGHT, # 9
65
+ EYE_LOOK_UP_RIGHT, # 10
66
+ EYE_BLINK_RIGHT, # 11
67
+ EYE_SQUINT_RIGHT, # 12
68
+ EYE_WIDE_RIGHT, # 13
69
+ BROW_DOWN_LEFT, # 14
70
+ BROW_OUTER_UP_LEFT, # 15
71
+ BROW_DOWN_RIGHT, # 16
72
+ BROW_OUTER_UP_RIGHT, # 17
73
+ BROW_INNER_UP, # 18
74
+ NOSE_SNEER_LEFT, # 19
75
+ NOSE_SNEER_RIGHT, # 20
76
+ CHEEK_SQUINT_LEFT, # 21
77
+ CHEEK_SQUINT_RIGHT, # 22
78
+ CHEEK_PUFF, # 23
79
+ MOUTH_LEFT, # 24
80
+ MOUTH_DIMPLE_LEFT, # 25
81
+ MOUTH_FROWN_LEFT, # 26
82
+ MOUTH_LOWER_DOWN_LEFT, # 27
83
+ MOUTH_PRESS_LEFT, # 28
84
+ MOUTH_SMILE_LEFT, # 29
85
+ MOUTH_STRETCH_LEFT, # 30
86
+ MOUTH_UPPER_UP_LEFT, # 31
87
+ MOUTH_RIGHT, # 32
88
+ MOUTH_DIMPLE_RIGHT, # 33
89
+ MOUTH_FROWN_RIGHT, # 34
90
+ MOUTH_LOWER_DOWN_RIGHT, # 35
91
+ MOUTH_PRESS_RIGHT, # 36
92
+ MOUTH_SMILE_RIGHT, # 37
93
+ MOUTH_STRETCH_RIGHT, # 38
94
+ MOUTH_UPPER_UP_RIGHT, # 39
95
+ MOUTH_CLOSE, # 40
96
+ MOUTH_FUNNEL, # 41
97
+ MOUTH_PUCKER, # 42
98
+ MOUTH_ROLL_LOWER, # 43
99
+ MOUTH_ROLL_UPPER, # 44
100
+ MOUTH_SHRUG_LOWER, # 45
101
+ MOUTH_SHRUG_UPPER, # 46
102
+ JAW_LEFT, # 47
103
+ JAW_RIGHT, # 48
104
+ JAW_FORWARD, # 49
105
+ JAW_OPEN, # 50
106
+ TONGUE_OUT, # 51
107
+ ]
108
+
109
+ EYE_LEFT_BLENDSHAPES = [
110
+ EYE_LOOK_IN_LEFT, # 0
111
+ EYE_LOOK_OUT_LEFT, # 1
112
+ EYE_LOOK_DOWN_LEFT, # 2
113
+ EYE_LOOK_UP_LEFT, # 3
114
+ EYE_BLINK_LEFT, # 4
115
+ EYE_SQUINT_LEFT, # 5
116
+ EYE_WIDE_LEFT, # 6
117
+ ]
118
+
119
+ EYE_RIGHT_BLENDSHAPES = [
120
+ EYE_LOOK_IN_RIGHT, # 7
121
+ EYE_LOOK_OUT_RIGHT, # 8
122
+ EYE_LOOK_DOWN_RIGHT, # 9
123
+ EYE_LOOK_UP_RIGHT, # 10
124
+ EYE_BLINK_RIGHT, # 11
125
+ EYE_SQUINT_RIGHT, # 12
126
+ EYE_WIDE_RIGHT, # 13
127
+ ]
128
+
129
+ BROW_LEFT_BLENDSHAPES = [
130
+ BROW_DOWN_LEFT, # 14
131
+ BROW_OUTER_UP_LEFT, # 15
132
+
133
+ ]
134
+
135
+ BROW_RIGHT_BLENDSHAPES = [
136
+ BROW_DOWN_RIGHT, # 16
137
+ BROW_OUTER_UP_RIGHT, # 17
138
+
139
+ ]
140
+
141
+ BROW_BOTH_BLENDSHAPES = [
142
+ BROW_INNER_UP, # 18
143
+ ]
144
+
145
+ NOSE_BLENDSHAPES = [
146
+ NOSE_SNEER_LEFT, # 19
147
+ NOSE_SNEER_RIGHT, # 20
148
+ ]
149
+
150
+ CHECK_BLENDSHAPES = [
151
+ CHEEK_SQUINT_LEFT, # 21
152
+ CHEEK_SQUINT_RIGHT, # 22
153
+ CHEEK_PUFF, # 23
154
+ ]
155
+
156
+ MOUTH_LEFT_BLENDSHAPES = [
157
+ MOUTH_LEFT, # 24
158
+ MOUTH_DIMPLE_LEFT, # 25
159
+ MOUTH_FROWN_LEFT, # 26
160
+ MOUTH_LOWER_DOWN_LEFT, # 27
161
+ MOUTH_PRESS_LEFT, # 28
162
+ MOUTH_SMILE_LEFT, # 29
163
+ MOUTH_STRETCH_LEFT, # 30
164
+ MOUTH_UPPER_UP_LEFT, # 31
165
+ ]
166
+
167
+ MOUTH_RIGHT_BLENDSHAPES = [
168
+ MOUTH_RIGHT, # 32
169
+ MOUTH_DIMPLE_RIGHT, # 33
170
+ MOUTH_FROWN_RIGHT, # 34
171
+ MOUTH_LOWER_DOWN_RIGHT, # 35
172
+ MOUTH_PRESS_RIGHT, # 36
173
+ MOUTH_SMILE_RIGHT, # 37
174
+ MOUTH_STRETCH_RIGHT, # 38
175
+ MOUTH_UPPER_UP_RIGHT, # 39
176
+ ]
177
+
178
+ MOUTH_BOTH_BLENDSHAPES = [
179
+ MOUTH_CLOSE, # 40
180
+ MOUTH_FUNNEL, # 41
181
+ MOUTH_PUCKER, # 42
182
+ MOUTH_ROLL_LOWER, # 43
183
+ MOUTH_ROLL_UPPER, # 44
184
+ MOUTH_SHRUG_LOWER, # 45
185
+ MOUTH_SHRUG_UPPER, # 46
186
+ ]
187
+
188
+ JAW_BLENDSHAPES = [
189
+ JAW_LEFT, # 47
190
+ JAW_RIGHT, # 48
191
+ JAW_FORWARD, # 49
192
+ JAW_OPEN, # 50
193
+ ]
194
+
195
+ TONGUE_BLENDSHAPES = [
196
+ TONGUE_OUT, # 51
197
+ ]
198
+
199
+ COLUMN_0_BLENDSHAPES = EYE_RIGHT_BLENDSHAPES + BROW_RIGHT_BLENDSHAPES + [NOSE_SNEER_RIGHT, CHEEK_SQUINT_RIGHT]
200
+ COLUMN_1_BLENDSHAPES = EYE_LEFT_BLENDSHAPES + BROW_LEFT_BLENDSHAPES + [NOSE_SNEER_LEFT, CHEEK_SQUINT_LEFT]
201
+ COLUMN_2_BLENDSHAPES = MOUTH_RIGHT_BLENDSHAPES + [JAW_RIGHT]
202
+ COLUMN_3_BLENDSHAPES = MOUTH_LEFT_BLENDSHAPES + [JAW_LEFT]
203
+ COLUMN_4_BLENDSHAPES = [BROW_INNER_UP, CHEEK_PUFF] + MOUTH_BOTH_BLENDSHAPES + [JAW_FORWARD, JAW_OPEN, TONGUE_OUT]
204
+
205
+ BLENDSHAPE_COLUMNS = [
206
+ COLUMN_0_BLENDSHAPES,
207
+ COLUMN_1_BLENDSHAPES,
208
+ COLUMN_2_BLENDSHAPES,
209
+ COLUMN_3_BLENDSHAPES,
210
+ COLUMN_4_BLENDSHAPES,
211
+ ]
212
+
213
+ RIGHT_EYE_BONE_X = "rightEyeBoneX"
214
+ RIGHT_EYE_BONE_Y = "rightEyeBoneY"
215
+ RIGHT_EYE_BONE_Z = "rightEyeBoneZ"
216
+ RIGHT_EYE_BONE_ROTATIONS = [RIGHT_EYE_BONE_X, RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z]
217
+
218
+ LEFT_EYE_BONE_X = "leftEyeBoneX"
219
+ LEFT_EYE_BONE_Y = "leftEyeBoneY"
220
+ LEFT_EYE_BONE_Z = "leftEyeBoneZ"
221
+ LEFT_EYE_BONE_ROTATIONS = [LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z]
222
+
223
+ HEAD_BONE_X = "headBoneX"
224
+ HEAD_BONE_Y = "headBoneY"
225
+ HEAD_BONE_Z = "headBoneZ"
226
+ HEAD_BONE_ROTATIONS = [HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z]
227
+
228
+ ROTATION_NAMES = RIGHT_EYE_BONE_ROTATIONS + LEFT_EYE_BONE_ROTATIONS + HEAD_BONE_ROTATIONS
229
+
230
+ RIGHT_EYE_BONE_QUAT = "rightEyeBoneQuat"
231
+ LEFT_EYE_BONE_QUAT = "leftEyeBoneQuat"
232
+ HEAD_BONE_QUAT = "headBoneQuat"
233
+ QUATERNION_NAMES = [
234
+ RIGHT_EYE_BONE_QUAT,
235
+ LEFT_EYE_BONE_QUAT,
236
+ HEAD_BONE_QUAT
237
+ ]
238
+
239
+ IFACIALMOCAP_DATETIME_FORMAT = "%Y/%m/%d-%H:%M:%S.%f"
tha3/mocap/ifacialmocap_pose.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tha3.mocap.ifacialmocap_constants import BLENDSHAPE_NAMES, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, \
2
+ HEAD_BONE_QUAT, LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z, LEFT_EYE_BONE_QUAT, RIGHT_EYE_BONE_X, \
3
+ RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z, RIGHT_EYE_BONE_QUAT
4
+
5
+
6
+ def create_default_ifacialmocap_pose():
7
+ data = {}
8
+
9
+ for blendshape_name in BLENDSHAPE_NAMES:
10
+ data[blendshape_name] = 0.0
11
+
12
+ data[HEAD_BONE_X] = 0.0
13
+ data[HEAD_BONE_Y] = 0.0
14
+ data[HEAD_BONE_Z] = 0.0
15
+ data[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
16
+
17
+ data[LEFT_EYE_BONE_X] = 0.0
18
+ data[LEFT_EYE_BONE_Y] = 0.0
19
+ data[LEFT_EYE_BONE_Z] = 0.0
20
+ data[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
21
+
22
+ data[RIGHT_EYE_BONE_X] = 0.0
23
+ data[RIGHT_EYE_BONE_Y] = 0.0
24
+ data[RIGHT_EYE_BONE_Z] = 0.0
25
+ data[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
26
+
27
+ return data
tha3/mocap/ifacialmocap_pose_converter.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, List
3
+
4
+
5
+ class IFacialMocapPoseConverter(ABC):
6
+ @abstractmethod
7
+ def convert(self, ifacialmocap_pose: Dict[str, float]) -> List[float]:
8
+ pass
9
+
10
+ @abstractmethod
11
+ def init_pose_converter_panel(self, parent):
12
+ pass
tha3/mocap/ifacialmocap_poser_converter_25.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+ from enum import Enum
4
+ from typing import Optional, Dict, List
5
+
6
+ import numpy
7
+ import scipy.optimize
8
+ import wx
9
+
10
+ from tha3.mocap.ifacialmocap_constants import MOUTH_SMILE_LEFT, MOUTH_SHRUG_UPPER, MOUTH_SMILE_RIGHT, \
11
+ BROW_INNER_UP, BROW_OUTER_UP_RIGHT, BROW_OUTER_UP_LEFT, BROW_DOWN_LEFT, BROW_DOWN_RIGHT, EYE_WIDE_LEFT, \
12
+ EYE_WIDE_RIGHT, EYE_BLINK_LEFT, EYE_BLINK_RIGHT, CHEEK_SQUINT_LEFT, CHEEK_SQUINT_RIGHT, EYE_LOOK_IN_LEFT, \
13
+ EYE_LOOK_OUT_LEFT, EYE_LOOK_IN_RIGHT, EYE_LOOK_OUT_RIGHT, EYE_LOOK_UP_LEFT, EYE_LOOK_UP_RIGHT, EYE_LOOK_DOWN_RIGHT, \
14
+ EYE_LOOK_DOWN_LEFT, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, JAW_OPEN, MOUTH_FROWN_LEFT, MOUTH_FROWN_RIGHT, \
15
+ MOUTH_LOWER_DOWN_LEFT, MOUTH_LOWER_DOWN_RIGHT, MOUTH_FUNNEL, MOUTH_PUCKER
16
+ from tha3.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter
17
+ from tha3.poser.modes.pose_parameters import get_pose_parameters
18
+
19
+
20
+ class EyebrowDownMode(Enum):
21
+ TROUBLED = 1
22
+ ANGRY = 2
23
+ LOWERED = 3
24
+ SERIOUS = 4
25
+
26
+
27
+ class WinkMode(Enum):
28
+ NORMAL = 1
29
+ RELAXED = 2
30
+
31
+
32
+ def rad_to_deg(rad):
33
+ return rad * 180.0 / math.pi
34
+
35
+
36
+ def deg_to_rad(deg):
37
+ return deg * math.pi / 180.0
38
+
39
+
40
+ def clamp(x, min_value, max_value):
41
+ return max(min_value, min(max_value, x))
42
+
43
+
44
+ class IFacialMocapPoseConverter25Args:
45
+ def __init__(self,
46
+ lower_smile_threshold: float = 0.4,
47
+ upper_smile_threshold: float = 0.6,
48
+ eyebrow_down_mode: EyebrowDownMode = EyebrowDownMode.ANGRY,
49
+ wink_mode: WinkMode = WinkMode.NORMAL,
50
+ eye_surprised_max_value: float = 0.5,
51
+ eye_wink_max_value: float = 0.8,
52
+ eyebrow_down_max_value: float = 0.4,
53
+ cheek_squint_min_value: float = 0.1,
54
+ cheek_squint_max_value: float = 0.7,
55
+ eye_rotation_factor: float = 1.0 / 0.75,
56
+ jaw_open_min_value: float = 0.1,
57
+ jaw_open_max_value: float = 0.4,
58
+ mouth_frown_max_value: float = 0.6,
59
+ mouth_funnel_min_value: float = 0.25,
60
+ mouth_funnel_max_value: float = 0.5,
61
+ iris_small_left=0.0,
62
+ iris_small_right=0.0):
63
+ self.iris_small_right = iris_small_left
64
+ self.iris_small_left = iris_small_right
65
+ self.wink_mode = wink_mode
66
+ self.mouth_funnel_max_value = mouth_funnel_max_value
67
+ self.mouth_funnel_min_value = mouth_funnel_min_value
68
+ self.mouth_frown_max_value = mouth_frown_max_value
69
+ self.jaw_open_max_value = jaw_open_max_value
70
+ self.jaw_open_min_value = jaw_open_min_value
71
+ self.eye_rotation_factor = eye_rotation_factor
72
+ self.cheek_squint_max_value = cheek_squint_max_value
73
+ self.cheek_squint_min_value = cheek_squint_min_value
74
+ self.eyebrow_down_max_value = eyebrow_down_max_value
75
+ self.eye_blink_max_value = eye_wink_max_value
76
+ self.eye_wide_max_value = eye_surprised_max_value
77
+ self.eyebrow_down_mode = eyebrow_down_mode
78
+ self.lower_smile_threshold = lower_smile_threshold
79
+ self.upper_smile_threshold = upper_smile_threshold
80
+
81
+
82
+ class IFacialMocapPoseConverter25(IFacialMocapPoseConverter):
83
+ def __init__(self, args: Optional[IFacialMocapPoseConverter25Args] = None):
84
+ super().__init__()
85
+ if args is None:
86
+ args = IFacialMocapPoseConverter25Args()
87
+ self.args = args
88
+ pose_parameters = get_pose_parameters()
89
+ self.pose_size = 45
90
+
91
+ self.eyebrow_troubled_left_index = pose_parameters.get_parameter_index("eyebrow_troubled_left")
92
+ self.eyebrow_troubled_right_index = pose_parameters.get_parameter_index("eyebrow_troubled_right")
93
+ self.eyebrow_angry_left_index = pose_parameters.get_parameter_index("eyebrow_angry_left")
94
+ self.eyebrow_angry_right_index = pose_parameters.get_parameter_index("eyebrow_angry_right")
95
+ self.eyebrow_happy_left_index = pose_parameters.get_parameter_index("eyebrow_happy_left")
96
+ self.eyebrow_happy_right_index = pose_parameters.get_parameter_index("eyebrow_happy_right")
97
+ self.eyebrow_raised_left_index = pose_parameters.get_parameter_index("eyebrow_raised_left")
98
+ self.eyebrow_raised_right_index = pose_parameters.get_parameter_index("eyebrow_raised_right")
99
+ self.eyebrow_lowered_left_index = pose_parameters.get_parameter_index("eyebrow_lowered_left")
100
+ self.eyebrow_lowered_right_index = pose_parameters.get_parameter_index("eyebrow_lowered_right")
101
+ self.eyebrow_serious_left_index = pose_parameters.get_parameter_index("eyebrow_serious_left")
102
+ self.eyebrow_serious_right_index = pose_parameters.get_parameter_index("eyebrow_serious_right")
103
+
104
+ self.eye_surprised_left_index = pose_parameters.get_parameter_index("eye_surprised_left")
105
+ self.eye_surprised_right_index = pose_parameters.get_parameter_index("eye_surprised_right")
106
+ self.eye_wink_left_index = pose_parameters.get_parameter_index("eye_wink_left")
107
+ self.eye_wink_right_index = pose_parameters.get_parameter_index("eye_wink_right")
108
+ self.eye_happy_wink_left_index = pose_parameters.get_parameter_index("eye_happy_wink_left")
109
+ self.eye_happy_wink_right_index = pose_parameters.get_parameter_index("eye_happy_wink_right")
110
+ self.eye_relaxed_left_index = pose_parameters.get_parameter_index("eye_relaxed_left")
111
+ self.eye_relaxed_right_index = pose_parameters.get_parameter_index("eye_relaxed_right")
112
+ self.eye_raised_lower_eyelid_left_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_left")
113
+ self.eye_raised_lower_eyelid_right_index = pose_parameters.get_parameter_index("eye_raised_lower_eyelid_right")
114
+
115
+ self.iris_small_left_index = pose_parameters.get_parameter_index("iris_small_left")
116
+ self.iris_small_right_index = pose_parameters.get_parameter_index("iris_small_right")
117
+
118
+ self.iris_rotation_x_index = pose_parameters.get_parameter_index("iris_rotation_x")
119
+ self.iris_rotation_y_index = pose_parameters.get_parameter_index("iris_rotation_y")
120
+
121
+ self.head_x_index = pose_parameters.get_parameter_index("head_x")
122
+ self.head_y_index = pose_parameters.get_parameter_index("head_y")
123
+ self.neck_z_index = pose_parameters.get_parameter_index("neck_z")
124
+
125
+ self.mouth_aaa_index = pose_parameters.get_parameter_index("mouth_aaa")
126
+ self.mouth_iii_index = pose_parameters.get_parameter_index("mouth_iii")
127
+ self.mouth_uuu_index = pose_parameters.get_parameter_index("mouth_uuu")
128
+ self.mouth_eee_index = pose_parameters.get_parameter_index("mouth_eee")
129
+ self.mouth_ooo_index = pose_parameters.get_parameter_index("mouth_ooo")
130
+
131
+ self.mouth_lowered_corner_left_index = pose_parameters.get_parameter_index("mouth_lowered_corner_left")
132
+ self.mouth_lowered_corner_right_index = pose_parameters.get_parameter_index("mouth_lowered_corner_right")
133
+ self.mouth_raised_corner_left_index = pose_parameters.get_parameter_index("mouth_raised_corner_left")
134
+ self.mouth_raised_corner_right_index = pose_parameters.get_parameter_index("mouth_raised_corner_right")
135
+
136
+ self.body_y_index = pose_parameters.get_parameter_index("body_y")
137
+ self.body_z_index = pose_parameters.get_parameter_index("body_z")
138
+ self.breathing_index = pose_parameters.get_parameter_index("breathing")
139
+
140
+ self.breathing_start_time = time.time()
141
+
142
+ self.panel = None
143
+
144
+ def init_pose_converter_panel(self, parent):
145
+ self.panel = wx.Panel(parent, style=wx.SIMPLE_BORDER)
146
+ self.panel_sizer = wx.BoxSizer(wx.VERTICAL)
147
+ self.panel.SetSizer(self.panel_sizer)
148
+ self.panel.SetAutoLayout(1)
149
+ parent.GetSizer().Add(self.panel, 0, wx.EXPAND)
150
+
151
+ if True:
152
+ eyebrow_down_mode_text = wx.StaticText(self.panel, label=" --- Eyebrow Down Mode --- ",
153
+ style=wx.ALIGN_CENTER)
154
+ self.panel_sizer.Add(eyebrow_down_mode_text, 0, wx.EXPAND)
155
+
156
+ self.eyebrow_down_mode_choice = wx.Choice(
157
+ self.panel,
158
+ choices=[
159
+ "ANGRY",
160
+ "TROUBLED",
161
+ "SERIOUS",
162
+ "LOWERED",
163
+ ])
164
+ self.eyebrow_down_mode_choice.SetSelection(0)
165
+ self.panel_sizer.Add(self.eyebrow_down_mode_choice, 0, wx.EXPAND)
166
+ self.eyebrow_down_mode_choice.Bind(wx.EVT_CHOICE, self.change_eyebrow_down_mode)
167
+
168
+ separator = wx.StaticLine(self.panel, -1, size=(256, 5))
169
+ self.panel_sizer.Add(separator, 0, wx.EXPAND)
170
+
171
+ if True:
172
+ wink_mode_text = wx.StaticText(self.panel, label=" --- Wink Mode --- ", style=wx.ALIGN_CENTER)
173
+ self.panel_sizer.Add(wink_mode_text, 0, wx.EXPAND)
174
+
175
+ self.wink_mode_choice = wx.Choice(
176
+ self.panel,
177
+ choices=[
178
+ "NORMAL",
179
+ "RELAXED",
180
+ ])
181
+ self.wink_mode_choice.SetSelection(0)
182
+ self.panel_sizer.Add(self.wink_mode_choice, 0, wx.EXPAND)
183
+ self.wink_mode_choice.Bind(wx.EVT_CHOICE, self.change_wink_mode)
184
+
185
+ separator = wx.StaticLine(self.panel, -1, size=(256, 5))
186
+ self.panel_sizer.Add(separator, 0, wx.EXPAND)
187
+
188
+ if True:
189
+ iris_size_text = wx.StaticText(self.panel, label=" --- Iris Size --- ", style=wx.ALIGN_CENTER)
190
+ self.panel_sizer.Add(iris_size_text, 0, wx.EXPAND)
191
+
192
+ self.iris_left_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL)
193
+ self.panel_sizer.Add(self.iris_left_slider, 0, wx.EXPAND)
194
+ self.iris_left_slider.Bind(wx.EVT_SLIDER, self.change_iris_size)
195
+
196
+ self.iris_right_slider = wx.Slider(self.panel, minValue=0, maxValue=1000, value=0, style=wx.HORIZONTAL)
197
+ self.panel_sizer.Add(self.iris_right_slider, 0, wx.EXPAND)
198
+ self.iris_right_slider.Bind(wx.EVT_SLIDER, self.change_iris_size)
199
+ self.iris_right_slider.Enable(False)
200
+
201
+ self.link_left_right_irises = wx.CheckBox(
202
+ self.panel, label="Use same value for both sides")
203
+ self.link_left_right_irises.SetValue(True)
204
+ self.panel_sizer.Add(self.link_left_right_irises, wx.SizerFlags().CenterHorizontal().Border())
205
+ self.link_left_right_irises.Bind(wx.EVT_CHECKBOX, self.link_left_right_irises_clicked)
206
+
207
+ separator = wx.StaticLine(self.panel, -1, size=(256, 5))
208
+ self.panel_sizer.Add(separator, 0, wx.EXPAND)
209
+
210
+ if True:
211
+ breathing_frequency_text = wx.StaticText(
212
+ self.panel, label=" --- Breathing --- ", style=wx.ALIGN_CENTER)
213
+ self.panel_sizer.Add(breathing_frequency_text, 0, wx.EXPAND)
214
+
215
+ self.restart_breathing_cycle_button = wx.Button(self.panel, label="Restart Breathing Cycle")
216
+ self.restart_breathing_cycle_button.Bind(wx.EVT_BUTTON, self.restart_breathing_cycle_clicked)
217
+ self.panel_sizer.Add(self.restart_breathing_cycle_button, 0, wx.EXPAND)
218
+
219
+ self.breathing_frequency_slider = wx.Slider(
220
+ self.panel, minValue=0, maxValue=60, value=20, style=wx.HORIZONTAL)
221
+ self.panel_sizer.Add(self.breathing_frequency_slider, 0, wx.EXPAND)
222
+
223
+ self.breathing_gauge = wx.Gauge(self.panel, style=wx.GA_HORIZONTAL, range=1000)
224
+ self.panel_sizer.Add(self.breathing_gauge, 0, wx.EXPAND)
225
+
226
+ self.panel_sizer.Fit(self.panel)
227
+
228
+ def restart_breathing_cycle_clicked(self, event: wx.Event):
229
+ self.breathing_start_time = time.time()
230
+
231
+ def change_eyebrow_down_mode(self, event: wx.Event):
232
+ selected_index = self.eyebrow_down_mode_choice.GetSelection()
233
+ if selected_index == 0:
234
+ self.args.eyebrow_down_mode = EyebrowDownMode.ANGRY
235
+ elif selected_index == 1:
236
+ self.args.eyebrow_down_mode = EyebrowDownMode.TROUBLED
237
+ elif selected_index == 2:
238
+ self.args.eyebrow_down_mode = EyebrowDownMode.SERIOUS
239
+ else:
240
+ self.args.eyebrow_down_mode = EyebrowDownMode.LOWERED
241
+
242
+ def change_wink_mode(self, event: wx.Event):
243
+ selected_index = self.wink_mode_choice.GetSelection()
244
+ if selected_index == 0:
245
+ self.args.wink_mode = WinkMode.NORMAL
246
+ else:
247
+ self.args.wink_mode = WinkMode.RELAXED
248
+
249
+ def change_iris_size(self, event: wx.Event):
250
+ if self.link_left_right_irises.GetValue():
251
+ left_value = self.iris_left_slider.GetValue()
252
+ right_value = self.iris_right_slider.GetValue()
253
+ if left_value != right_value:
254
+ self.iris_right_slider.SetValue(left_value)
255
+ self.args.iris_small_left = left_value / 1000.0
256
+ self.args.iris_small_right = left_value / 1000.0
257
+ else:
258
+ self.args.iris_small_left = self.iris_left_slider.GetValue() / 1000.0
259
+ self.args.iris_small_right = self.iris_right_slider.GetValue() / 1000.0
260
+
261
+ def link_left_right_irises_clicked(self, event: wx.Event):
262
+ if self.link_left_right_irises.GetValue():
263
+ self.iris_right_slider.Enable(False)
264
+ else:
265
+ self.iris_right_slider.Enable(True)
266
+ self.change_iris_size(event)
267
+
268
+ def decompose_head_body_param(self, param, threshold=2.0 / 3):
269
+ if abs(param) < threshold:
270
+ return (param, 0.0)
271
+ else:
272
+ if param < 0:
273
+ sign = -1.0
274
+ else:
275
+ sign = 1.0
276
+ return (threshold * sign, (abs(param) - threshold) * sign)
277
+
278
+ def convert(self, ifacialmocap_pose: Dict[str, float]) -> List[float]:
279
+ pose = [0.0 for i in range(self.pose_size)]
280
+
281
+ smile_value = \
282
+ (ifacialmocap_pose[MOUTH_SMILE_LEFT] + ifacialmocap_pose[MOUTH_SMILE_RIGHT]) / 2.0 \
283
+ + ifacialmocap_pose[MOUTH_SHRUG_UPPER]
284
+ if smile_value < self.args.lower_smile_threshold:
285
+ smile_degree = 0.0
286
+ elif smile_value > self.args.upper_smile_threshold:
287
+ smile_degree = 1.0
288
+ else:
289
+ smile_degree = (smile_value - self.args.lower_smile_threshold) / (
290
+ self.args.upper_smile_threshold - self.args.lower_smile_threshold)
291
+
292
+ # Eyebrow
293
+ if True:
294
+ brow_inner_up = ifacialmocap_pose[BROW_INNER_UP]
295
+ brow_outer_up_right = ifacialmocap_pose[BROW_OUTER_UP_RIGHT]
296
+ brow_outer_up_left = ifacialmocap_pose[BROW_OUTER_UP_LEFT]
297
+
298
+ brow_up_left = clamp(brow_inner_up + brow_outer_up_left, 0.0, 1.0)
299
+ brow_up_right = clamp(brow_inner_up + brow_outer_up_right, 0.0, 1.0)
300
+ pose[self.eyebrow_raised_left_index] = brow_up_left
301
+ pose[self.eyebrow_raised_right_index] = brow_up_right
302
+
303
+ brow_down_left = (1.0 - smile_degree) \
304
+ * clamp(ifacialmocap_pose[BROW_DOWN_LEFT] / self.args.eyebrow_down_max_value, 0.0, 1.0)
305
+ brow_down_right = (1.0 - smile_degree) \
306
+ * clamp(ifacialmocap_pose[BROW_DOWN_RIGHT] / self.args.eyebrow_down_max_value, 0.0, 1.0)
307
+ if self.args.eyebrow_down_mode == EyebrowDownMode.TROUBLED:
308
+ pose[self.eyebrow_troubled_left_index] = brow_down_left
309
+ pose[self.eyebrow_troubled_right_index] = brow_down_right
310
+ elif self.args.eyebrow_down_mode == EyebrowDownMode.ANGRY:
311
+ pose[self.eyebrow_angry_left_index] = brow_down_left
312
+ pose[self.eyebrow_angry_right_index] = brow_down_right
313
+ elif self.args.eyebrow_down_mode == EyebrowDownMode.LOWERED:
314
+ pose[self.eyebrow_lowered_left_index] = brow_down_left
315
+ pose[self.eyebrow_lowered_right_index] = brow_down_right
316
+ elif self.args.eyebrow_down_mode == EyebrowDownMode.SERIOUS:
317
+ pose[self.eyebrow_serious_left_index] = brow_down_left
318
+ pose[self.eyebrow_serious_right_index] = brow_down_right
319
+
320
+ brow_happy_value = clamp(smile_value, 0.0, 1.0) * smile_degree
321
+ pose[self.eyebrow_happy_left_index] = brow_happy_value
322
+ pose[self.eyebrow_happy_right_index] = brow_happy_value
323
+
324
+ # Eye
325
+ if True:
326
+ # Surprised
327
+ pose[self.eye_surprised_left_index] = clamp(
328
+ ifacialmocap_pose[EYE_WIDE_LEFT] / self.args.eye_wide_max_value, 0.0, 1.0)
329
+ pose[self.eye_surprised_right_index] = clamp(
330
+ ifacialmocap_pose[EYE_WIDE_RIGHT] / self.args.eye_wide_max_value, 0.0, 1.0)
331
+
332
+ # Wink
333
+ if self.args.wink_mode == WinkMode.NORMAL:
334
+ wink_left_index = self.eye_wink_left_index
335
+ wink_right_index = self.eye_wink_right_index
336
+ else:
337
+ wink_left_index = self.eye_relaxed_left_index
338
+ wink_right_index = self.eye_relaxed_right_index
339
+ pose[wink_left_index] = (1.0 - smile_degree) * clamp(
340
+ ifacialmocap_pose[EYE_BLINK_LEFT] / self.args.eye_blink_max_value, 0.0, 1.0)
341
+ pose[wink_right_index] = (1.0 - smile_degree) * clamp(
342
+ ifacialmocap_pose[EYE_BLINK_RIGHT] / self.args.eye_blink_max_value, 0.0, 1.0)
343
+ pose[self.eye_happy_wink_left_index] = smile_degree * clamp(
344
+ ifacialmocap_pose[EYE_BLINK_LEFT] / self.args.eye_blink_max_value, 0.0, 1.0)
345
+ pose[self.eye_happy_wink_right_index] = smile_degree * clamp(
346
+ ifacialmocap_pose[EYE_BLINK_RIGHT] / self.args.eye_blink_max_value, 0.0, 1.0)
347
+
348
+ # Lower eyelid
349
+ cheek_squint_denom = self.args.cheek_squint_max_value - self.args.cheek_squint_min_value
350
+ pose[self.eye_raised_lower_eyelid_left_index] = \
351
+ clamp(
352
+ (ifacialmocap_pose[CHEEK_SQUINT_LEFT] - self.args.cheek_squint_min_value) / cheek_squint_denom,
353
+ 0.0, 1.0)
354
+ pose[self.eye_raised_lower_eyelid_right_index] = \
355
+ clamp(
356
+ (ifacialmocap_pose[CHEEK_SQUINT_RIGHT] - self.args.cheek_squint_min_value) / cheek_squint_denom,
357
+ 0.0, 1.0)
358
+
359
+ # Iris rotation
360
+ if True:
361
+ eye_rotation_y = (ifacialmocap_pose[EYE_LOOK_IN_LEFT]
362
+ - ifacialmocap_pose[EYE_LOOK_OUT_LEFT]
363
+ - ifacialmocap_pose[EYE_LOOK_IN_RIGHT]
364
+ + ifacialmocap_pose[EYE_LOOK_OUT_RIGHT]) / 2.0 * self.args.eye_rotation_factor
365
+ pose[self.iris_rotation_y_index] = clamp(eye_rotation_y, -1.0, 1.0)
366
+
367
+ eye_rotation_x = (ifacialmocap_pose[EYE_LOOK_UP_LEFT]
368
+ + ifacialmocap_pose[EYE_LOOK_UP_RIGHT]
369
+ - ifacialmocap_pose[EYE_LOOK_DOWN_LEFT]
370
+ - ifacialmocap_pose[EYE_LOOK_DOWN_RIGHT]) / 2.0 * self.args.eye_rotation_factor
371
+ pose[self.iris_rotation_x_index] = clamp(eye_rotation_x, -1.0, 1.0)
372
+
373
+ # Iris size
374
+ if True:
375
+ pose[self.iris_small_left_index] = self.args.iris_small_left
376
+ pose[self.iris_small_right_index] = self.args.iris_small_right
377
+
378
+ # Head rotation
379
+ if True:
380
+ x_param = clamp(-ifacialmocap_pose[HEAD_BONE_X] * 180.0 / math.pi, -15.0, 15.0) / 15.0
381
+ pose[self.head_x_index] = x_param
382
+
383
+ y_param = clamp(-ifacialmocap_pose[HEAD_BONE_Y] * 180.0 / math.pi, -10.0, 10.0) / 10.0
384
+ pose[self.head_y_index] = y_param
385
+ pose[self.body_y_index] = y_param
386
+
387
+ z_param = clamp(ifacialmocap_pose[HEAD_BONE_Z] * 180.0 / math.pi, -15.0, 15.0) / 15.0
388
+ pose[self.neck_z_index] = z_param
389
+ pose[self.body_z_index] = z_param
390
+
391
+ # Mouth
392
+ if True:
393
+ jaw_open_denom = self.args.jaw_open_max_value - self.args.jaw_open_min_value
394
+ mouth_open = clamp((ifacialmocap_pose[JAW_OPEN] - self.args.jaw_open_min_value) / jaw_open_denom, 0.0, 1.0)
395
+ pose[self.mouth_aaa_index] = mouth_open
396
+ pose[self.mouth_raised_corner_left_index] = clamp(smile_value, 0.0, 1.0)
397
+ pose[self.mouth_raised_corner_right_index] = clamp(smile_value, 0.0, 1.0)
398
+
399
+ is_mouth_open = mouth_open > 0.0
400
+ if not is_mouth_open:
401
+ mouth_frown_value = clamp(
402
+ (ifacialmocap_pose[MOUTH_FROWN_LEFT] + ifacialmocap_pose[
403
+ MOUTH_FROWN_RIGHT]) / self.args.mouth_frown_max_value, 0.0, 1.0)
404
+ pose[self.mouth_lowered_corner_left_index] = mouth_frown_value
405
+ pose[self.mouth_lowered_corner_right_index] = mouth_frown_value
406
+ else:
407
+ mouth_lower_down = clamp(
408
+ ifacialmocap_pose[MOUTH_LOWER_DOWN_LEFT] + ifacialmocap_pose[MOUTH_LOWER_DOWN_RIGHT], 0.0, 1.0)
409
+ mouth_funnel = ifacialmocap_pose[MOUTH_FUNNEL]
410
+ mouth_pucker = ifacialmocap_pose[MOUTH_PUCKER]
411
+
412
+ mouth_point = [mouth_open, mouth_lower_down, mouth_funnel, mouth_pucker]
413
+
414
+ aaa_point = [1.0, 1.0, 0.0, 0.0]
415
+ iii_point = [0.0, 1.0, 0.0, 0.0]
416
+ uuu_point = [0.5, 0.3, 0.25, 0.75]
417
+ ooo_point = [1.0, 0.5, 0.5, 0.4]
418
+
419
+ decomp = numpy.array([0, 0, 0, 0])
420
+ M = numpy.array([
421
+ aaa_point,
422
+ iii_point,
423
+ uuu_point,
424
+ ooo_point
425
+ ])
426
+
427
+ def loss(decomp):
428
+ return numpy.linalg.norm(numpy.matmul(decomp, M) - mouth_point) \
429
+ + 0.01 * numpy.linalg.norm(decomp, ord=1)
430
+
431
+ opt_result = scipy.optimize.minimize(
432
+ loss, decomp, bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0), (0.0, 1.0)])
433
+ decomp = opt_result["x"]
434
+ restricted_decomp = [decomp.item(0), decomp.item(1), decomp.item(2), decomp.item(3)]
435
+ pose[self.mouth_aaa_index] = restricted_decomp[0]
436
+ pose[self.mouth_iii_index] = restricted_decomp[1]
437
+ mouth_funnel_denom = self.args.mouth_funnel_max_value - self.args.mouth_funnel_min_value
438
+ ooo_alpha = clamp((mouth_funnel - self.args.mouth_funnel_min_value) / mouth_funnel_denom, 0.0, 1.0)
439
+ uo_value = clamp(restricted_decomp[2] + restricted_decomp[3], 0.0, 1.0)
440
+ pose[self.mouth_uuu_index] = uo_value * (1.0 - ooo_alpha)
441
+ pose[self.mouth_ooo_index] = uo_value * ooo_alpha
442
+
443
+ if self.panel is not None:
444
+ frequency = self.breathing_frequency_slider.GetValue()
445
+ if frequency == 0:
446
+ value = 0.0
447
+ pose[self.breathing_index] = value
448
+ self.breathing_start_time = time.time()
449
+ else:
450
+ period = 60.0 / frequency
451
+ now = time.time()
452
+ diff = now - self.breathing_start_time
453
+ frac = (diff % period) / period
454
+ value = (-math.cos(2 * math.pi * frac) + 1.0) / 2.0
455
+ pose[self.breathing_index] = value
456
+ self.breathing_gauge.SetValue(int(1000 * value))
457
+
458
+ return pose
459
+
460
+
461
+ def create_ifacialmocap_pose_converter(
462
+ args: Optional[IFacialMocapPoseConverter25Args] = None) -> IFacialMocapPoseConverter:
463
+ return IFacialMocapPoseConverter25(args)
tha3/mocap/ifacialmocap_v2.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from tha3.mocap.ifacialmocap_constants import BLENDSHAPE_NAMES, HEAD_BONE_X, HEAD_BONE_Y, HEAD_BONE_Z, \
4
+ RIGHT_EYE_BONE_X, RIGHT_EYE_BONE_Y, RIGHT_EYE_BONE_Z, LEFT_EYE_BONE_X, LEFT_EYE_BONE_Y, LEFT_EYE_BONE_Z, \
5
+ HEAD_BONE_QUAT, LEFT_EYE_BONE_QUAT, RIGHT_EYE_BONE_QUAT
6
+
7
+ IFACIALMOCAP_PORT = 49983
8
+ IFACIALMOCAP_START_STRING = "iFacialMocap_sahuasouryya9218sauhuiayeta91555dy3719|sendDataVersion=v2".encode('utf-8')
9
+
10
+
11
+ def parse_ifacialmocap_v2_pose(ifacialmocap_output):
12
+ output = {}
13
+ parts = ifacialmocap_output.split("|")
14
+ for part in parts:
15
+ part = part.strip()
16
+ if len(part) == 0:
17
+ continue
18
+ if "&" in part:
19
+ components = part.split("&")
20
+ assert len(components) == 2
21
+ key = components[0]
22
+ value = float(components[1]) / 100.0
23
+ if key.endswith("_L"):
24
+ key = key[:-2] + "Left"
25
+ elif key.endswith("_R"):
26
+ key = key[:-2] + "Right"
27
+ if key in BLENDSHAPE_NAMES:
28
+ output[key] = value
29
+ elif part.startswith("=head#"):
30
+ components = part[len("=head#"):].split(",")
31
+ assert len(components) == 6
32
+ output[HEAD_BONE_X] = float(components[0]) * math.pi / 180
33
+ output[HEAD_BONE_Y] = float(components[1]) * math.pi / 180
34
+ output[HEAD_BONE_Z] = float(components[2]) * math.pi / 180
35
+ elif part.startswith("rightEye#"):
36
+ components = part[len("rightEye#"):].split(",")
37
+ output[RIGHT_EYE_BONE_X] = float(components[0]) * math.pi / 180
38
+ output[RIGHT_EYE_BONE_Y] = float(components[1]) * math.pi / 180
39
+ output[RIGHT_EYE_BONE_Z] = float(components[2]) * math.pi / 180
40
+ elif part.startswith("leftEye#"):
41
+ components = part[len("leftEye#"):].split(",")
42
+ output[LEFT_EYE_BONE_X] = float(components[0]) * math.pi / 180
43
+ output[LEFT_EYE_BONE_Y] = float(components[1]) * math.pi / 180
44
+ output[LEFT_EYE_BONE_Z] = float(components[2]) * math.pi / 180
45
+ output[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
46
+ output[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
47
+ output[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
48
+ return output
49
+
50
+
51
+ def parse_ifacialmocap_v1_pose(ifacialmocap_output):
52
+ output = {}
53
+ parts = ifacialmocap_output.split("|")
54
+ for part in parts:
55
+ part = part.strip()
56
+ if len(part) == 0:
57
+ continue
58
+ if part.startswith("=head#"):
59
+ components = part[len("=head#"):].split(",")
60
+ assert len(components) == 6
61
+ output[HEAD_BONE_X] = float(components[0]) * math.pi / 180
62
+ output[HEAD_BONE_Y] = float(components[1]) * math.pi / 180
63
+ output[HEAD_BONE_Z] = float(components[2]) * math.pi / 180
64
+ elif part.startswith("rightEye#"):
65
+ components = part[len("rightEye#"):].split(",")
66
+ output[RIGHT_EYE_BONE_X] = float(components[0]) * math.pi / 180
67
+ output[RIGHT_EYE_BONE_Y] = float(components[1]) * math.pi / 180
68
+ output[RIGHT_EYE_BONE_Z] = float(components[2]) * math.pi / 180
69
+ elif part.startswith("leftEye#"):
70
+ components = part[len("leftEye#"):].split(",")
71
+ output[LEFT_EYE_BONE_X] = float(components[0]) * math.pi / 180
72
+ output[LEFT_EYE_BONE_Y] = float(components[1]) * math.pi / 180
73
+ output[LEFT_EYE_BONE_Z] = float(components[2]) * math.pi / 180
74
+ else:
75
+ components = part.split("-")
76
+ assert len(components) == 2
77
+ key = components[0]
78
+ value = float(components[1]) / 100.0
79
+ if key.endswith("_L"):
80
+ key = key[:-2] + "Left"
81
+ elif key.endswith("_R"):
82
+ key = key[:-2] + "Right"
83
+ if key in BLENDSHAPE_NAMES:
84
+ output[key] = value
85
+ output[HEAD_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
86
+ output[LEFT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
87
+ output[RIGHT_EYE_BONE_QUAT] = [0.0, 0.0, 0.0, 1.0]
88
+ return output
89
+
tha3/module/__init__.py ADDED
File without changes
tha3/module/module_factory.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ from torch.nn import Module
4
+
5
+
6
+ class ModuleFactory(ABC):
7
+ @abstractmethod
8
+ def create(self) -> Module:
9
+ pass
tha3/nn/__init__.py ADDED
File without changes
tha3/nn/common/__init__.py ADDED
File without changes
tha3/nn/common/conv_block_factory.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from tha3.nn.conv import create_conv7_block_from_block_args, create_conv3_block_from_block_args, \
4
+ create_downsample_block_from_block_args, create_conv3
5
+ from tha3.nn.resnet_block import ResnetBlock
6
+ from tha3.nn.resnet_block_seperable import ResnetBlockSeparable
7
+ from tha3.nn.separable_conv import create_separable_conv7_block, create_separable_conv3_block, \
8
+ create_separable_downsample_block, create_separable_conv3
9
+ from tha3.nn.util import BlockArgs
10
+
11
+
12
+ class ConvBlockFactory:
13
+ def __init__(self,
14
+ block_args: BlockArgs,
15
+ use_separable_convolution: bool = False):
16
+ self.use_separable_convolution = use_separable_convolution
17
+ self.block_args = block_args
18
+
19
+ def create_conv3(self,
20
+ in_channels: int,
21
+ out_channels: int,
22
+ bias: bool,
23
+ initialization_method: Optional[str] = None):
24
+ if initialization_method is None:
25
+ initialization_method = self.block_args.initialization_method
26
+ if self.use_separable_convolution:
27
+ return create_separable_conv3(
28
+ in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm)
29
+ else:
30
+ return create_conv3(
31
+ in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm)
32
+
33
+ def create_conv7_block(self, in_channels: int, out_channels: int):
34
+ if self.use_separable_convolution:
35
+ return create_separable_conv7_block(in_channels, out_channels, self.block_args)
36
+ else:
37
+ return create_conv7_block_from_block_args(in_channels, out_channels, self.block_args)
38
+
39
+ def create_conv3_block(self, in_channels: int, out_channels: int):
40
+ if self.use_separable_convolution:
41
+ return create_separable_conv3_block(in_channels, out_channels, self.block_args)
42
+ else:
43
+ return create_conv3_block_from_block_args(in_channels, out_channels, self.block_args)
44
+
45
+ def create_downsample_block(self, in_channels: int, out_channels: int, is_output_1x1: bool):
46
+ if self.use_separable_convolution:
47
+ return create_separable_downsample_block(in_channels, out_channels, is_output_1x1, self.block_args)
48
+ else:
49
+ return create_downsample_block_from_block_args(in_channels, out_channels, is_output_1x1)
50
+
51
+ def create_resnet_block(self, num_channels: int, is_1x1: bool):
52
+ if self.use_separable_convolution:
53
+ return ResnetBlockSeparable.create(num_channels, is_1x1, block_args=self.block_args)
54
+ else:
55
+ return ResnetBlock.create(num_channels, is_1x1, block_args=self.block_args)
tha3/nn/common/poser_args.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from torch.nn import Sigmoid, Sequential, Tanh
4
+
5
+ from tha3.nn.conv import create_conv3, create_conv3_from_block_args
6
+ from tha3.nn.nonlinearity_factory import ReLUFactory
7
+ from tha3.nn.normalization import InstanceNorm2dFactory
8
+ from tha3.nn.util import BlockArgs
9
+
10
+
11
+ class PoserArgs00:
12
+ def __init__(self,
13
+ image_size: int,
14
+ input_image_channels: int,
15
+ output_image_channels: int,
16
+ start_channels: int,
17
+ num_pose_params: int,
18
+ block_args: Optional[BlockArgs] = None):
19
+ self.num_pose_params = num_pose_params
20
+ self.start_channels = start_channels
21
+ self.output_image_channels = output_image_channels
22
+ self.input_image_channels = input_image_channels
23
+ self.image_size = image_size
24
+ if block_args is None:
25
+ self.block_args = BlockArgs(
26
+ normalization_layer_factory=InstanceNorm2dFactory(),
27
+ nonlinearity_factory=ReLUFactory(inplace=True))
28
+ else:
29
+ self.block_args = block_args
30
+
31
+ def create_alpha_block(self):
32
+ from torch.nn import Sequential
33
+ return Sequential(
34
+ create_conv3(
35
+ in_channels=self.start_channels,
36
+ out_channels=1,
37
+ bias=True,
38
+ initialization_method=self.block_args.initialization_method,
39
+ use_spectral_norm=False),
40
+ Sigmoid())
41
+
42
+ def create_all_channel_alpha_block(self):
43
+ from torch.nn import Sequential
44
+ return Sequential(
45
+ create_conv3(
46
+ in_channels=self.start_channels,
47
+ out_channels=self.output_image_channels,
48
+ bias=True,
49
+ initialization_method=self.block_args.initialization_method,
50
+ use_spectral_norm=False),
51
+ Sigmoid())
52
+
53
+ def create_color_change_block(self):
54
+ return Sequential(
55
+ create_conv3_from_block_args(
56
+ in_channels=self.start_channels,
57
+ out_channels=self.output_image_channels,
58
+ bias=True,
59
+ block_args=self.block_args),
60
+ Tanh())
61
+
62
+ def create_grid_change_block(self):
63
+ return create_conv3(
64
+ in_channels=self.start_channels,
65
+ out_channels=2,
66
+ bias=False,
67
+ initialization_method='zero',
68
+ use_spectral_norm=False)
tha3/nn/common/poser_encoder_decoder_00.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, List
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import ModuleList, Module
7
+
8
+ from tha3.nn.common.poser_args import PoserArgs00
9
+ from tha3.nn.conv import create_conv3_block_from_block_args, create_downsample_block_from_block_args, \
10
+ create_upsample_block_from_block_args
11
+ from tha3.nn.nonlinearity_factory import ReLUFactory
12
+ from tha3.nn.normalization import InstanceNorm2dFactory
13
+ from tha3.nn.resnet_block import ResnetBlock
14
+ from tha3.nn.util import BlockArgs
15
+
16
+
17
+ class PoserEncoderDecoder00Args(PoserArgs00):
18
+ def __init__(self,
19
+ image_size: int,
20
+ input_image_channels: int,
21
+ output_image_channels: int,
22
+ num_pose_params: int ,
23
+ start_channels: int,
24
+ bottleneck_image_size,
25
+ num_bottleneck_blocks,
26
+ max_channels: int,
27
+ block_args: Optional[BlockArgs] = None):
28
+ super().__init__(
29
+ image_size, input_image_channels, output_image_channels, start_channels, num_pose_params, block_args)
30
+ self.max_channels = max_channels
31
+ self.num_bottleneck_blocks = num_bottleneck_blocks
32
+ self.bottleneck_image_size = bottleneck_image_size
33
+ assert bottleneck_image_size > 1
34
+
35
+ if block_args is None:
36
+ self.block_args = BlockArgs(
37
+ normalization_layer_factory=InstanceNorm2dFactory(),
38
+ nonlinearity_factory=ReLUFactory(inplace=True))
39
+ else:
40
+ self.block_args = block_args
41
+
42
+
43
+ class PoserEncoderDecoder00(Module):
44
+ def __init__(self, args: PoserEncoderDecoder00Args):
45
+ super().__init__()
46
+ self.args = args
47
+
48
+ self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1
49
+
50
+ self.downsample_blocks = ModuleList()
51
+ self.downsample_blocks.append(
52
+ create_conv3_block_from_block_args(
53
+ args.input_image_channels,
54
+ args.start_channels,
55
+ args.block_args))
56
+ current_image_size = args.image_size
57
+ current_num_channels = args.start_channels
58
+ while current_image_size > args.bottleneck_image_size:
59
+ next_image_size = current_image_size // 2
60
+ next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
61
+ self.downsample_blocks.append(create_downsample_block_from_block_args(
62
+ in_channels=current_num_channels,
63
+ out_channels=next_num_channels,
64
+ is_output_1x1=False,
65
+ block_args=args.block_args))
66
+ current_image_size = next_image_size
67
+ current_num_channels = next_num_channels
68
+ assert len(self.downsample_blocks) == self.num_levels
69
+
70
+ self.bottleneck_blocks = ModuleList()
71
+ self.bottleneck_blocks.append(create_conv3_block_from_block_args(
72
+ in_channels=current_num_channels + args.num_pose_params,
73
+ out_channels=current_num_channels,
74
+ block_args=args.block_args))
75
+ for i in range(1, args.num_bottleneck_blocks):
76
+ self.bottleneck_blocks.append(
77
+ ResnetBlock.create(
78
+ num_channels=current_num_channels,
79
+ is1x1=False,
80
+ block_args=args.block_args))
81
+
82
+ self.upsample_blocks = ModuleList()
83
+ while current_image_size < args.image_size:
84
+ next_image_size = current_image_size * 2
85
+ next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
86
+ self.upsample_blocks.append(create_upsample_block_from_block_args(
87
+ in_channels=current_num_channels,
88
+ out_channels=next_num_channels,
89
+ block_args=args.block_args))
90
+ current_image_size = next_image_size
91
+ current_num_channels = next_num_channels
92
+
93
+ def get_num_output_channels_from_level(self, level: int):
94
+ return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))
95
+
96
+ def get_num_output_channels_from_image_size(self, image_size: int):
97
+ return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)
98
+
99
+ def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]:
100
+ if self.args.num_pose_params != 0:
101
+ assert pose is not None
102
+ else:
103
+ assert pose is None
104
+ outputs = []
105
+ feature = image
106
+ outputs.append(feature)
107
+ for block in self.downsample_blocks:
108
+ feature = block(feature)
109
+ outputs.append(feature)
110
+ if pose is not None:
111
+ n, c = pose.shape
112
+ pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)
113
+ feature = torch.cat([feature, pose], dim=1)
114
+ for block in self.bottleneck_blocks:
115
+ feature = block(feature)
116
+ outputs.append(feature)
117
+ for block in self.upsample_blocks:
118
+ feature = block(feature)
119
+ outputs.append(feature)
120
+ outputs.reverse()
121
+ return outputs
tha3/nn/common/poser_encoder_decoder_00_separable.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, List
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import ModuleList, Module
7
+
8
+ from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args
9
+ from tha3.nn.resnet_block_seperable import ResnetBlockSeparable
10
+ from tha3.nn.separable_conv import create_separable_conv3_block, create_separable_downsample_block, \
11
+ create_separable_upsample_block
12
+
13
+
14
+ class PoserEncoderDecoder00Separable(Module):
15
+ def __init__(self, args: PoserEncoderDecoder00Args):
16
+ super().__init__()
17
+ self.args = args
18
+
19
+ self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1
20
+
21
+ self.downsample_blocks = ModuleList()
22
+ self.downsample_blocks.append(
23
+ create_separable_conv3_block(
24
+ args.input_image_channels,
25
+ args.start_channels,
26
+ args.block_args))
27
+ current_image_size = args.image_size
28
+ current_num_channels = args.start_channels
29
+ while current_image_size > args.bottleneck_image_size:
30
+ next_image_size = current_image_size // 2
31
+ next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
32
+ self.downsample_blocks.append(create_separable_downsample_block(
33
+ in_channels=current_num_channels,
34
+ out_channels=next_num_channels,
35
+ is_output_1x1=False,
36
+ block_args=args.block_args))
37
+ current_image_size = next_image_size
38
+ current_num_channels = next_num_channels
39
+ assert len(self.downsample_blocks) == self.num_levels
40
+
41
+ self.bottleneck_blocks = ModuleList()
42
+ self.bottleneck_blocks.append(create_separable_conv3_block(
43
+ in_channels=current_num_channels + args.num_pose_params,
44
+ out_channels=current_num_channels,
45
+ block_args=args.block_args))
46
+ for i in range(1, args.num_bottleneck_blocks):
47
+ self.bottleneck_blocks.append(
48
+ ResnetBlockSeparable.create(
49
+ num_channels=current_num_channels,
50
+ is1x1=False,
51
+ block_args=args.block_args))
52
+
53
+ self.upsample_blocks = ModuleList()
54
+ while current_image_size < args.image_size:
55
+ next_image_size = current_image_size * 2
56
+ next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
57
+ self.upsample_blocks.append(create_separable_upsample_block(
58
+ in_channels=current_num_channels,
59
+ out_channels=next_num_channels,
60
+ block_args=args.block_args))
61
+ current_image_size = next_image_size
62
+ current_num_channels = next_num_channels
63
+
64
+ def get_num_output_channels_from_level(self, level: int):
65
+ return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))
66
+
67
+ def get_num_output_channels_from_image_size(self, image_size: int):
68
+ return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)
69
+
70
+ def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]:
71
+ if self.args.num_pose_params != 0:
72
+ assert pose is not None
73
+ else:
74
+ assert pose is None
75
+ outputs = []
76
+ feature = image
77
+ outputs.append(feature)
78
+ for block in self.downsample_blocks:
79
+ feature = block(feature)
80
+ outputs.append(feature)
81
+ if pose is not None:
82
+ n, c = pose.shape
83
+ pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)
84
+ feature = torch.cat([feature, pose], dim=1)
85
+ for block in self.bottleneck_blocks:
86
+ feature = block(feature)
87
+ outputs.append(feature)
88
+ for block in self.upsample_blocks:
89
+ feature = block(feature)
90
+ outputs.append(feature)
91
+ outputs.reverse()
92
+ return outputs
tha3/nn/common/resize_conv_encoder_decoder.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, List
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Module, ModuleList, Sequential, Upsample
7
+
8
+ from tha3.nn.common.conv_block_factory import ConvBlockFactory
9
+ from tha3.nn.nonlinearity_factory import LeakyReLUFactory
10
+ from tha3.nn.normalization import InstanceNorm2dFactory
11
+ from tha3.nn.util import BlockArgs
12
+
13
+
14
+ class ResizeConvEncoderDecoderArgs:
15
+ def __init__(self,
16
+ image_size: int,
17
+ input_channels: int,
18
+ start_channels: int,
19
+ bottleneck_image_size,
20
+ num_bottleneck_blocks,
21
+ max_channels: int,
22
+ block_args: Optional[BlockArgs] = None,
23
+ upsample_mode: str = 'bilinear',
24
+ use_separable_convolution=False):
25
+ self.use_separable_convolution = use_separable_convolution
26
+ self.upsample_mode = upsample_mode
27
+ self.block_args = block_args
28
+ self.max_channels = max_channels
29
+ self.num_bottleneck_blocks = num_bottleneck_blocks
30
+ self.bottleneck_image_size = bottleneck_image_size
31
+ self.start_channels = start_channels
32
+ self.image_size = image_size
33
+ self.input_channels = input_channels
34
+
35
+
36
+ class ResizeConvEncoderDecoder(Module):
37
+ def __init__(self, args: ResizeConvEncoderDecoderArgs):
38
+ super().__init__()
39
+ self.args = args
40
+
41
+ self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1
42
+
43
+ conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution)
44
+
45
+ self.downsample_blocks = ModuleList()
46
+ self.downsample_blocks.append(conv_block_factory.create_conv7_block(args.input_channels, args.start_channels))
47
+ current_image_size = args.image_size
48
+ current_num_channels = args.start_channels
49
+ while current_image_size > args.bottleneck_image_size:
50
+ next_image_size = current_image_size // 2
51
+ next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
52
+ self.downsample_blocks.append(conv_block_factory.create_downsample_block(
53
+ in_channels=current_num_channels,
54
+ out_channels=next_num_channels,
55
+ is_output_1x1=False))
56
+ current_image_size = next_image_size
57
+ current_num_channels = next_num_channels
58
+ assert len(self.downsample_blocks) == self.num_levels
59
+
60
+ self.bottleneck_blocks = ModuleList()
61
+ for i in range(args.num_bottleneck_blocks):
62
+ self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_num_channels, is_1x1=False))
63
+
64
+ self.output_image_sizes = [current_image_size]
65
+ self.output_num_channels = [current_num_channels]
66
+ self.upsample_blocks = ModuleList()
67
+ if args.upsample_mode == 'nearest':
68
+ align_corners = None
69
+ else:
70
+ align_corners = False
71
+ while current_image_size < args.image_size:
72
+ next_image_size = current_image_size * 2
73
+ next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
74
+ self.upsample_blocks.append(
75
+ Sequential(
76
+ Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners),
77
+ conv_block_factory.create_conv3_block(
78
+ in_channels=current_num_channels, out_channels=next_num_channels)))
79
+ current_image_size = next_image_size
80
+ current_num_channels = next_num_channels
81
+ self.output_image_sizes.append(current_image_size)
82
+ self.output_num_channels.append(current_num_channels)
83
+
84
+ def get_num_output_channels_from_level(self, level: int):
85
+ return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))
86
+
87
+ def get_num_output_channels_from_image_size(self, image_size: int):
88
+ return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)
89
+
90
+ def forward(self, feature: Tensor) -> List[Tensor]:
91
+ outputs = []
92
+ for block in self.downsample_blocks:
93
+ feature = block(feature)
94
+ for block in self.bottleneck_blocks:
95
+ feature = block(feature)
96
+ outputs.append(feature)
97
+ for block in self.upsample_blocks:
98
+ feature = block(feature)
99
+ outputs.append(feature)
100
+ return outputs
101
+
102
+
103
+ if __name__ == "__main__":
104
+ device = torch.device('cuda')
105
+ args = ResizeConvEncoderDecoderArgs(
106
+ image_size=512,
107
+ input_channels=4 + 6,
108
+ start_channels=32,
109
+ bottleneck_image_size=32,
110
+ num_bottleneck_blocks=6,
111
+ max_channels=512,
112
+ use_separable_convolution=True,
113
+ block_args=BlockArgs(
114
+ initialization_method='he',
115
+ use_spectral_norm=False,
116
+ normalization_layer_factory=InstanceNorm2dFactory(),
117
+ nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1)))
118
+ module = ResizeConvEncoderDecoder(args).to(device)
119
+ print(module.output_image_sizes)
120
+ print(module.output_num_channels)
121
+
122
+ input = torch.zeros(8, 4 + 6, 512, 512, device=device)
123
+ outputs = module(input)
124
+ for output in outputs:
125
+ print(output.shape)
tha3/nn/common/resize_conv_unet.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn import ModuleList, Module, Upsample
6
+
7
+ from tha3.nn.common.conv_block_factory import ConvBlockFactory
8
+ from tha3.nn.nonlinearity_factory import ReLUFactory
9
+ from tha3.nn.normalization import InstanceNorm2dFactory
10
+ from tha3.nn.util import BlockArgs
11
+
12
+
13
+ class ResizeConvUNetArgs:
14
+ def __init__(self,
15
+ image_size: int,
16
+ input_channels: int,
17
+ start_channels: int,
18
+ bottleneck_image_size: int,
19
+ num_bottleneck_blocks: int,
20
+ max_channels: int,
21
+ upsample_mode: str = 'bilinear',
22
+ block_args: Optional[BlockArgs] = None,
23
+ use_separable_convolution: bool = False):
24
+ if block_args is None:
25
+ block_args = BlockArgs(
26
+ normalization_layer_factory=InstanceNorm2dFactory(),
27
+ nonlinearity_factory=ReLUFactory(inplace=False))
28
+
29
+ self.use_separable_convolution = use_separable_convolution
30
+ self.block_args = block_args
31
+ self.upsample_mode = upsample_mode
32
+ self.max_channels = max_channels
33
+ self.num_bottleneck_blocks = num_bottleneck_blocks
34
+ self.bottleneck_image_size = bottleneck_image_size
35
+ self.input_channels = input_channels
36
+ self.start_channels = start_channels
37
+ self.image_size = image_size
38
+
39
+
40
+ class ResizeConvUNet(Module):
41
+ def __init__(self, args: ResizeConvUNetArgs):
42
+ super().__init__()
43
+ self.args = args
44
+ conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution)
45
+
46
+ self.downsample_blocks = ModuleList()
47
+ self.downsample_blocks.append(conv_block_factory.create_conv3_block(
48
+ self.args.input_channels,
49
+ self.args.start_channels))
50
+ current_channels = self.args.start_channels
51
+ current_size = self.args.image_size
52
+
53
+ size_to_channel = {
54
+ current_size: current_channels
55
+ }
56
+ while current_size > self.args.bottleneck_image_size:
57
+ next_size = current_size // 2
58
+ next_channels = min(self.args.max_channels, current_channels * 2)
59
+ self.downsample_blocks.append(conv_block_factory.create_downsample_block(
60
+ current_channels,
61
+ next_channels,
62
+ is_output_1x1=False))
63
+ current_size = next_size
64
+ current_channels = next_channels
65
+ size_to_channel[current_size] = current_channels
66
+
67
+ self.bottleneck_blocks = ModuleList()
68
+ for i in range(self.args.num_bottleneck_blocks):
69
+ self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_channels, is_1x1=False))
70
+
71
+ self.output_image_sizes = [current_size]
72
+ self.output_num_channels = [current_channels]
73
+ self.upsample_blocks = ModuleList()
74
+ while current_size < self.args.image_size:
75
+ next_size = current_size * 2
76
+ next_channels = size_to_channel[next_size]
77
+ self.upsample_blocks.append(conv_block_factory.create_conv3_block(
78
+ current_channels + next_channels,
79
+ next_channels))
80
+ current_size = next_size
81
+ current_channels = next_channels
82
+ self.output_image_sizes.append(current_size)
83
+ self.output_num_channels.append(current_channels)
84
+
85
+ if args.upsample_mode == 'nearest':
86
+ align_corners = None
87
+ else:
88
+ align_corners = False
89
+ self.double_resolution = Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners)
90
+
91
+ def forward(self, feature: Tensor) -> List[Tensor]:
92
+ downsampled_features = []
93
+ for block in self.downsample_blocks:
94
+ feature = block(feature)
95
+ downsampled_features.append(feature)
96
+
97
+ for block in self.bottleneck_blocks:
98
+ feature = block(feature)
99
+
100
+ outputs = [feature]
101
+ for i in range(0, len(self.upsample_blocks)):
102
+ feature = self.double_resolution(feature)
103
+ feature = torch.cat([feature, downsampled_features[-i - 2]], dim=1)
104
+ feature = self.upsample_blocks[i](feature)
105
+ outputs.append(feature)
106
+
107
+ return outputs
108
+
109
+
110
+ if __name__ == "__main__":
111
+ device = torch.device('cuda')
112
+
113
+ image_size = 512
114
+ image_channels = 4
115
+ num_pose_params = 6
116
+ args = ResizeConvUNetArgs(
117
+ image_size=512,
118
+ input_channels=10,
119
+ start_channels=32,
120
+ bottleneck_image_size=32,
121
+ num_bottleneck_blocks=6,
122
+ max_channels=512,
123
+ upsample_mode='nearest',
124
+ use_separable_convolution=False,
125
+ block_args=BlockArgs(
126
+ initialization_method='he',
127
+ use_spectral_norm=False,
128
+ normalization_layer_factory=InstanceNorm2dFactory(),
129
+ nonlinearity_factory=ReLUFactory(inplace=False)))
130
+ module = ResizeConvUNet(args).to(device)
131
+
132
+ image_count = 8
133
+ input = torch.zeros(image_count, 10, 512, 512, device=device)
134
+ outputs = module.forward(input)
135
+ for output in outputs:
136
+ print(output.shape)
137
+
138
+
139
+ if True:
140
+ repeat = 100
141
+ acc = 0.0
142
+ for i in range(repeat + 2):
143
+ start = torch.cuda.Event(enable_timing=True)
144
+ end = torch.cuda.Event(enable_timing=True)
145
+
146
+ start.record()
147
+ module.forward(input)
148
+ end.record()
149
+ torch.cuda.synchronize()
150
+ if i >= 2:
151
+ elapsed_time = start.elapsed_time(end)
152
+ print("%d:" % i, elapsed_time)
153
+ acc = acc + elapsed_time
154
+
155
+ print("average:", acc / repeat)
tha3/nn/conv.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Callable
2
+
3
+ from torch.nn import Conv2d, Module, Sequential, ConvTranspose2d
4
+
5
+ from tha3.module.module_factory import ModuleFactory
6
+ from tha3.nn.nonlinearity_factory import resolve_nonlinearity_factory
7
+ from tha3.nn.normalization import NormalizationLayerFactory
8
+ from tha3.nn.util import wrap_conv_or_linear_module, BlockArgs
9
+
10
+
11
+ def create_conv7(in_channels: int, out_channels: int,
12
+ bias: bool = False,
13
+ initialization_method: Union[str, Callable[[Module], Module]] = 'he',
14
+ use_spectral_norm: bool = False) -> Module:
15
+ return wrap_conv_or_linear_module(
16
+ Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=bias),
17
+ initialization_method,
18
+ use_spectral_norm)
19
+
20
+
21
+ def create_conv7_from_block_args(in_channels: int,
22
+ out_channels: int,
23
+ bias: bool = False,
24
+ block_args: Optional[BlockArgs] = None) -> Module:
25
+ if block_args is None:
26
+ block_args = BlockArgs()
27
+ return create_conv7(
28
+ in_channels, out_channels, bias,
29
+ block_args.initialization_method,
30
+ block_args.use_spectral_norm)
31
+
32
+
33
+ def create_conv3(in_channels: int,
34
+ out_channels: int,
35
+ bias: bool = False,
36
+ initialization_method: Union[str, Callable[[Module], Module]] = 'he',
37
+ use_spectral_norm: bool = False) -> Module:
38
+ return wrap_conv_or_linear_module(
39
+ Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias),
40
+ initialization_method,
41
+ use_spectral_norm)
42
+
43
+
44
+ def create_conv3_from_block_args(in_channels: int, out_channels: int,
45
+ bias: bool = False,
46
+ block_args: Optional[BlockArgs] = None):
47
+ if block_args is None:
48
+ block_args = BlockArgs()
49
+ return create_conv3(in_channels, out_channels, bias,
50
+ block_args.initialization_method,
51
+ block_args.use_spectral_norm)
52
+
53
+
54
+ def create_conv1(in_channels: int, out_channels: int,
55
+ initialization_method: Union[str, Callable[[Module], Module]] = 'he',
56
+ bias: bool = False,
57
+ use_spectral_norm: bool = False) -> Module:
58
+ return wrap_conv_or_linear_module(
59
+ Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
60
+ initialization_method,
61
+ use_spectral_norm)
62
+
63
+
64
+ def create_conv1_from_block_args(in_channels: int,
65
+ out_channels: int,
66
+ bias: bool = False,
67
+ block_args: Optional[BlockArgs] = None) -> Module:
68
+ if block_args is None:
69
+ block_args = BlockArgs()
70
+ return create_conv1(
71
+ in_channels=in_channels,
72
+ out_channels=out_channels,
73
+ initialization_method=block_args.initialization_method,
74
+ bias=bias,
75
+ use_spectral_norm=block_args.use_spectral_norm)
76
+
77
+
78
+ def create_conv7_block(in_channels: int, out_channels: int,
79
+ initialization_method: Union[str, Callable[[Module], Module]] = 'he',
80
+ nonlinearity_factory: Optional[ModuleFactory] = None,
81
+ normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
82
+ use_spectral_norm: bool = False) -> Module:
83
+ nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
84
+ return Sequential(
85
+ create_conv7(in_channels, out_channels,
86
+ bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm),
87
+ NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),
88
+ resolve_nonlinearity_factory(nonlinearity_factory).create())
89
+
90
+
91
+ def create_conv7_block_from_block_args(
92
+ in_channels: int, out_channels: int,
93
+ block_args: Optional[BlockArgs] = None) -> Module:
94
+ if block_args is None:
95
+ block_args = BlockArgs()
96
+ return create_conv7_block(in_channels, out_channels,
97
+ block_args.initialization_method,
98
+ block_args.nonlinearity_factory,
99
+ block_args.normalization_layer_factory,
100
+ block_args.use_spectral_norm)
101
+
102
+
103
+ def create_conv3_block(in_channels: int, out_channels: int,
104
+ initialization_method: Union[str, Callable[[Module], Module]] = 'he',
105
+ nonlinearity_factory: Optional[ModuleFactory] = None,
106
+ normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
107
+ use_spectral_norm: bool = False) -> Module:
108
+ nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
109
+ return Sequential(
110
+ create_conv3(in_channels, out_channels,
111
+ bias=False, initialization_method=initialization_method, use_spectral_norm=use_spectral_norm),
112
+ NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),
113
+ resolve_nonlinearity_factory(nonlinearity_factory).create())
114
+
115
+
116
+ def create_conv3_block_from_block_args(
117
+ in_channels: int, out_channels: int, block_args: Optional[BlockArgs] = None):
118
+ if block_args is None:
119
+ block_args = BlockArgs()
120
+ return create_conv3_block(in_channels, out_channels,
121
+ block_args.initialization_method,
122
+ block_args.nonlinearity_factory,
123
+ block_args.normalization_layer_factory,
124
+ block_args.use_spectral_norm)
125
+
126
+
127
+ def create_downsample_block(in_channels: int, out_channels: int,
128
+ is_output_1x1: bool = False,
129
+ initialization_method: Union[str, Callable[[Module], Module]] = 'he',
130
+ nonlinearity_factory: Optional[ModuleFactory] = None,
131
+ normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
132
+ use_spectral_norm: bool = False) -> Module:
133
+ if is_output_1x1:
134
+ return Sequential(
135
+ wrap_conv_or_linear_module(
136
+ Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
137
+ initialization_method,
138
+ use_spectral_norm),
139
+ resolve_nonlinearity_factory(nonlinearity_factory).create())
140
+ else:
141
+ return Sequential(
142
+ wrap_conv_or_linear_module(
143
+ Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
144
+ initialization_method,
145
+ use_spectral_norm),
146
+ NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),
147
+ resolve_nonlinearity_factory(nonlinearity_factory).create())
148
+
149
+
150
+ def create_downsample_block_from_block_args(in_channels: int, out_channels: int,
151
+ is_output_1x1: bool = False,
152
+ block_args: Optional[BlockArgs] = None):
153
+ if block_args is None:
154
+ block_args = BlockArgs()
155
+ return create_downsample_block(
156
+ in_channels, out_channels,
157
+ is_output_1x1,
158
+ block_args.initialization_method,
159
+ block_args.nonlinearity_factory,
160
+ block_args.normalization_layer_factory,
161
+ block_args.use_spectral_norm)
162
+
163
+
164
+ def create_upsample_block(in_channels: int,
165
+ out_channels: int,
166
+ initialization_method: Union[str, Callable[[Module], Module]] = 'he',
167
+ nonlinearity_factory: Optional[ModuleFactory] = None,
168
+ normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
169
+ use_spectral_norm: bool = False) -> Module:
170
+ nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
171
+ return Sequential(
172
+ wrap_conv_or_linear_module(
173
+ ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
174
+ initialization_method,
175
+ use_spectral_norm),
176
+ NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(out_channels, affine=True),
177
+ resolve_nonlinearity_factory(nonlinearity_factory).create())
178
+
179
+
180
+ def create_upsample_block_from_block_args(in_channels: int,
181
+ out_channels: int,
182
+ block_args: Optional[BlockArgs] = None) -> Module:
183
+ if block_args is None:
184
+ block_args = BlockArgs()
185
+ return create_upsample_block(in_channels, out_channels,
186
+ block_args.initialization_method,
187
+ block_args.nonlinearity_factory,
188
+ block_args.normalization_layer_factory,
189
+ block_args.use_spectral_norm)
tha3/nn/editor/__init__.py ADDED
File without changes
tha3/nn/editor/editor_07.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+
3
+ import torch
4
+ from matplotlib import pyplot
5
+ from torch import Tensor
6
+ from torch.nn import Module, Sequential, Tanh, Sigmoid
7
+
8
+ from tha3.nn.image_processing_util import GridChangeApplier, apply_color_change
9
+ from tha3.nn.common.resize_conv_unet import ResizeConvUNet, ResizeConvUNetArgs
10
+ from tha3.util import numpy_linear_to_srgb
11
+ from tha3.module.module_factory import ModuleFactory
12
+ from tha3.nn.conv import create_conv3_from_block_args, create_conv3
13
+ from tha3.nn.nonlinearity_factory import ReLUFactory
14
+ from tha3.nn.normalization import InstanceNorm2dFactory
15
+ from tha3.nn.util import BlockArgs
16
+
17
+
18
+ class Editor07Args:
19
+ def __init__(self,
20
+ image_size: int = 512,
21
+ image_channels: int = 4,
22
+ num_pose_params: int = 6,
23
+ start_channels: int = 32,
24
+ bottleneck_image_size=32,
25
+ num_bottleneck_blocks=6,
26
+ max_channels: int = 512,
27
+ upsampling_mode: str = 'nearest',
28
+ block_args: Optional[BlockArgs] = None,
29
+ use_separable_convolution: bool = False):
30
+ if block_args is None:
31
+ block_args = BlockArgs(
32
+ normalization_layer_factory=InstanceNorm2dFactory(),
33
+ nonlinearity_factory=ReLUFactory(inplace=False))
34
+
35
+ self.block_args = block_args
36
+ self.upsampling_mode = upsampling_mode
37
+ self.max_channels = max_channels
38
+ self.num_bottleneck_blocks = num_bottleneck_blocks
39
+ self.bottleneck_image_size = bottleneck_image_size
40
+ self.start_channels = start_channels
41
+ self.num_pose_params = num_pose_params
42
+ self.image_channels = image_channels
43
+ self.image_size = image_size
44
+ self.use_separable_convolution = use_separable_convolution
45
+
46
+
47
+ class Editor07(Module):
48
+ def __init__(self, args: Editor07Args):
49
+ super().__init__()
50
+ self.args = args
51
+
52
+ self.body = ResizeConvUNet(ResizeConvUNetArgs(
53
+ image_size=args.image_size,
54
+ input_channels=2 * args.image_channels + args.num_pose_params + 2,
55
+ start_channels=args.start_channels,
56
+ bottleneck_image_size=args.bottleneck_image_size,
57
+ num_bottleneck_blocks=args.num_bottleneck_blocks,
58
+ max_channels=args.max_channels,
59
+ upsample_mode=args.upsampling_mode,
60
+ block_args=args.block_args,
61
+ use_separable_convolution=args.use_separable_convolution))
62
+ self.color_change_creator = Sequential(
63
+ create_conv3_from_block_args(
64
+ in_channels=self.args.start_channels,
65
+ out_channels=self.args.image_channels,
66
+ bias=True,
67
+ block_args=self.args.block_args),
68
+ Tanh())
69
+ self.alpha_creator = Sequential(
70
+ create_conv3_from_block_args(
71
+ in_channels=self.args.start_channels,
72
+ out_channels=self.args.image_channels,
73
+ bias=True,
74
+ block_args=self.args.block_args),
75
+ Sigmoid())
76
+ self.grid_change_creator = create_conv3(
77
+ in_channels=self.args.start_channels,
78
+ out_channels=2,
79
+ bias=False,
80
+ initialization_method='zero',
81
+ use_spectral_norm=False)
82
+ self.grid_change_applier = GridChangeApplier()
83
+
84
+ def forward(self,
85
+ input_original_image: Tensor,
86
+ input_warped_image: Tensor,
87
+ input_grid_change: Tensor,
88
+ pose: Tensor,
89
+ *args) -> List[Tensor]:
90
+ n, c = pose.shape
91
+ pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.image_size, self.args.image_size)
92
+ feature = torch.cat([input_original_image, input_warped_image, input_grid_change, pose], dim=1)
93
+
94
+ feature = self.body.forward(feature)[-1]
95
+ output_grid_change = input_grid_change + self.grid_change_creator(feature)
96
+
97
+ output_color_change = self.color_change_creator(feature)
98
+ output_color_change_alpha = self.alpha_creator(feature)
99
+ output_warped_image = self.grid_change_applier.apply(output_grid_change, input_original_image)
100
+ output_color_changed = apply_color_change(output_color_change_alpha, output_color_change, output_warped_image)
101
+
102
+ return [
103
+ output_color_changed,
104
+ output_color_change_alpha,
105
+ output_color_change,
106
+ output_warped_image,
107
+ output_grid_change,
108
+ ]
109
+
110
+ COLOR_CHANGED_IMAGE_INDEX = 0
111
+ COLOR_CHANGE_ALPHA_INDEX = 1
112
+ COLOR_CHANGE_IMAGE_INDEX = 2
113
+ WARPED_IMAGE_INDEX = 3
114
+ GRID_CHANGE_INDEX = 4
115
+ OUTPUT_LENGTH = 5
116
+
117
+
118
+ class Editor07Factory(ModuleFactory):
119
+ def __init__(self, args: Editor07Args):
120
+ super().__init__()
121
+ self.args = args
122
+
123
+ def create(self) -> Module:
124
+ return Editor07(self.args)
125
+
126
+
127
+ def show_image(pytorch_image):
128
+ numpy_image = ((pytorch_image + 1.0) / 2.0).squeeze(0).numpy()
129
+ numpy_image[0:3, :, :] = numpy_linear_to_srgb(numpy_image[0:3, :, :])
130
+ c, h, w = numpy_image.shape
131
+ numpy_image = numpy_image.reshape((c, h * w)).transpose().reshape((h, w, c))
132
+ pyplot.imshow(numpy_image)
133
+ pyplot.show()
134
+
135
+
136
+ if __name__ == "__main__":
137
+ cuda = torch.device('cuda')
138
+
139
+ image_size = 512
140
+ image_channels = 4
141
+ num_pose_params = 6
142
+ args = Editor07Args(
143
+ image_size=512,
144
+ image_channels=4,
145
+ start_channels=32,
146
+ num_pose_params=6,
147
+ bottleneck_image_size=32,
148
+ num_bottleneck_blocks=6,
149
+ max_channels=512,
150
+ upsampling_mode='nearest',
151
+ block_args=BlockArgs(
152
+ initialization_method='he',
153
+ use_spectral_norm=False,
154
+ normalization_layer_factory=InstanceNorm2dFactory(),
155
+ nonlinearity_factory=ReLUFactory(inplace=False)))
156
+ module = Editor07(args).to(cuda)
157
+
158
+ image_count = 1
159
+ input_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda)
160
+ direct_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda)
161
+ warped_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda)
162
+ grid_change = torch.zeros(image_count, 2, image_size, image_size, device=cuda)
163
+ pose = torch.zeros(image_count, num_pose_params, device=cuda)
164
+
165
+ repeat = 100
166
+ acc = 0.0
167
+ for i in range(repeat + 2):
168
+ start = torch.cuda.Event(enable_timing=True)
169
+ end = torch.cuda.Event(enable_timing=True)
170
+
171
+ start.record()
172
+ module.forward(input_image, warped_image, grid_change, pose)
173
+ end.record()
174
+ torch.cuda.synchronize()
175
+ if i >= 2:
176
+ elapsed_time = start.elapsed_time(end)
177
+ print("%d:" % i, elapsed_time)
178
+ acc = acc + elapsed_time
179
+
180
+ print("average:", acc / repeat)
tha3/nn/eyebrow_decomposer/__init__.py ADDED
File without changes
tha3/nn/eyebrow_decomposer/eyebrow_decomposer_00.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn import Module
6
+
7
+ from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args, PoserEncoderDecoder00
8
+ from tha3.nn.image_processing_util import apply_color_change
9
+ from tha3.module.module_factory import ModuleFactory
10
+ from tha3.nn.nonlinearity_factory import ReLUFactory
11
+ from tha3.nn.normalization import InstanceNorm2dFactory
12
+ from tha3.nn.util import BlockArgs
13
+
14
+
15
+ class EyebrowDecomposer00Args(PoserEncoderDecoder00Args):
16
+ def __init__(self,
17
+ image_size: int = 128,
18
+ image_channels: int = 4,
19
+ start_channels: int = 64,
20
+ bottleneck_image_size=16,
21
+ num_bottleneck_blocks=6,
22
+ max_channels: int = 512,
23
+ block_args: Optional[BlockArgs] = None):
24
+ super().__init__(
25
+ image_size,
26
+ image_channels,
27
+ image_channels,
28
+ 0,
29
+ start_channels,
30
+ bottleneck_image_size,
31
+ num_bottleneck_blocks,
32
+ max_channels,
33
+ block_args)
34
+
35
+
36
+ class EyebrowDecomposer00(Module):
37
+ def __init__(self, args: EyebrowDecomposer00Args):
38
+ super().__init__()
39
+ self.args = args
40
+ self.body = PoserEncoderDecoder00(args)
41
+ self.background_layer_alpha = self.args.create_alpha_block()
42
+ self.background_layer_color_change = self.args.create_color_change_block()
43
+ self.eyebrow_layer_alpha = self.args.create_alpha_block()
44
+ self.eyebrow_layer_color_change = self.args.create_color_change_block()
45
+
46
+ def forward(self, image: Tensor, *args) -> List[Tensor]:
47
+ feature = self.body(image)[0]
48
+
49
+ background_layer_alpha = self.background_layer_alpha(feature)
50
+ background_layer_color_change = self.background_layer_color_change(feature)
51
+ background_layer_1 = apply_color_change(background_layer_alpha, background_layer_color_change, image)
52
+
53
+ eyebrow_layer_alpha = self.eyebrow_layer_alpha(feature)
54
+ eyebrow_layer_color_change = self.eyebrow_layer_color_change(feature)
55
+ eyebrow_layer = apply_color_change(eyebrow_layer_alpha, image, eyebrow_layer_color_change)
56
+
57
+ return [
58
+ eyebrow_layer, # 0
59
+ eyebrow_layer_alpha, # 1
60
+ eyebrow_layer_color_change, # 2
61
+ background_layer_1, # 3
62
+ background_layer_alpha, # 4
63
+ background_layer_color_change, # 5
64
+ ]
65
+
66
+ EYEBROW_LAYER_INDEX = 0
67
+ EYEBROW_LAYER_ALPHA_INDEX = 1
68
+ EYEBROW_LAYER_COLOR_CHANGE_INDEX = 2
69
+ BACKGROUND_LAYER_INDEX = 3
70
+ BACKGROUND_LAYER_ALPHA_INDEX = 4
71
+ BACKGROUND_LAYER_COLOR_CHANGE_INDEX = 5
72
+ OUTPUT_LENGTH = 6
73
+
74
+
75
+ class EyebrowDecomposer00Factory(ModuleFactory):
76
+ def __init__(self, args: EyebrowDecomposer00Args):
77
+ super().__init__()
78
+ self.args = args
79
+
80
+ def create(self) -> Module:
81
+ return EyebrowDecomposer00(self.args)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ cuda = torch.device('cuda')
86
+ args = EyebrowDecomposer00Args(
87
+ image_size=128,
88
+ image_channels=4,
89
+ start_channels=64,
90
+ bottleneck_image_size=16,
91
+ num_bottleneck_blocks=3,
92
+ block_args=BlockArgs(
93
+ initialization_method='xavier',
94
+ use_spectral_norm=False,
95
+ normalization_layer_factory=InstanceNorm2dFactory(),
96
+ nonlinearity_factory=ReLUFactory(inplace=True)))
97
+ face_morpher = EyebrowDecomposer00(args).to(cuda)
98
+
99
+ image = torch.randn(8, 4, 128, 128, device=cuda)
100
+ outputs = face_morpher.forward(image)
101
+ for i in range(len(outputs)):
102
+ print(i, outputs[i].shape)
tha3/nn/eyebrow_decomposer/eyebrow_decomposer_03.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn import Module
6
+
7
+ from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args
8
+ from tha3.nn.common.poser_encoder_decoder_00_separable import PoserEncoderDecoder00Separable
9
+ from tha3.nn.image_processing_util import apply_color_change
10
+ from tha3.module.module_factory import ModuleFactory
11
+ from tha3.nn.nonlinearity_factory import ReLUFactory
12
+ from tha3.nn.normalization import InstanceNorm2dFactory
13
+ from tha3.nn.util import BlockArgs
14
+
15
+
16
+ class EyebrowDecomposer03Args(PoserEncoderDecoder00Args):
17
+ def __init__(self,
18
+ image_size: int = 128,
19
+ image_channels: int = 4,
20
+ start_channels: int = 64,
21
+ bottleneck_image_size=16,
22
+ num_bottleneck_blocks=6,
23
+ max_channels: int = 512,
24
+ block_args: Optional[BlockArgs] = None):
25
+ super().__init__(
26
+ image_size,
27
+ image_channels,
28
+ image_channels,
29
+ 0,
30
+ start_channels,
31
+ bottleneck_image_size,
32
+ num_bottleneck_blocks,
33
+ max_channels,
34
+ block_args)
35
+
36
+
37
+ class EyebrowDecomposer03(Module):
38
+ def __init__(self, args: EyebrowDecomposer03Args):
39
+ super().__init__()
40
+ self.args = args
41
+ self.body = PoserEncoderDecoder00Separable(args)
42
+ self.background_layer_alpha = self.args.create_alpha_block()
43
+ self.background_layer_color_change = self.args.create_color_change_block()
44
+ self.eyebrow_layer_alpha = self.args.create_alpha_block()
45
+ self.eyebrow_layer_color_change = self.args.create_color_change_block()
46
+
47
+ def forward(self, image: Tensor, *args) -> List[Tensor]:
48
+ feature = self.body(image)[0]
49
+
50
+ background_layer_alpha = self.background_layer_alpha(feature)
51
+ background_layer_color_change = self.background_layer_color_change(feature)
52
+ background_layer_1 = apply_color_change(background_layer_alpha, background_layer_color_change, image)
53
+
54
+ eyebrow_layer_alpha = self.eyebrow_layer_alpha(feature)
55
+ eyebrow_layer_color_change = self.eyebrow_layer_color_change(feature)
56
+ eyebrow_layer = apply_color_change(eyebrow_layer_alpha, image, eyebrow_layer_color_change)
57
+
58
+ return [
59
+ eyebrow_layer, # 0
60
+ eyebrow_layer_alpha, # 1
61
+ eyebrow_layer_color_change, # 2
62
+ background_layer_1, # 3
63
+ background_layer_alpha, # 4
64
+ background_layer_color_change, # 5
65
+ ]
66
+
67
+ EYEBROW_LAYER_INDEX = 0
68
+ EYEBROW_LAYER_ALPHA_INDEX = 1
69
+ EYEBROW_LAYER_COLOR_CHANGE_INDEX = 2
70
+ BACKGROUND_LAYER_INDEX = 3
71
+ BACKGROUND_LAYER_ALPHA_INDEX = 4
72
+ BACKGROUND_LAYER_COLOR_CHANGE_INDEX = 5
73
+ OUTPUT_LENGTH = 6
74
+
75
+
76
+ class EyebrowDecomposer03Factory(ModuleFactory):
77
+ def __init__(self, args: EyebrowDecomposer03Args):
78
+ super().__init__()
79
+ self.args = args
80
+
81
+ def create(self) -> Module:
82
+ return EyebrowDecomposer03(self.args)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ cuda = torch.device('cuda')
87
+ args = EyebrowDecomposer03Args(
88
+ image_size=128,
89
+ image_channels=4,
90
+ start_channels=64,
91
+ bottleneck_image_size=16,
92
+ num_bottleneck_blocks=6,
93
+ block_args=BlockArgs(
94
+ initialization_method='xavier',
95
+ use_spectral_norm=False,
96
+ normalization_layer_factory=InstanceNorm2dFactory(),
97
+ nonlinearity_factory=ReLUFactory(inplace=True)))
98
+ face_morpher = EyebrowDecomposer03(args).to(cuda)
99
+
100
+ #image = torch.randn(8, 4, 128, 128, device=cuda)
101
+ #outputs = face_morpher.forward(image)
102
+ #for i in range(len(outputs)):
103
+ # print(i, outputs[i].shape)
104
+
105
+ state_dict = face_morpher.state_dict()
106
+ index = 0
107
+ for key in state_dict:
108
+ print(f"[{index}]", key, state_dict[key].shape)
109
+ index += 1
tha3/nn/eyebrow_morphing_combiner/__init__.py ADDED
File without changes
tha3/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_00.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn import Module
6
+
7
+ from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args, PoserEncoderDecoder00
8
+ from tha3.nn.image_processing_util import apply_color_change, apply_grid_change, apply_rgb_change
9
+ from tha3.module.module_factory import ModuleFactory
10
+ from tha3.nn.nonlinearity_factory import ReLUFactory
11
+ from tha3.nn.normalization import InstanceNorm2dFactory
12
+ from tha3.nn.util import BlockArgs
13
+
14
+
15
+ class EyebrowMorphingCombiner00Args(PoserEncoderDecoder00Args):
16
+ def __init__(self,
17
+ image_size: int = 128,
18
+ image_channels: int = 4,
19
+ num_pose_params: int = 12,
20
+ start_channels: int = 64,
21
+ bottleneck_image_size=16,
22
+ num_bottleneck_blocks=6,
23
+ max_channels: int = 512,
24
+ block_args: Optional[BlockArgs] = None):
25
+ super().__init__(
26
+ image_size,
27
+ 2 * image_channels,
28
+ image_channels,
29
+ num_pose_params,
30
+ start_channels,
31
+ bottleneck_image_size,
32
+ num_bottleneck_blocks,
33
+ max_channels,
34
+ block_args)
35
+
36
+
37
+ class EyebrowMorphingCombiner00(Module):
38
+ def __init__(self, args: EyebrowMorphingCombiner00Args):
39
+ super().__init__()
40
+ self.args = args
41
+ self.body = PoserEncoderDecoder00(args)
42
+ self.morphed_eyebrow_layer_grid_change = self.args.create_grid_change_block()
43
+ self.morphed_eyebrow_layer_alpha = self.args.create_alpha_block()
44
+ self.morphed_eyebrow_layer_color_change = self.args.create_color_change_block()
45
+ self.combine_alpha = self.args.create_alpha_block()
46
+
47
+ def forward(self, background_layer: Tensor, eyebrow_layer: Tensor, pose: Tensor, *args) -> List[Tensor]:
48
+ combined_image = torch.cat([background_layer, eyebrow_layer], dim=1)
49
+ feature = self.body(combined_image, pose)[0]
50
+
51
+ morphed_eyebrow_layer_grid_change = self.morphed_eyebrow_layer_grid_change(feature)
52
+ morphed_eyebrow_layer_alpha = self.morphed_eyebrow_layer_alpha(feature)
53
+ morphed_eyebrow_layer_color_change = self.morphed_eyebrow_layer_color_change(feature)
54
+ warped_eyebrow_layer = apply_grid_change(morphed_eyebrow_layer_grid_change, eyebrow_layer)
55
+ morphed_eyebrow_layer = apply_color_change(
56
+ morphed_eyebrow_layer_alpha, morphed_eyebrow_layer_color_change, warped_eyebrow_layer)
57
+
58
+ combine_alpha = self.combine_alpha(feature)
59
+ eyebrow_image = apply_rgb_change(combine_alpha, morphed_eyebrow_layer, background_layer)
60
+ eyebrow_image_no_combine_alpha = apply_rgb_change(
61
+ (morphed_eyebrow_layer[:, 3:4, :, :] + 1.0) / 2.0, morphed_eyebrow_layer, background_layer)
62
+
63
+ return [
64
+ eyebrow_image, # 0
65
+ combine_alpha, # 1
66
+ eyebrow_image_no_combine_alpha, # 2
67
+ morphed_eyebrow_layer, # 3
68
+ morphed_eyebrow_layer_alpha, # 4
69
+ morphed_eyebrow_layer_color_change, # 5
70
+ warped_eyebrow_layer, # 6
71
+ morphed_eyebrow_layer_grid_change, # 7
72
+ ]
73
+
74
+ EYEBROW_IMAGE_INDEX = 0
75
+ COMBINE_ALPHA_INDEX = 1
76
+ EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX = 2
77
+ MORPHED_EYEBROW_LAYER_INDEX = 3
78
+ MORPHED_EYEBROW_LAYER_ALPHA_INDEX = 4
79
+ MORPHED_EYEBROW_LAYER_COLOR_CHANGE_INDEX = 5
80
+ WARPED_EYEBROW_LAYER_INDEX = 6
81
+ MORPHED_EYEBROW_LAYER_GRID_CHANGE_INDEX = 7
82
+ OUTPUT_LENGTH = 8
83
+
84
+
85
+ class EyebrowMorphingCombiner00Factory(ModuleFactory):
86
+ def __init__(self, args: EyebrowMorphingCombiner00Args):
87
+ super().__init__()
88
+ self.args = args
89
+
90
+ def create(self) -> Module:
91
+ return EyebrowMorphingCombiner00(self.args)
92
+
93
+
94
+ if __name__ == "__main__":
95
+ cuda = torch.device('cuda')
96
+ args = EyebrowMorphingCombiner00Args(
97
+ image_size=128,
98
+ image_channels=4,
99
+ num_pose_params=12,
100
+ start_channels=64,
101
+ bottleneck_image_size=16,
102
+ num_bottleneck_blocks=3,
103
+ block_args=BlockArgs(
104
+ initialization_method='xavier',
105
+ use_spectral_norm=False,
106
+ normalization_layer_factory=InstanceNorm2dFactory(),
107
+ nonlinearity_factory=ReLUFactory(inplace=True)))
108
+ face_morpher = EyebrowMorphingCombiner00(args).to(cuda)
109
+
110
+ background_layer = torch.randn(8, 4, 128, 128, device=cuda)
111
+ eyebrow_layer = torch.randn(8, 4, 128, 128, device=cuda)
112
+ pose = torch.randn(8, 12, device=cuda)
113
+ outputs = face_morpher.forward(background_layer, eyebrow_layer, pose)
114
+ for i in range(len(outputs)):
115
+ print(i, outputs[i].shape)
tha3/nn/eyebrow_morphing_combiner/eyebrow_morphing_combiner_03.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn import Module
6
+
7
+ from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args
8
+ from tha3.nn.common.poser_encoder_decoder_00_separable import PoserEncoderDecoder00Separable
9
+ from tha3.nn.image_processing_util import apply_color_change, apply_rgb_change, GridChangeApplier
10
+ from tha3.module.module_factory import ModuleFactory
11
+ from tha3.nn.nonlinearity_factory import ReLUFactory
12
+ from tha3.nn.normalization import InstanceNorm2dFactory
13
+ from tha3.nn.util import BlockArgs
14
+
15
+
16
+ class EyebrowMorphingCombiner03Args(PoserEncoderDecoder00Args):
17
+ def __init__(self,
18
+ image_size: int = 128,
19
+ image_channels: int = 4,
20
+ num_pose_params: int = 12,
21
+ start_channels: int = 64,
22
+ bottleneck_image_size=16,
23
+ num_bottleneck_blocks=6,
24
+ max_channels: int = 512,
25
+ block_args: Optional[BlockArgs] = None):
26
+ super().__init__(
27
+ image_size,
28
+ 2 * image_channels,
29
+ image_channels,
30
+ num_pose_params,
31
+ start_channels,
32
+ bottleneck_image_size,
33
+ num_bottleneck_blocks,
34
+ max_channels,
35
+ block_args)
36
+
37
+
38
+ class EyebrowMorphingCombiner03(Module):
39
+ def __init__(self, args: EyebrowMorphingCombiner03Args):
40
+ super().__init__()
41
+ self.args = args
42
+ self.body = PoserEncoderDecoder00Separable(args)
43
+ self.morphed_eyebrow_layer_grid_change = self.args.create_grid_change_block()
44
+ self.morphed_eyebrow_layer_alpha = self.args.create_alpha_block()
45
+ self.morphed_eyebrow_layer_color_change = self.args.create_color_change_block()
46
+ self.combine_alpha = self.args.create_alpha_block()
47
+ self.grid_change_applier = GridChangeApplier()
48
+
49
+ def forward(self, background_layer: Tensor, eyebrow_layer: Tensor, pose: Tensor, *args) -> List[Tensor]:
50
+ combined_image = torch.cat([background_layer, eyebrow_layer], dim=1)
51
+ feature = self.body(combined_image, pose)[0]
52
+
53
+ morphed_eyebrow_layer_grid_change = self.morphed_eyebrow_layer_grid_change(feature)
54
+ morphed_eyebrow_layer_alpha = self.morphed_eyebrow_layer_alpha(feature)
55
+ morphed_eyebrow_layer_color_change = self.morphed_eyebrow_layer_color_change(feature)
56
+ warped_eyebrow_layer = self.grid_change_applier.apply(morphed_eyebrow_layer_grid_change, eyebrow_layer)
57
+ morphed_eyebrow_layer = apply_color_change(
58
+ morphed_eyebrow_layer_alpha, morphed_eyebrow_layer_color_change, warped_eyebrow_layer)
59
+
60
+ combine_alpha = self.combine_alpha(feature)
61
+ eyebrow_image = apply_rgb_change(combine_alpha, morphed_eyebrow_layer, background_layer)
62
+ eyebrow_image_no_combine_alpha = apply_rgb_change(
63
+ (morphed_eyebrow_layer[:, 3:4, :, :] + 1.0) / 2.0, morphed_eyebrow_layer, background_layer)
64
+
65
+ return [
66
+ eyebrow_image, # 0
67
+ combine_alpha, # 1
68
+ eyebrow_image_no_combine_alpha, # 2
69
+ morphed_eyebrow_layer, # 3
70
+ morphed_eyebrow_layer_alpha, # 4
71
+ morphed_eyebrow_layer_color_change, # 5
72
+ warped_eyebrow_layer, # 6
73
+ morphed_eyebrow_layer_grid_change, # 7
74
+ ]
75
+
76
+ EYEBROW_IMAGE_INDEX = 0
77
+ COMBINE_ALPHA_INDEX = 1
78
+ EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX = 2
79
+ MORPHED_EYEBROW_LAYER_INDEX = 3
80
+ MORPHED_EYEBROW_LAYER_ALPHA_INDEX = 4
81
+ MORPHED_EYEBROW_LAYER_COLOR_CHANGE_INDEX = 5
82
+ WARPED_EYEBROW_LAYER_INDEX = 6
83
+ MORPHED_EYEBROW_LAYER_GRID_CHANGE_INDEX = 7
84
+ OUTPUT_LENGTH = 8
85
+
86
+
87
+ class EyebrowMorphingCombiner03Factory(ModuleFactory):
88
+ def __init__(self, args: EyebrowMorphingCombiner03Args):
89
+ super().__init__()
90
+ self.args = args
91
+
92
+ def create(self) -> Module:
93
+ return EyebrowMorphingCombiner03(self.args)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ cuda = torch.device('cuda')
98
+ args = EyebrowMorphingCombiner03Args(
99
+ image_size=128,
100
+ image_channels=4,
101
+ num_pose_params=12,
102
+ start_channels=64,
103
+ bottleneck_image_size=16,
104
+ num_bottleneck_blocks=3,
105
+ block_args=BlockArgs(
106
+ initialization_method='xavier',
107
+ use_spectral_norm=False,
108
+ normalization_layer_factory=InstanceNorm2dFactory(),
109
+ nonlinearity_factory=ReLUFactory(inplace=True)))
110
+ face_morpher = EyebrowMorphingCombiner03(args).to(cuda)
111
+
112
+ background_layer = torch.randn(8, 4, 128, 128, device=cuda)
113
+ eyebrow_layer = torch.randn(8, 4, 128, 128, device=cuda)
114
+ pose = torch.randn(8, 12, device=cuda)
115
+ outputs = face_morpher.forward(background_layer, eyebrow_layer, pose)
116
+ for i in range(len(outputs)):
117
+ print(i, outputs[i].shape)
tha3/nn/face_morpher/__init__.py ADDED
File without changes
tha3/nn/face_morpher/face_morpher_08.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import ModuleList, Sequential, Sigmoid, Tanh, Module
7
+ from torch.nn.functional import affine_grid, grid_sample
8
+
9
+ from tha3.module.module_factory import ModuleFactory
10
+ from tha3.nn.conv import create_conv3_block_from_block_args, \
11
+ create_downsample_block_from_block_args, create_upsample_block_from_block_args, create_conv3_from_block_args, \
12
+ create_conv3
13
+ from tha3.nn.nonlinearity_factory import LeakyReLUFactory
14
+ from tha3.nn.normalization import InstanceNorm2dFactory
15
+ from tha3.nn.resnet_block import ResnetBlock
16
+ from tha3.nn.util import BlockArgs
17
+
18
+
19
+ class FaceMorpher08Args:
20
+ def __init__(self,
21
+ image_size: int = 256,
22
+ image_channels: int = 4,
23
+ num_expression_params: int = 67,
24
+ start_channels: int = 16,
25
+ bottleneck_image_size=4,
26
+ num_bottleneck_blocks=3,
27
+ max_channels: int = 512,
28
+ block_args: Optional[BlockArgs] = None):
29
+ self.max_channels = max_channels
30
+ self.num_bottleneck_blocks = num_bottleneck_blocks
31
+ assert bottleneck_image_size > 1
32
+ self.bottleneck_image_size = bottleneck_image_size
33
+ self.start_channels = start_channels
34
+ self.image_channels = image_channels
35
+ self.num_expression_params = num_expression_params
36
+ self.image_size = image_size
37
+
38
+ if block_args is None:
39
+ self.block_args = BlockArgs(
40
+ normalization_layer_factory=InstanceNorm2dFactory(),
41
+ nonlinearity_factory=LeakyReLUFactory(negative_slope=0.2, inplace=True))
42
+ else:
43
+ self.block_args = block_args
44
+
45
+
46
+ class FaceMorpher08(Module):
47
+ def __init__(self, args: FaceMorpher08Args):
48
+ super().__init__()
49
+ self.args = args
50
+ self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1
51
+
52
+ self.downsample_blocks = ModuleList()
53
+ self.downsample_blocks.append(
54
+ create_conv3_block_from_block_args(
55
+ args.image_channels,
56
+ args.start_channels,
57
+ args.block_args))
58
+ current_image_size = args.image_size
59
+ current_num_channels = args.start_channels
60
+ while current_image_size > args.bottleneck_image_size:
61
+ next_image_size = current_image_size // 2
62
+ next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
63
+ self.downsample_blocks.append(create_downsample_block_from_block_args(
64
+ in_channels=current_num_channels,
65
+ out_channels=next_num_channels,
66
+ is_output_1x1=False,
67
+ block_args=args.block_args))
68
+ current_image_size = next_image_size
69
+ current_num_channels = next_num_channels
70
+ assert len(self.downsample_blocks) == self.num_levels
71
+
72
+ self.bottleneck_blocks = ModuleList()
73
+ self.bottleneck_blocks.append(create_conv3_block_from_block_args(
74
+ in_channels=current_num_channels + args.num_expression_params,
75
+ out_channels=current_num_channels,
76
+ block_args=args.block_args))
77
+ for i in range(1, args.num_bottleneck_blocks):
78
+ self.bottleneck_blocks.append(
79
+ ResnetBlock.create(
80
+ num_channels=current_num_channels,
81
+ is1x1=False,
82
+ block_args=args.block_args))
83
+
84
+ self.upsample_blocks = ModuleList()
85
+ while current_image_size < args.image_size:
86
+ next_image_size = current_image_size * 2
87
+ next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
88
+ self.upsample_blocks.append(create_upsample_block_from_block_args(
89
+ in_channels=current_num_channels,
90
+ out_channels=next_num_channels,
91
+ block_args=args.block_args))
92
+ current_image_size = next_image_size
93
+ current_num_channels = next_num_channels
94
+
95
+ self.iris_mouth_grid_change = self.create_grid_change_block()
96
+ self.iris_mouth_color_change = self.create_color_change_block()
97
+ self.iris_mouth_alpha = self.create_alpha_block()
98
+
99
+ self.eye_color_change = self.create_color_change_block()
100
+ self.eye_alpha = self.create_alpha_block()
101
+
102
+ def create_alpha_block(self):
103
+ return Sequential(
104
+ create_conv3(
105
+ in_channels=self.args.start_channels,
106
+ out_channels=1,
107
+ bias=True,
108
+ initialization_method=self.args.block_args.initialization_method,
109
+ use_spectral_norm=False),
110
+ Sigmoid())
111
+
112
+ def create_color_change_block(self):
113
+ return Sequential(
114
+ create_conv3_from_block_args(
115
+ in_channels=self.args.start_channels,
116
+ out_channels=self.args.image_channels,
117
+ bias=True,
118
+ block_args=self.args.block_args),
119
+ Tanh())
120
+
121
+ def create_grid_change_block(self):
122
+ return create_conv3(
123
+ in_channels=self.args.start_channels,
124
+ out_channels=2,
125
+ bias=False,
126
+ initialization_method='zero',
127
+ use_spectral_norm=False)
128
+
129
+ def get_num_output_channels_from_level(self, level: int):
130
+ return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))
131
+
132
+ def get_num_output_channels_from_image_size(self, image_size: int):
133
+ return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)
134
+
135
+ def merge_down(self, top_layer: Tensor, bottom_layer: Tensor):
136
+ top_layer_rgb = top_layer[:, 0:3, :, :]
137
+ top_layer_a = top_layer[:, 3:4, :, :]
138
+ return bottom_layer * (1-top_layer_a) + torch.cat([top_layer_rgb * top_layer_a, top_layer_a], dim=1)
139
+
140
+ def apply_grid_change(self, grid_change, image: Tensor) -> Tensor:
141
+ n, c, h, w = image.shape
142
+ device = grid_change.device
143
+ grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
144
+ identity = torch.tensor(
145
+ [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
146
+ device=device,
147
+ dtype=grid_change.dtype).unsqueeze(0).repeat(n, 1, 1)
148
+ base_grid = affine_grid(identity, [n, c, h, w], align_corners=False)
149
+ grid = base_grid + grid_change
150
+ resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False)
151
+ return resampled_image
152
+
153
+ def apply_color_change(self, alpha, color_change, image: Tensor) -> Tensor:
154
+ return color_change * alpha + image * (1 - alpha)
155
+
156
+ def forward(self, image: Tensor, pose: Tensor, *args) -> List[Tensor]:
157
+ feature = image
158
+ for block in self.downsample_blocks:
159
+ feature = block(feature)
160
+ n, c = pose.shape
161
+ pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)
162
+ feature = torch.cat([feature, pose], dim=1)
163
+ for block in self.bottleneck_blocks:
164
+ feature = block(feature)
165
+ for block in self.upsample_blocks:
166
+ feature = block(feature)
167
+
168
+ iris_mouth_grid_change = self.iris_mouth_grid_change(feature)
169
+ iris_mouth_image_0 = self.apply_grid_change(iris_mouth_grid_change, image)
170
+ iris_mouth_color_change = self.iris_mouth_color_change(feature)
171
+ iris_mouth_alpha = self.iris_mouth_alpha(feature)
172
+ iris_mouth_image_1 = self.apply_color_change(iris_mouth_alpha, iris_mouth_color_change, iris_mouth_image_0)
173
+
174
+ eye_color_change = self.eye_color_change(feature)
175
+ eye_alpha = self.eye_alpha(feature)
176
+ output_image = self.apply_color_change(eye_alpha, eye_color_change, iris_mouth_image_1.detach())
177
+
178
+ return [
179
+ output_image, #0
180
+ eye_alpha, #1
181
+ eye_color_change, #2
182
+ iris_mouth_image_1, #3
183
+ iris_mouth_alpha, #4
184
+ iris_mouth_color_change, #5
185
+ iris_mouth_image_0, #6
186
+ ]
187
+
188
+ OUTPUT_IMAGE_INDEX = 0
189
+ EYE_ALPHA_INDEX = 1
190
+ EYE_COLOR_CHANGE_INDEX = 2
191
+ IRIS_MOUTH_IMAGE_1_INDEX = 3
192
+ IRIS_MOUTH_ALPHA_INDEX = 4
193
+ IRIS_MOUTH_COLOR_CHANGE_INDEX = 5
194
+ IRIS_MOUTh_IMAGE_0_INDEX = 6
195
+
196
+
197
+ class FaceMorpher08Factory(ModuleFactory):
198
+ def __init__(self, args: FaceMorpher08Args):
199
+ super().__init__()
200
+ self.args = args
201
+
202
+ def create(self) -> Module:
203
+ return FaceMorpher08(self.args)
204
+
205
+
206
+ if __name__ == "__main__":
207
+ cuda = torch.device('cuda')
208
+ args = FaceMorpher08Args(
209
+ image_size=256,
210
+ image_channels=4,
211
+ num_expression_params=12,
212
+ start_channels=64,
213
+ bottleneck_image_size=32,
214
+ num_bottleneck_blocks=6,
215
+ block_args=BlockArgs(
216
+ initialization_method='he',
217
+ use_spectral_norm=False,
218
+ normalization_layer_factory=InstanceNorm2dFactory(),
219
+ nonlinearity_factory=LeakyReLUFactory(inplace=True, negative_slope=0.2)))
220
+ module = FaceMorpher08(args).to(cuda)
221
+
222
+ image = torch.zeros(16, 4, 256, 256, device=cuda)
223
+ pose = torch.zeros(16, 12, device=cuda)
224
+
225
+ repeat = 100
226
+ acc = 0.0
227
+ for i in range(repeat + 2):
228
+ start = torch.cuda.Event(enable_timing=True)
229
+ end = torch.cuda.Event(enable_timing=True)
230
+
231
+ start.record()
232
+ module.forward(image, pose)
233
+ end.record()
234
+ torch.cuda.synchronize()
235
+
236
+ if i >= 2:
237
+ elapsed_time = start.elapsed_time(end)
238
+ print("%d:" % i, elapsed_time)
239
+ acc += elapsed_time
240
+
241
+ print("average:", acc / repeat)
tha3/nn/face_morpher/face_morpher_09.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn import Sequential, Sigmoid, Tanh, Module
6
+ from torch.nn.functional import affine_grid, grid_sample
7
+
8
+ from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args
9
+ from tha3.nn.common.poser_encoder_decoder_00_separable import PoserEncoderDecoder00Separable
10
+ from tha3.nn.image_processing_util import GridChangeApplier
11
+ from tha3.module.module_factory import ModuleFactory
12
+ from tha3.nn.conv import create_conv3_from_block_args, create_conv3
13
+ from tha3.nn.nonlinearity_factory import LeakyReLUFactory
14
+ from tha3.nn.normalization import InstanceNorm2dFactory
15
+ from tha3.nn.util import BlockArgs
16
+
17
+
18
+ class FaceMorpher09Args(PoserEncoderDecoder00Args):
19
+ def __init__(self,
20
+ image_size: int = 256,
21
+ image_channels: int = 4,
22
+ num_pose_params: int = 67,
23
+ start_channels: int = 16,
24
+ bottleneck_image_size=4,
25
+ num_bottleneck_blocks=3,
26
+ max_channels: int = 512,
27
+ block_args: Optional[BlockArgs] = None):
28
+ super().__init__(
29
+ image_size,
30
+ image_channels,
31
+ image_channels,
32
+ num_pose_params,
33
+ start_channels,
34
+ bottleneck_image_size,
35
+ num_bottleneck_blocks,
36
+ max_channels,
37
+ block_args)
38
+
39
+
40
+ class FaceMorpher09(Module):
41
+ def __init__(self, args: FaceMorpher09Args):
42
+ super().__init__()
43
+ self.args = args
44
+ self.body = PoserEncoderDecoder00Separable(args)
45
+
46
+ self.iris_mouth_grid_change = self.create_grid_change_block()
47
+ self.iris_mouth_color_change = self.create_color_change_block()
48
+ self.iris_mouth_alpha = self.create_alpha_block()
49
+
50
+ self.eye_color_change = self.create_color_change_block()
51
+ self.eye_alpha = self.create_alpha_block()
52
+
53
+ self.grid_change_applier = GridChangeApplier()
54
+
55
+ def create_alpha_block(self):
56
+ return Sequential(
57
+ create_conv3(
58
+ in_channels=self.args.start_channels,
59
+ out_channels=1,
60
+ bias=True,
61
+ initialization_method=self.args.block_args.initialization_method,
62
+ use_spectral_norm=False),
63
+ Sigmoid())
64
+
65
+ def create_color_change_block(self):
66
+ return Sequential(
67
+ create_conv3_from_block_args(
68
+ in_channels=self.args.start_channels,
69
+ out_channels=self.args.input_image_channels,
70
+ bias=True,
71
+ block_args=self.args.block_args),
72
+ Tanh())
73
+
74
+ def create_grid_change_block(self):
75
+ return create_conv3(
76
+ in_channels=self.args.start_channels,
77
+ out_channels=2,
78
+ bias=False,
79
+ initialization_method='zero',
80
+ use_spectral_norm=False)
81
+
82
+ def get_num_output_channels_from_level(self, level: int):
83
+ return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))
84
+
85
+ def get_num_output_channels_from_image_size(self, image_size: int):
86
+ return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)
87
+
88
+ def forward(self, image: Tensor, pose: Tensor, *args) -> List[Tensor]:
89
+ feature = self.body(image, pose)[0]
90
+
91
+ iris_mouth_grid_change = self.iris_mouth_grid_change(feature)
92
+ iris_mouth_image_0 = self.grid_change_applier.apply(iris_mouth_grid_change, image)
93
+ iris_mouth_color_change = self.iris_mouth_color_change(feature)
94
+ iris_mouth_alpha = self.iris_mouth_alpha(feature)
95
+ iris_mouth_image_1 = self.apply_color_change(iris_mouth_alpha, iris_mouth_color_change, iris_mouth_image_0)
96
+
97
+ eye_color_change = self.eye_color_change(feature)
98
+ eye_alpha = self.eye_alpha(feature)
99
+ output_image = self.apply_color_change(eye_alpha, eye_color_change, iris_mouth_image_1.detach())
100
+
101
+ return [
102
+ output_image, # 0
103
+ eye_alpha, # 1
104
+ eye_color_change, # 2
105
+ iris_mouth_image_1, # 3
106
+ iris_mouth_alpha, # 4
107
+ iris_mouth_color_change, # 5
108
+ iris_mouth_image_0, # 6
109
+ ]
110
+
111
+ OUTPUT_IMAGE_INDEX = 0
112
+ EYE_ALPHA_INDEX = 1
113
+ EYE_COLOR_CHANGE_INDEX = 2
114
+ IRIS_MOUTH_IMAGE_1_INDEX = 3
115
+ IRIS_MOUTH_ALPHA_INDEX = 4
116
+ IRIS_MOUTH_COLOR_CHANGE_INDEX = 5
117
+ IRIS_MOUTh_IMAGE_0_INDEX = 6
118
+
119
+ def merge_down(self, top_layer: Tensor, bottom_layer: Tensor):
120
+ top_layer_rgb = top_layer[:, 0:3, :, :]
121
+ top_layer_a = top_layer[:, 3:4, :, :]
122
+ return bottom_layer * (1 - top_layer_a) + torch.cat([top_layer_rgb * top_layer_a, top_layer_a], dim=1)
123
+
124
+ def apply_grid_change(self, grid_change, image: Tensor) -> Tensor:
125
+ n, c, h, w = image.shape
126
+ device = grid_change.device
127
+ grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
128
+ identity = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device).unsqueeze(0).repeat(n, 1, 1)
129
+ base_grid = affine_grid(identity, [n, c, h, w], align_corners=False)
130
+ grid = base_grid + grid_change
131
+ resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False)
132
+ return resampled_image
133
+
134
+ def apply_color_change(self, alpha, color_change, image: Tensor) -> Tensor:
135
+ return color_change * alpha + image * (1 - alpha)
136
+
137
+
138
+ class FaceMorpher09Factory(ModuleFactory):
139
+ def __init__(self, args: FaceMorpher09Args):
140
+ super().__init__()
141
+ self.args = args
142
+
143
+ def create(self) -> Module:
144
+ return FaceMorpher09(self.args)
145
+
146
+
147
+ if __name__ == "__main__":
148
+ cuda = torch.device('cuda')
149
+ args = FaceMorpher09Args(
150
+ image_size=256,
151
+ image_channels=4,
152
+ num_pose_params=12,
153
+ start_channels=64,
154
+ bottleneck_image_size=32,
155
+ num_bottleneck_blocks=6,
156
+ block_args=BlockArgs(
157
+ initialization_method='xavier',
158
+ use_spectral_norm=False,
159
+ normalization_layer_factory=InstanceNorm2dFactory(),
160
+ nonlinearity_factory=LeakyReLUFactory(inplace=True, negative_slope=0.2)))
161
+ module = FaceMorpher09(args).to(cuda)
162
+
163
+ image = torch.zeros(16, 4, 256, 256, device=cuda)
164
+ pose = torch.zeros(16, 12, device=cuda)
165
+
166
+ state_dict = module.state_dict()
167
+ for key in state_dict:
168
+ print(key, state_dict[key].shape)
169
+
170
+ if False:
171
+ repeat = 100
172
+ acc = 0.0
173
+ for i in range(repeat + 2):
174
+ start = torch.cuda.Event(enable_timing=True)
175
+ end = torch.cuda.Event(enable_timing=True)
176
+
177
+ start.record()
178
+ module.forward(image, pose)
179
+ end.record()
180
+ torch.cuda.synchronize()
181
+
182
+ if i >= 2:
183
+ elapsed_time = start.elapsed_time(end)
184
+ print("%d:" % i, elapsed_time)
185
+ acc += elapsed_time
186
+
187
+ print("average:", acc / repeat)
tha3/nn/image_processing_util.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from torch.nn.functional import affine_grid, grid_sample
4
+
5
+
6
+ def apply_rgb_change(alpha: Tensor, color_change: Tensor, image: Tensor):
7
+ image_rgb = image[:, 0:3, :, :]
8
+ color_change_rgb = color_change[:, 0:3, :, :]
9
+ output_rgb = color_change_rgb * alpha + image_rgb * (1 - alpha)
10
+ return torch.cat([output_rgb, image[:, 3:4, :, :]], dim=1)
11
+
12
+
13
+ def apply_grid_change(grid_change, image: Tensor) -> Tensor:
14
+ n, c, h, w = image.shape
15
+ device = grid_change.device
16
+ grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
17
+ identity = torch.tensor(
18
+ [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
19
+ dtype=grid_change.dtype,
20
+ device=device).unsqueeze(0).repeat(n, 1, 1)
21
+ base_grid = affine_grid(identity, [n, c, h, w], align_corners=False)
22
+ grid = base_grid + grid_change
23
+ resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False)
24
+ return resampled_image
25
+
26
+
27
+ class GridChangeApplier:
28
+ def __init__(self):
29
+ self.last_n = None
30
+ self.last_device = None
31
+ self.last_identity = None
32
+
33
+ def apply(self, grid_change: Tensor, image: Tensor, align_corners: bool = False) -> Tensor:
34
+ n, c, h, w = image.shape
35
+ device = grid_change.device
36
+ grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
37
+
38
+ if n == self.last_n and device == self.last_device:
39
+ identity = self.last_identity
40
+ else:
41
+ identity = torch.tensor(
42
+ [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
43
+ dtype=grid_change.dtype,
44
+ device=device,
45
+ requires_grad=False) \
46
+ .unsqueeze(0).repeat(n, 1, 1)
47
+ self.last_identity = identity
48
+ self.last_n = n
49
+ self.last_device = device
50
+ base_grid = affine_grid(identity, [n, c, h, w], align_corners=align_corners)
51
+
52
+ grid = base_grid + grid_change
53
+ resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=align_corners)
54
+ return resampled_image
55
+
56
+
57
+ def apply_color_change(alpha, color_change, image: Tensor) -> Tensor:
58
+ return color_change * alpha + image * (1 - alpha)
tha3/nn/init_function.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ from torch import zero_
5
+ from torch.nn import Module
6
+ from torch.nn.init import kaiming_normal_, xavier_normal_, normal_
7
+
8
+
9
+ def create_init_function(method: str = 'none') -> Callable[[Module], Module]:
10
+ def init(module: Module):
11
+ if method == 'none':
12
+ return module
13
+ elif method == 'he':
14
+ kaiming_normal_(module.weight)
15
+ return module
16
+ elif method == 'xavier':
17
+ xavier_normal_(module.weight)
18
+ return module
19
+ elif method == 'dcgan':
20
+ normal_(module.weight, 0.0, 0.02)
21
+ return module
22
+ elif method == 'dcgan_001':
23
+ normal_(module.weight, 0.0, 0.01)
24
+ return module
25
+ elif method == "zero":
26
+ with torch.no_grad():
27
+ zero_(module.weight)
28
+ return module
29
+ else:
30
+ raise ("Invalid initialization method %s" % method)
31
+
32
+ return init
33
+
34
+
35
+ class HeInitialization:
36
+ def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'):
37
+ self.nonlinearity = nonlinearity
38
+ self.mode = mode
39
+ self.a = a
40
+
41
+ def __call__(self, module: Module) -> Module:
42
+ with torch.no_grad():
43
+ kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity)
44
+ return module
45
+
46
+
47
+ class NormalInitialization:
48
+ def __init__(self, mean: float = 0.0, std: float = 1.0):
49
+ self.std = std
50
+ self.mean = mean
51
+
52
+ def __call__(self, module: Module) -> Module:
53
+ with torch.no_grad():
54
+ normal_(module.weight, self.mean, self.std)
55
+ return module
56
+
57
+
58
+ class XavierInitialization:
59
+ def __init__(self, gain: float = 1.0):
60
+ self.gain = gain
61
+
62
+ def __call__(self, module: Module) -> Module:
63
+ with torch.no_grad():
64
+ xavier_normal_(module.weight, self.gain)
65
+ return module
66
+
67
+
68
+ class ZeroInitialization:
69
+ def __call__(self, module: Module) -> Module:
70
+ with torch.no_grad:
71
+ zero_(module.weight)
72
+ return module
73
+
74
+ class NoInitialization:
75
+ def __call__(self, module: Module) -> Module:
76
+ return module
tha3/nn/nonlinearity_factory.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from torch.nn import Module, ReLU, LeakyReLU, ELU, ReLU6, Hardswish, SiLU, Tanh, Sigmoid
4
+
5
+ from tha3.module.module_factory import ModuleFactory
6
+
7
+
8
+ class ReLUFactory(ModuleFactory):
9
+ def __init__(self, inplace: bool = False):
10
+ self.inplace = inplace
11
+
12
+ def create(self) -> Module:
13
+ return ReLU(self.inplace)
14
+
15
+
16
+ class LeakyReLUFactory(ModuleFactory):
17
+ def __init__(self, inplace: bool = False, negative_slope: float = 1e-2):
18
+ self.negative_slope = negative_slope
19
+ self.inplace = inplace
20
+
21
+ def create(self) -> Module:
22
+ return LeakyReLU(inplace=self.inplace, negative_slope=self.negative_slope)
23
+
24
+
25
+ class ELUFactory(ModuleFactory):
26
+ def __init__(self, inplace: bool = False, alpha: float = 1.0):
27
+ self.alpha = alpha
28
+ self.inplace = inplace
29
+
30
+ def create(self) -> Module:
31
+ return ELU(inplace=self.inplace, alpha=self.alpha)
32
+
33
+
34
+ class ReLU6Factory(ModuleFactory):
35
+ def __init__(self, inplace: bool = False):
36
+ self.inplace = inplace
37
+
38
+ def create(self) -> Module:
39
+ return ReLU6(inplace=self.inplace)
40
+
41
+
42
+ class SiLUFactory(ModuleFactory):
43
+ def __init__(self, inplace: bool = False):
44
+ self.inplace = inplace
45
+
46
+ def create(self) -> Module:
47
+ return SiLU(inplace=self.inplace)
48
+
49
+
50
+ class HardswishFactory(ModuleFactory):
51
+ def __init__(self, inplace: bool = False):
52
+ self.inplace = inplace
53
+
54
+ def create(self) -> Module:
55
+ return Hardswish(inplace=self.inplace)
56
+
57
+
58
+ class TanhFactory(ModuleFactory):
59
+ def create(self) -> Module:
60
+ return Tanh()
61
+
62
+
63
+ class SigmoidFactory(ModuleFactory):
64
+ def create(self) -> Module:
65
+ return Sigmoid()
66
+
67
+
68
+ def resolve_nonlinearity_factory(nonlinearity_fatory: Optional[ModuleFactory]) -> ModuleFactory:
69
+ if nonlinearity_fatory is None:
70
+ return ReLUFactory(inplace=False)
71
+ else:
72
+ return nonlinearity_fatory