{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"source": [
"#@title set font and wrap\n",
"from IPython.display import HTML, display\n",
"\n",
"def set_css():\n",
" display(HTML('''\n",
" \n",
" '''))\n",
"get_ipython().events.register('pre_run_cell', set_css)"
],
"metadata": {
"cellView": "form",
"id": "2eK7PTTQiAWv"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LBBjwIn0ROsF"
},
"outputs": [],
"source": [
"#!git clone https://github.com/yuangan/EAT_code.git\n",
"!wget https://huggingface.co/waveydaveygravy/styletalk/resolve/main/EAT_code.zip\n",
"!unzip /content/EAT_code.zip\n",
"%cd /content/EAT_code\n",
"!pip install -r requirements.txt\n",
"!pip install resampy\n",
"#!pip install face_detection\n",
"!pip install python_speech_features\n",
"#@title make directories if not done yet\n",
"!mkdir tensorflow\n",
"!mkdir Results\n",
"!mkdir ckpt\n",
"!mkdir demo\n",
"!mkdir Utils\n",
"%cd /content/EAT_code/tensorflow\n",
"!mkdir models\n",
"%cd /content/EAT_code\n",
"print(\"done\")"
]
},
{
"cell_type": "code",
"source": [
"#@title in /content/EAT_code/config/deepprompt_eam3d_st_tanh_304_3090_all.yaml change batch_size to 1 line 70"
],
"metadata": {
"cellView": "form",
"id": "_Mxxg6d-WakT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title make directories if not done yet\n",
"!mkdir tensorflow\n",
"!mkdir ckpt\n",
"!mkdir demo\n",
"!mkdir Utils\n",
"%cd /content/EAT_code/demo\n",
"!mkdir imgs1\n",
"!mkdir imgs_cropped1\n",
"!mkdir imgs_latent1\n",
"%cd /content/EAT_code/tensorflow\n",
"!mkdir models\n",
"%cd /content/EAT_code"
],
"metadata": {
"id": "3bf2Yff2DEdH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title donwload and place deepspeech model\n",
"%cd /content/EAT_code\n",
"!gdown --id 1KK15n2fOdfLECWN5wvX54mVyDt18IZCo && unzip -q ckpt.zip -d ckpt\n",
"!gdown --id 1MeFGC7ig-vgpDLdhh2vpTIiElrhzZmgT && unzip -q demo.zip -d demo\n",
"!gdown --id 1HGVzckXh-vYGZEUUKMntY1muIbkbnRcd && unzip -q Utils.zip -d Utils\n",
"%cd /content/EAT_code/tensorflow/models\n",
"!wget https://github.com/osmr/deepspeech_features/releases/download/v0.0.1/deepspeech-0_1_0-b90017e8.pb.zip\n",
"%cd /content/EAT_code"
],
"metadata": {
"id": "ucK4MQbq0yHx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title download models\n",
"%cd /content/EAT_code\n",
"\n",
"!gdown --id 1KK15n2fOdfLECWN5wvX54mVyDt18IZCo && unzip -q ckpt.zip -d ckpt\n",
"!gdown --id 1MeFGC7ig-vgpDLdhh2vpTIiElrhzZmgT && unzip -q demo.zip -d demo\n",
"!gdown --id 1HGVzckXh-vYGZEUUKMntY1muIbkbnRcd && unzip -q Utils.zip -d Utils"
],
"metadata": {
"id": "eDDFgToSSEgj",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title upload custom mp4 to videos\n",
"%cd /content/EAT_code/preprocess/video\n",
"from google.colab import files\n",
"uploaded = files.upload()\n",
"%cd /content/EAT_code"
],
"metadata": {
"id": "rlfKO73uUXFT",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 106
},
"outputId": "e8116307-0a3f-4eea-ebd1-925f3e17a4bf"
},
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"\n",
" \n",
" "
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/EAT_code/preprocess/video\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"\n",
" \n",
" \n",
" "
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Saving bo_1resized.mp4 to bo_1resized.mp4\n",
"/content/EAT_code\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"#@title extract boundary boxes\n",
"!python /content/EAT_code/preprocess/extract_bbox.py"
],
"metadata": {
"cellView": "form",
"id": "llzj0RprSPu6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title place custom video in preprocess/video\n",
"%cd /content/EAT_code/preprocess\n",
"!python /content/EAT_code/preprocess/preprocess_video.py # --deepspeech \"/content/EAT_code/tensorflow/models/deepspeech-0_1_0-b90017e8.pb\""
],
"metadata": {
"id": "p99eLmnjW-e1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
" print(f\"Number of paths: {len(self.A_paths)}\")\n",
" print(\"Paths:\")\n",
" for path in self.A_paths:\n",
" print(path)"
],
"metadata": {
"id": "B4LWLwEvJOsp"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"\n",
"# Replace \"path/to/your/file.npy\" with the actual path to your file\n",
"data = np.load(\"/content/EAT_code/demo/video_processed/obama/latent_evp_25/obama.npy\")\n",
"\n",
"print(f\"Type: {type(data)}\")\n",
"print(f\"Shape: {data.shape}\")"
],
"metadata": {
"id": "V-hPYJBgYIKx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!python /content/EAT_code/demo.py --root_wav /content/EAT_code/demo/video_processed1/bo_1resized --emo ang --save_dir /content/EAT_code/Results"
],
"metadata": {
"id": "APJUoQkgaqa9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "e8150705-1bae-4420-8306-069aa274fb1d"
},
"execution_count": 14,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"\n",
" \n",
" "
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"deepprompt_eam3d_all_final_313\n",
"cuda is available\n",
"/usr/local/lib/python3.10/dist-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2228.)\n",
" return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
" 0% 0/1 [00:00, ?it/s]\n",
" 0% 0/1 [00:00, ?it/s]\u001b[A/content/EAT_code/modules/model_transformer.py:158: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" pred = F.softmax(pred)\n",
"\n",
"\n",
" 0% 0/177 [00:00, ?it/s]\u001b[A\u001b[A/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:4193: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:1944: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n",
" warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n",
"\n",
"\n",
" 1% 1/177 [00:00<00:57, 3.04it/s]\u001b[A\u001b[A\n",
"\n",
" 1% 2/177 [00:00<00:58, 3.01it/s]\u001b[A\u001b[A\n",
"\n",
" 2% 3/177 [00:00<00:58, 2.99it/s]\u001b[A\u001b[A\n",
"\n",
" 2% 4/177 [00:01<00:57, 2.99it/s]\u001b[A\u001b[A\n",
"\n",
" 3% 5/177 [00:01<00:57, 2.98it/s]\u001b[A\u001b[A\n",
"\n",
" 3% 6/177 [00:02<00:57, 2.98it/s]\u001b[A\u001b[A\n",
"\n",
" 4% 7/177 [00:02<00:57, 2.98it/s]\u001b[A\u001b[A\n",
"\n",
" 5% 8/177 [00:02<00:56, 2.98it/s]\u001b[A\u001b[A\n",
"\n",
" 5% 9/177 [00:03<00:56, 2.98it/s]\u001b[A\u001b[A\n",
"\n",
" 6% 10/177 [00:03<00:56, 2.98it/s]\u001b[A\u001b[A\n",
"\n",
" 6% 11/177 [00:03<00:55, 2.98it/s]\u001b[A\u001b[A\n",
"\n",
" 7% 12/177 [00:04<00:55, 2.97it/s]\u001b[A\u001b[A\n",
"\n",
" 7% 13/177 [00:04<00:55, 2.97it/s]\u001b[A\u001b[A\n",
"\n",
" 8% 14/177 [00:04<00:54, 2.97it/s]\u001b[A\u001b[A\n",
"\n",
" 8% 15/177 [00:05<00:54, 2.97it/s]\u001b[A\u001b[A\n",
"\n",
" 9% 16/177 [00:05<00:54, 2.97it/s]\u001b[A\u001b[A\n",
"\n",
" 10% 17/177 [00:05<00:53, 2.97it/s]\u001b[A\u001b[A\n",
"\n",
" 10% 18/177 [00:06<00:53, 2.97it/s]\u001b[A\u001b[A\n",
"\n",
" 11% 19/177 [00:06<00:53, 2.97it/s]\u001b[A\u001b[A\n",
"\n",
" 11% 20/177 [00:06<00:52, 2.97it/s]\u001b[A\u001b[A\n",
"\n",
" 12% 21/177 [00:07<00:52, 2.97it/s]\u001b[A\u001b[A\n",
"\n",
" 12% 22/177 [00:07<00:52, 2.97it/s]\u001b[A\u001b[A\n",
"\n",
" 13% 23/177 [00:07<00:51, 2.96it/s]\u001b[A\u001b[A\n",
"\n",
" 14% 24/177 [00:08<00:51, 2.96it/s]\u001b[A\u001b[A\n",
"\n",
" 14% 25/177 [00:08<00:51, 2.96it/s]\u001b[A\u001b[A\n",
"\n",
" 15% 26/177 [00:08<00:51, 2.95it/s]\u001b[A\u001b[A\n",
"\n",
" 15% 27/177 [00:09<00:50, 2.95it/s]\u001b[A\u001b[A\n",
"\n",
" 16% 28/177 [00:09<00:50, 2.95it/s]\u001b[A\u001b[A\n",
"\n",
" 16% 29/177 [00:09<00:50, 2.95it/s]\u001b[A\u001b[A\n",
"\n",
" 17% 30/177 [00:10<00:49, 2.95it/s]\u001b[A\u001b[A\n",
"\n",
" 18% 31/177 [00:10<00:49, 2.95it/s]\u001b[A\u001b[A\n",
"\n",
" 18% 32/177 [00:10<00:49, 2.95it/s]\u001b[A\u001b[A\n",
"\n",
" 19% 33/177 [00:11<00:48, 2.95it/s]\u001b[A\u001b[A\n",
"\n",
" 19% 34/177 [00:11<00:48, 2.95it/s]\u001b[A\u001b[A\n",
"\n",
" 20% 35/177 [00:11<00:48, 2.95it/s]\u001b[A\u001b[A\n",
"\n",
" 20% 36/177 [00:12<00:47, 2.94it/s]\u001b[A\u001b[A\n",
"\n",
" 21% 37/177 [00:12<00:47, 2.94it/s]\u001b[A\u001b[A\n",
"\n",
" 21% 38/177 [00:12<00:47, 2.95it/s]\u001b[A\u001b[A\n",
"\n",
" 22% 39/177 [00:13<00:46, 2.94it/s]\u001b[A\u001b[A\n",
"\n",
" 23% 40/177 [00:13<00:46, 2.94it/s]\u001b[A\u001b[A\n",
"\n",
" 23% 41/177 [00:13<00:46, 2.94it/s]\u001b[A\u001b[A\n",
"\n",
" 24% 42/177 [00:14<00:45, 2.94it/s]\u001b[A\u001b[A\n",
"\n",
" 24% 43/177 [00:14<00:45, 2.94it/s]\u001b[A\u001b[A\n",
"\n",
" 25% 44/177 [00:14<00:45, 2.94it/s]\u001b[A\u001b[A\n",
"\n",
" 25% 45/177 [00:15<00:45, 2.93it/s]\u001b[A\u001b[A\n",
"\n",
" 26% 46/177 [00:15<00:44, 2.93it/s]\u001b[A\u001b[A\n",
"\n",
" 27% 47/177 [00:15<00:44, 2.93it/s]\u001b[A\u001b[A\n",
"\n",
" 27% 48/177 [00:16<00:44, 2.92it/s]\u001b[A\u001b[A\n",
"\n",
" 28% 49/177 [00:16<00:43, 2.93it/s]\u001b[A\u001b[A\n",
"\n",
" 28% 50/177 [00:16<00:43, 2.92it/s]\u001b[A\u001b[A\n",
"\n",
" 29% 51/177 [00:17<00:43, 2.92it/s]\u001b[A\u001b[A\n",
"\n",
" 29% 52/177 [00:17<00:42, 2.92it/s]\u001b[A\u001b[A\n",
"\n",
" 30% 53/177 [00:17<00:42, 2.92it/s]\u001b[A\u001b[A\n",
"\n",
" 31% 54/177 [00:18<00:42, 2.92it/s]\u001b[A\u001b[A\n",
"\n",
" 31% 55/177 [00:18<00:41, 2.92it/s]\u001b[A\u001b[A\n",
"\n",
" 32% 56/177 [00:18<00:41, 2.92it/s]\u001b[A\u001b[A\n",
"\n",
" 32% 57/177 [00:19<00:41, 2.91it/s]\u001b[A\u001b[A\n",
"\n",
" 33% 58/177 [00:19<00:40, 2.91it/s]\u001b[A\u001b[A\n",
"\n",
" 33% 59/177 [00:20<00:40, 2.91it/s]\u001b[A\u001b[A\n",
"\n",
" 34% 60/177 [00:20<00:40, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 34% 61/177 [00:20<00:39, 2.91it/s]\u001b[A\u001b[A\n",
"\n",
" 35% 62/177 [00:21<00:39, 2.91it/s]\u001b[A\u001b[A\n",
"\n",
" 36% 63/177 [00:21<00:39, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 36% 64/177 [00:21<00:38, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 37% 65/177 [00:22<00:38, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 37% 66/177 [00:22<00:38, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 38% 67/177 [00:22<00:37, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 38% 68/177 [00:23<00:37, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 39% 69/177 [00:23<00:37, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 40% 70/177 [00:23<00:36, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 40% 71/177 [00:24<00:36, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 41% 72/177 [00:24<00:36, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 41% 73/177 [00:24<00:35, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 42% 74/177 [00:25<00:35, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 42% 75/177 [00:25<00:35, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 43% 76/177 [00:25<00:34, 2.91it/s]\u001b[A\u001b[A\n",
"\n",
" 44% 77/177 [00:26<00:34, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 44% 78/177 [00:26<00:34, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 45% 79/177 [00:26<00:33, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 45% 80/177 [00:27<00:33, 2.90it/s]\u001b[A\u001b[A\n",
"\n",
" 46% 81/177 [00:27<00:33, 2.89it/s]\u001b[A\u001b[A\n",
"\n",
" 46% 82/177 [00:27<00:32, 2.89it/s]\u001b[A\u001b[A\n",
"\n",
" 47% 83/177 [00:28<00:32, 2.89it/s]\u001b[A\u001b[A\n",
"\n",
" 47% 84/177 [00:28<00:32, 2.88it/s]\u001b[A\u001b[A\n",
"\n",
" 48% 85/177 [00:28<00:32, 2.87it/s]\u001b[A\u001b[A\n",
"\n",
" 49% 86/177 [00:29<00:31, 2.87it/s]\u001b[A\u001b[A\n",
"\n",
" 49% 87/177 [00:29<00:31, 2.87it/s]\u001b[A\u001b[A\n",
"\n",
" 50% 88/177 [00:30<00:30, 2.87it/s]\u001b[A\u001b[A\n",
"\n",
" 50% 89/177 [00:30<00:30, 2.87it/s]\u001b[A\u001b[A\n",
"\n",
" 51% 90/177 [00:30<00:30, 2.87it/s]\u001b[A\u001b[A\n",
"\n",
" 51% 91/177 [00:31<00:29, 2.87it/s]\u001b[A\u001b[A\n",
"\n",
" 52% 92/177 [00:31<00:29, 2.87it/s]\u001b[A\u001b[A\n",
"\n",
" 53% 93/177 [00:31<00:29, 2.87it/s]\u001b[A\u001b[A\n",
"\n",
" 53% 94/177 [00:32<00:28, 2.87it/s]\u001b[A\u001b[A\n",
"\n",
" 54% 95/177 [00:32<00:28, 2.87it/s]\u001b[A\u001b[A\n",
"\n",
" 54% 96/177 [00:32<00:28, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 55% 97/177 [00:33<00:27, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 55% 98/177 [00:33<00:27, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 56% 99/177 [00:33<00:27, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 56% 100/177 [00:34<00:26, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 57% 101/177 [00:34<00:26, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 58% 102/177 [00:34<00:26, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 58% 103/177 [00:35<00:25, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 59% 104/177 [00:35<00:25, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 59% 105/177 [00:35<00:25, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 60% 106/177 [00:36<00:24, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 60% 107/177 [00:36<00:24, 2.86it/s]\u001b[A\u001b[A\n",
"\n",
" 61% 108/177 [00:37<00:24, 2.85it/s]\u001b[A\u001b[A\n",
"\n",
" 62% 109/177 [00:37<00:23, 2.85it/s]\u001b[A\u001b[A\n",
"\n",
" 62% 110/177 [00:37<00:23, 2.85it/s]\u001b[A\u001b[A\n",
"\n",
" 63% 111/177 [00:38<00:23, 2.85it/s]\u001b[A\u001b[A\n",
"\n",
" 63% 112/177 [00:38<00:22, 2.85it/s]\u001b[A\u001b[A\n",
"\n",
" 64% 113/177 [00:38<00:22, 2.85it/s]\u001b[A\u001b[A\n",
"\n",
" 64% 114/177 [00:39<00:22, 2.85it/s]\u001b[A\u001b[A\n",
"\n",
" 65% 115/177 [00:39<00:21, 2.85it/s]\u001b[A\u001b[A\n",
"\n",
" 66% 116/177 [00:39<00:21, 2.85it/s]\u001b[A\u001b[A\n",
"\n",
" 66% 117/177 [00:40<00:21, 2.85it/s]\u001b[A\u001b[A\n",
"\n",
" 67% 118/177 [00:40<00:20, 2.85it/s]\u001b[A\u001b[A\n",
"\n",
" 67% 119/177 [00:40<00:20, 2.84it/s]\u001b[A\u001b[A\n",
"\n",
" 68% 120/177 [00:41<00:20, 2.84it/s]\u001b[A\u001b[A\n",
"\n",
" 68% 121/177 [00:41<00:19, 2.84it/s]\u001b[A\u001b[A\n",
"\n",
" 69% 122/177 [00:41<00:19, 2.83it/s]\u001b[A\u001b[A\n",
"\n",
" 69% 123/177 [00:42<00:19, 2.84it/s]\u001b[A\u001b[A\n",
"\n",
" 70% 124/177 [00:42<00:18, 2.83it/s]\u001b[A\u001b[A\n",
"\n",
" 71% 125/177 [00:43<00:18, 2.83it/s]\u001b[A\u001b[A\n",
"\n",
" 71% 126/177 [00:43<00:18, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 72% 127/177 [00:43<00:17, 2.82it/s]\u001b[A\u001b[A\n",
"\n",
" 72% 128/177 [00:44<00:17, 2.82it/s]\u001b[A\u001b[A\n",
"\n",
" 73% 129/177 [00:44<00:17, 2.82it/s]\u001b[A\u001b[A\n",
"\n",
" 73% 130/177 [00:44<00:16, 2.82it/s]\u001b[A\u001b[A\n",
"\n",
" 74% 131/177 [00:45<00:16, 2.82it/s]\u001b[A\u001b[A\n",
"\n",
" 75% 132/177 [00:45<00:15, 2.82it/s]\u001b[A\u001b[A\n",
"\n",
" 75% 133/177 [00:45<00:15, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 76% 134/177 [00:46<00:15, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 76% 135/177 [00:46<00:14, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 77% 136/177 [00:46<00:14, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 77% 137/177 [00:47<00:14, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 78% 138/177 [00:47<00:13, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 79% 139/177 [00:47<00:13, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 79% 140/177 [00:48<00:13, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 80% 141/177 [00:48<00:12, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 80% 142/177 [00:49<00:12, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 81% 143/177 [00:49<00:12, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 81% 144/177 [00:49<00:11, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 82% 145/177 [00:50<00:11, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 82% 146/177 [00:50<00:11, 2.81it/s]\u001b[A\u001b[A\n",
"\n",
" 83% 147/177 [00:50<00:10, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 84% 148/177 [00:51<00:10, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 84% 149/177 [00:51<00:09, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 85% 150/177 [00:51<00:09, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 85% 151/177 [00:52<00:09, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 86% 152/177 [00:52<00:08, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 86% 153/177 [00:52<00:08, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 87% 154/177 [00:53<00:08, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 88% 155/177 [00:53<00:07, 2.80it/s]\u001b[A\u001b[A\n",
"\n",
" 88% 156/177 [00:54<00:07, 2.79it/s]\u001b[A\u001b[A\n",
"\n",
" 89% 157/177 [00:54<00:07, 2.79it/s]\u001b[A\u001b[A\n",
"\n",
" 89% 158/177 [00:54<00:06, 2.79it/s]\u001b[A\u001b[A\n",
"\n",
" 90% 159/177 [00:55<00:06, 2.79it/s]\u001b[A\u001b[A\n",
"\n",
" 90% 160/177 [00:55<00:06, 2.79it/s]\u001b[A\u001b[A\n",
"\n",
" 91% 161/177 [00:55<00:05, 2.79it/s]\u001b[A\u001b[A\n",
"\n",
" 92% 162/177 [00:56<00:05, 2.77it/s]\u001b[A\u001b[A\n",
"\n",
" 92% 163/177 [00:56<00:05, 2.77it/s]\u001b[A\u001b[A\n",
"\n",
" 93% 164/177 [00:56<00:04, 2.78it/s]\u001b[A\u001b[A\n",
"\n",
" 93% 165/177 [00:57<00:04, 2.78it/s]\u001b[A\u001b[A\n",
"\n",
" 94% 166/177 [00:57<00:03, 2.77it/s]\u001b[A\u001b[A\n",
"\n",
" 94% 167/177 [00:58<00:03, 2.77it/s]\u001b[A\u001b[A\n",
"\n",
" 95% 168/177 [00:58<00:03, 2.77it/s]\u001b[A\u001b[A\n",
"\n",
" 95% 169/177 [00:58<00:02, 2.77it/s]\u001b[A\u001b[A\n",
"\n",
" 96% 170/177 [00:59<00:02, 2.78it/s]\u001b[A\u001b[A\n",
"\n",
" 97% 171/177 [00:59<00:02, 2.77it/s]\u001b[A\u001b[A\n",
"\n",
" 97% 172/177 [00:59<00:01, 2.77it/s]\u001b[A\u001b[A\n",
"\n",
" 98% 173/177 [01:00<00:01, 2.76it/s]\u001b[A\u001b[A\n",
"\n",
" 98% 174/177 [01:00<00:01, 2.76it/s]\u001b[A\u001b[A\n",
"\n",
" 99% 175/177 [01:00<00:00, 2.76it/s]\u001b[A\u001b[A\n",
"\n",
" 99% 176/177 [01:01<00:00, 2.76it/s]\u001b[A\u001b[A\n",
"\n",
"100% 177/177 [01:01<00:00, 2.87it/s]\n",
"\n",
"100% 1/1 [01:04<00:00, 64.70s/it]\n",
"100% 1/1 [01:04<00:00, 64.70s/it]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%cd /content\n",
"!zip -r EAT_code.zip /EAT_code"
],
"metadata": {
"id": "XZuRqjR0EGuY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title demo.py with fixed paths\n",
"\n",
"import os\n",
"import numpy as np\n",
"import torch\n",
"import yaml\n",
"from modules.generator import OcclusionAwareSPADEGeneratorEam\n",
"from modules.keypoint_detector import KPDetector, HEEstimator\n",
"import argparse\n",
"import imageio\n",
"from modules.transformer import Audio2kpTransformerBBoxQDeepPrompt as Audio2kpTransformer\n",
"from modules.prompt import EmotionDeepPrompt, EmotionalDeformationTransformer\n",
"from scipy.io import wavfile\n",
"\n",
"from modules.model_transformer import get_rotation_matrix, keypoint_transformation\n",
"from skimage import io, img_as_float32\n",
"from skimage.transform import resize\n",
"import torchaudio\n",
"import soundfile as sf\n",
"from scipy.spatial import ConvexHull\n",
"\n",
"import torch.nn.functional as F\n",
"import glob\n",
"from tqdm import tqdm\n",
"import gzip\n",
"\n",
"emo_label = ['ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur']\n",
"emo_label_full = ['angry', 'contempt', 'disgusted', 'fear', 'happy', 'neutral', 'sad', 'surprised']\n",
"latent_dim = 16\n",
"\n",
"MEL_PARAMS_25 = {\n",
" \"n_mels\": 80,\n",
" \"n_fft\": 2048,\n",
" \"win_length\": 640,\n",
" \"hop_length\": 640\n",
"}\n",
"\n",
"to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS_25)\n",
"mean, std = -4, 4\n",
"\n",
"expU = torch.from_numpy(np.load('/content/EAT_code/expPCAnorm_fin/U_mead.npy')[:,:32])\n",
"expmean = torch.from_numpy(np.load('/content/EAT_code/expPCAnorm_fin/mean_mead.npy'))\n",
"\n",
"root_wav = './demo/video_processed/M003_neu_1_001'\n",
"def normalize_kp(kp_source, kp_driving, kp_driving_initial,\n",
" use_relative_movement=True, use_relative_jacobian=True):\n",
"\n",
" kp_new = {k: v for k, v in kp_driving.items()}\n",
" if use_relative_movement:\n",
" kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])\n",
" kp_new['value'] = kp_value_diff + kp_source['value']\n",
"\n",
" if use_relative_jacobian:\n",
" jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))\n",
" kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])\n",
"\n",
" return kp_new\n",
"\n",
"def _load_tensor(data):\n",
" wave_path = data\n",
" wave, sr = sf.read(wave_path)\n",
" wave_tensor = torch.from_numpy(wave).float()\n",
" return wave_tensor\n",
"\n",
"def build_model(config, device_ids=[0]):\n",
" generator = OcclusionAwareSPADEGeneratorEam(**config['model_params']['generator_params'],\n",
" **config['model_params']['common_params'])\n",
" if torch.cuda.is_available():\n",
" print('cuda is available')\n",
" generator.to(device_ids[0])\n",
"\n",
" kp_detector = KPDetector(**config['model_params']['kp_detector_params'],\n",
" **config['model_params']['common_params'])\n",
"\n",
" if torch.cuda.is_available():\n",
" kp_detector.to(device_ids[0])\n",
"\n",
"\n",
" audio2kptransformer = Audio2kpTransformer(**config['model_params']['audio2kp_params'], face_ea=True)\n",
"\n",
" if torch.cuda.is_available():\n",
" audio2kptransformer.to(device_ids[0])\n",
"\n",
" sidetuning = EmotionalDeformationTransformer(**config['model_params']['audio2kp_params'])\n",
"\n",
" if torch.cuda.is_available():\n",
" sidetuning.to(device_ids[0])\n",
"\n",
" emotionprompt = EmotionDeepPrompt()\n",
"\n",
" if torch.cuda.is_available():\n",
" emotionprompt.to(device_ids[0])\n",
"\n",
" return generator, kp_detector, audio2kptransformer, sidetuning, emotionprompt\n",
"\n",
"\n",
"def prepare_test_data(img_path, audio_path, opt, emotype, use_otherimg=True):\n",
" # sr,_ = wavfile.read(audio_path)\n",
"\n",
" if use_otherimg:\n",
" source_latent = np.load(img_path.replace('cropped', 'latent')[:-4]+'.npy', allow_pickle=True)\n",
" else:\n",
" source_latent = np.load(img_path.replace('images', 'latent')[:-9]+'.npy', allow_pickle=True)\n",
" he_source = {}\n",
" for k in source_latent[1].keys():\n",
" he_source[k] = torch.from_numpy(source_latent[1][k][0]).unsqueeze(0).cuda()\n",
"\n",
" # source images\n",
" source_img = img_as_float32(io.imread(img_path)).transpose((2, 0, 1))\n",
" asp = os.path.basename(audio_path)[:-4]\n",
"\n",
" # latent code\n",
" y_trg = emo_label.index(emotype)\n",
" z_trg = torch.randn(latent_dim)\n",
"\n",
" # driving latent\n",
" latent_path_driving = f'{root_wav}/latent_evp_25/{asp}.npy'\n",
" pose_gz = gzip.GzipFile(f'{root_wav}/poseimg/{asp}.npy.gz', 'r')\n",
" poseimg = np.load(pose_gz)\n",
" deepfeature = np.load(f'{root_wav}/deepfeature32/{asp}.npy')\n",
" driving_latent = np.load(latent_path_driving[:-4]+'.npy', allow_pickle=True)\n",
" he_driving = driving_latent[1]\n",
"\n",
" # gt frame number\n",
" frames = glob.glob(f'{root_wav}/images_evp_25/cropped/*.jpg')\n",
" num_frames = len(frames)\n",
"\n",
" wave_tensor = _load_tensor(audio_path)\n",
" if len(wave_tensor.shape) > 1:\n",
" wave_tensor = wave_tensor[:, 0]\n",
" mel_tensor = to_melspec(wave_tensor)\n",
" mel_tensor = (torch.log(1e-5 + mel_tensor) - mean) / std\n",
" name_len = min(mel_tensor.shape[1], poseimg.shape[0], deepfeature.shape[0])\n",
"\n",
" audio_frames = []\n",
" poseimgs = []\n",
" deep_feature = []\n",
"\n",
" pad, deep_pad = np.load('pad.npy', allow_pickle=True)\n",
"\n",
" if name_len < num_frames:\n",
" diff = num_frames - name_len\n",
" if diff > 2:\n",
" print(f\"Attention: the frames are {diff} more than name_len, we will use name_len to replace num_frames\")\n",
" num_frames=name_len\n",
" for k in he_driving.keys():\n",
" he_driving[k] = he_driving[k][:name_len, :]\n",
" for rid in range(0, num_frames):\n",
" audio = []\n",
" poses = []\n",
" deeps = []\n",
" for i in range(rid - opt['num_w'], rid + opt['num_w'] + 1):\n",
" if i < 0:\n",
" audio.append(pad)\n",
" poses.append(poseimg[0])\n",
" deeps.append(deep_pad)\n",
" elif i >= name_len:\n",
" audio.append(pad)\n",
" poses.append(poseimg[-1])\n",
" deeps.append(deep_pad)\n",
" else:\n",
" audio.append(mel_tensor[:, i])\n",
" poses.append(poseimg[i])\n",
" deeps.append(deepfeature[i])\n",
"\n",
" audio_frames.append(torch.stack(audio, dim=1))\n",
" poseimgs.append(poses)\n",
" deep_feature.append(deeps)\n",
" audio_frames = torch.stack(audio_frames, dim=0)\n",
" poseimgs = torch.from_numpy(np.array(poseimgs))\n",
" deep_feature = torch.from_numpy(np.array(deep_feature)).to(torch.float)\n",
" return audio_frames, poseimgs, deep_feature, source_img, he_source, he_driving, num_frames, y_trg, z_trg, latent_path_driving\n",
"\n",
"def load_ckpt(ckpt, kp_detector, generator, audio2kptransformer, sidetuning, emotionprompt):\n",
" checkpoint = torch.load(ckpt, map_location=torch.device('cpu'))\n",
" if audio2kptransformer is not None:\n",
" audio2kptransformer.load_state_dict(checkpoint['audio2kptransformer'])\n",
" if generator is not None:\n",
" generator.load_state_dict(checkpoint['generator'])\n",
" if kp_detector is not None:\n",
" kp_detector.load_state_dict(checkpoint['kp_detector'])\n",
" if sidetuning is not None:\n",
" sidetuning.load_state_dict(checkpoint['sidetuning'])\n",
" if emotionprompt is not None:\n",
" emotionprompt.load_state_dict(checkpoint['emotionprompt'])\n",
"\n",
"import cv2\n",
"import dlib\n",
"from tqdm import tqdm\n",
"from skimage import transform as tf\n",
"detector = dlib.get_frontal_face_detector()\n",
"predictor = dlib.shape_predictor('/content/EAT_code/demo/shape_predictor_68_face_landmarks.dat')\n",
"\n",
"def shape_to_np(shape, dtype=\"int\"):\n",
" # initialize the list of (x, y)-coordinates\n",
" coords = np.zeros((shape.num_parts, 2), dtype=dtype)\n",
"\n",
" # loop over all facial landmarks and convert them\n",
" # to a 2-tuple of (x, y)-coordinates\n",
" for i in range(0, shape.num_parts):\n",
" coords[i] = (shape.part(i).x, shape.part(i).y)\n",
"\n",
" # return the list of (x, y)-coordinates\n",
" return coords\n",
"\n",
"def crop_image(image_path, out_path):\n",
" template = np.load('./demo/M003_template.npy')\n",
" image = cv2.imread(image_path)\n",
" gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
" rects = detector(gray, 1) #detect human face\n",
" if len(rects) != 1:\n",
" return 0\n",
" for (j, rect) in enumerate(rects):\n",
" shape = predictor(gray, rect) #detect 68 points\n",
" shape = shape_to_np(shape)\n",
"\n",
" pts2 = np.float32(template[:47,:])\n",
" pts1 = np.float32(shape[:47,:]) #eye and nose\n",
" tform = tf.SimilarityTransform()\n",
" tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.\n",
"\n",
" dst = tf.warp(image, tform, output_shape=(256, 256))\n",
"\n",
" dst = np.array(dst * 255, dtype=np.uint8)\n",
"\n",
" cv2.imwrite(out_path, dst)\n",
"\n",
"def preprocess_imgs(allimgs, tmp_allimgs_cropped):\n",
" name_cropped = []\n",
" for path in tmp_allimgs_cropped:\n",
" name_cropped.append(os.path.basename(path))\n",
" for path in allimgs:\n",
" if os.path.basename(path) in name_cropped:\n",
" continue\n",
" else:\n",
" out_path = path.replace('imgs/', 'imgs_cropped/')\n",
" crop_image(path, out_path)\n",
"\n",
"from sync_batchnorm import DataParallelWithCallback\n",
"def load_checkpoints_extractor(config_path, checkpoint_path, cpu=False):\n",
"\n",
" with open(config_path) as f:\n",
" config = yaml.load(f, Loader=yaml.FullLoader)\n",
"\n",
" kp_detector = KPDetector(**config['model_params']['kp_detector_params'],\n",
" **config['model_params']['common_params'])\n",
" if not cpu:\n",
" kp_detector.cuda()\n",
"\n",
" he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],\n",
" **config['model_params']['common_params'])\n",
" if not cpu:\n",
" he_estimator.cuda()\n",
"\n",
" if cpu:\n",
" checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))\n",
" else:\n",
" checkpoint = torch.load(checkpoint_path)\n",
"\n",
" kp_detector.load_state_dict(checkpoint['kp_detector'])\n",
" he_estimator.load_state_dict(checkpoint['he_estimator'])\n",
"\n",
" if not cpu:\n",
" kp_detector = DataParallelWithCallback(kp_detector)\n",
" he_estimator = DataParallelWithCallback(he_estimator)\n",
"\n",
" kp_detector.eval()\n",
" he_estimator.eval()\n",
"\n",
" return kp_detector, he_estimator\n",
"\n",
"def estimate_latent(driving_video, kp_detector, he_estimator):\n",
" with torch.no_grad():\n",
" predictions = []\n",
" driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).cuda()\n",
" kp_canonical = kp_detector(driving[:, :, 0])\n",
" he_drivings = {'yaw': [], 'pitch': [], 'roll': [], 't': [], 'exp': []}\n",
"\n",
" for frame_idx in range(driving.shape[2]):\n",
" driving_frame = driving[:, :, frame_idx]\n",
" he_driving = he_estimator(driving_frame)\n",
" for k in he_drivings.keys():\n",
" he_drivings[k].append(he_driving[k])\n",
" return [kp_canonical, he_drivings]\n",
"\n",
"def extract_keypoints(extract_list):\n",
" kp_detector, he_estimator = load_checkpoints_extractor(config_path='config/vox-256-spade.yaml', checkpoint_path='./ckpt/pretrain_new_274.pth.tar')\n",
" if not os.path.exists('./demo/imgs_latent/'):\n",
" os.makedirs('./demo/imgs_latent/')\n",
" for imgname in tqdm(extract_list):\n",
" path_frames = [imgname]\n",
" filesname=os.path.basename(imgname)[:-4]\n",
" if os.path.exists(f'./demo/imgs_latent/'+filesname+'.npy'):\n",
" continue\n",
" driving_frames = []\n",
" for im in path_frames:\n",
" driving_frames.append(imageio.imread(im))\n",
" driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_frames]\n",
"\n",
" kc, he = estimate_latent(driving_video, kp_detector, he_estimator)\n",
" kc = kc['value'].cpu().numpy()\n",
" for k in he:\n",
" he[k] = torch.cat(he[k]).cpu().numpy()\n",
" np.save('./demo/imgs_latent/'+filesname, [kc, he])\n",
"\n",
"def preprocess_cropped_imgs(allimgs_cropped):\n",
" extract_list = []\n",
" for img_path in allimgs_cropped:\n",
" if not os.path.exists(img_path.replace('cropped', 'latent')[:-4]+'.npy'):\n",
" extract_list.append(img_path)\n",
" if len(extract_list) > 0:\n",
" print('=========', \"Extract latent keypoints from New image\", '======')\n",
" extract_keypoints(extract_list)\n",
"\n",
"def test(ckpt, emotype, save_dir=\" \"):\n",
" # with open(\"config/vox-transformer2.yaml\") as f:\n",
" with open(\"/content/EAT_code/config/deepprompt_eam3d_st_tanh_304_3090_all.yaml\") as f:\n",
" config = yaml.load(f, Loader=yaml.FullLoader)\n",
" cur_path = os.getcwd()\n",
" generator, kp_detector, audio2kptransformer, sidetuning, emotionprompt = build_model(config)\n",
" load_ckpt(ckpt, kp_detector=kp_detector, generator=generator, audio2kptransformer=audio2kptransformer, sidetuning=sidetuning, emotionprompt=emotionprompt)\n",
"\n",
" audio2kptransformer.eval()\n",
" generator.eval()\n",
" kp_detector.eval()\n",
" sidetuning.eval()\n",
" emotionprompt.eval()\n",
"\n",
" all_wavs2 = [f'{root_wav}/{os.path.basename(root_wav)}.wav']\n",
" allimg = glob.glob('./demo/imgs/*.jpg')\n",
" tmp_allimg_cropped = glob.glob('./demo/imgs_cropped/*.jpg')\n",
" preprocess_imgs(allimg, tmp_allimg_cropped) # crop and align images\n",
"\n",
" allimg_cropped = glob.glob('./demo/imgs_cropped/*.jpg')\n",
" preprocess_cropped_imgs(allimg_cropped) # extract latent keypoints if necessary\n",
"\n",
" for ind in tqdm(range(len(all_wavs2))):\n",
" for img_path in tqdm(allimg_cropped):\n",
" audio_path = all_wavs2[ind]\n",
" # read in data\n",
" audio_frames, poseimgs, deep_feature, source_img, he_source, he_driving, num_frames, y_trg, z_trg, latent_path_driving = prepare_test_data(img_path, audio_path, config['model_params']['audio2kp_params'], emotype)\n",
"\n",
"\n",
" with torch.no_grad():\n",
" source_img = torch.from_numpy(source_img).unsqueeze(0).cuda()\n",
" kp_canonical = kp_detector(source_img, with_feature=True) # {'value': value, 'jacobian': jacobian}\n",
" kp_cano = kp_canonical['value']\n",
"\n",
" x = {}\n",
" x['mel'] = audio_frames.unsqueeze(1).unsqueeze(0).cuda()\n",
" x['z_trg'] = z_trg.unsqueeze(0).cuda()\n",
" x['y_trg'] = torch.tensor(y_trg, dtype=torch.long).cuda().reshape(1)\n",
" x['pose'] = poseimgs.cuda()\n",
" x['deep'] = deep_feature.cuda().unsqueeze(0)\n",
" x['he_driving'] = {'yaw': torch.from_numpy(he_driving['yaw']).cuda().unsqueeze(0),\n",
" 'pitch': torch.from_numpy(he_driving['pitch']).cuda().unsqueeze(0),\n",
" 'roll': torch.from_numpy(he_driving['roll']).cuda().unsqueeze(0),\n",
" 't': torch.from_numpy(he_driving['t']).cuda().unsqueeze(0),\n",
" }\n",
"\n",
" ### emotion prompt\n",
" emoprompt, deepprompt = emotionprompt(x)\n",
" a2kp_exps = []\n",
" emo_exps = []\n",
" T = 5\n",
" if T == 1:\n",
" for i in range(x['mel'].shape[1]):\n",
" xi = {}\n",
" xi['mel'] = x['mel'][:,i,:,:,:].unsqueeze(1)\n",
" xi['z_trg'] = x['z_trg']\n",
" xi['y_trg'] = x['y_trg']\n",
" xi['pose'] = x['pose'][i,:,:,:,:].unsqueeze(0)\n",
" xi['deep'] = x['deep'][:,i,:,:,:].unsqueeze(1)\n",
" xi['he_driving'] = {'yaw': x['he_driving']['yaw'][:,i,:].unsqueeze(0),\n",
" 'pitch': x['he_driving']['pitch'][:,i,:].unsqueeze(0),\n",
" 'roll': x['he_driving']['roll'][:,i,:].unsqueeze(0),\n",
" 't': x['he_driving']['t'][:,i,:].unsqueeze(0),\n",
" }\n",
" he_driving_emo_xi, input_st_xi = audio2kptransformer(xi, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n",
" emo_exp = sidetuning(input_st_xi, emoprompt, deepprompt)\n",
" a2kp_exps.append(he_driving_emo_xi['emo'])\n",
" emo_exps.append(emo_exp)\n",
" elif T is not None:\n",
" for i in range(x['mel'].shape[1]//T+1):\n",
" if i*T >= x['mel'].shape[1]:\n",
" break\n",
" xi = {}\n",
" xi['mel'] = x['mel'][:,i*T:(i+1)*T,:,:,:]\n",
" xi['z_trg'] = x['z_trg']\n",
" xi['y_trg'] = x['y_trg']\n",
" xi['pose'] = x['pose'][i*T:(i+1)*T,:,:,:,:]\n",
" xi['deep'] = x['deep'][:,i*T:(i+1)*T,:,:,:]\n",
" xi['he_driving'] = {'yaw': x['he_driving']['yaw'][:,i*T:(i+1)*T,:],\n",
" 'pitch': x['he_driving']['pitch'][:,i*T:(i+1)*T,:],\n",
" 'roll': x['he_driving']['roll'][:,i*T:(i+1)*T,:],\n",
" 't': x['he_driving']['t'][:,i*T:(i+1)*T,:],\n",
" }\n",
" he_driving_emo_xi, input_st_xi = audio2kptransformer(xi, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n",
" emo_exp = sidetuning(input_st_xi, emoprompt, deepprompt)\n",
" a2kp_exps.append(he_driving_emo_xi['emo'])\n",
" emo_exps.append(emo_exp)\n",
"\n",
" if T is None:\n",
" he_driving_emo, input_st = audio2kptransformer(x, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n",
" emo_exps = sidetuning(input_st, emoprompt, deepprompt).reshape(-1, 45)\n",
" else:\n",
" he_driving_emo = {}\n",
" he_driving_emo['emo'] = torch.cat(a2kp_exps, dim=0)\n",
" emo_exps = torch.cat(emo_exps, dim=0).reshape(-1, 45)\n",
"\n",
" exp = he_driving_emo['emo']\n",
" device = exp.get_device()\n",
" exp = torch.mm(exp, expU.t().to(device))\n",
" exp = exp + expmean.expand_as(exp).to(device)\n",
" exp = exp + emo_exps\n",
"\n",
"\n",
" source_area = ConvexHull(kp_cano[0].cpu().numpy()).volume\n",
" exp = exp * source_area\n",
"\n",
" he_new_driving = {'yaw': torch.from_numpy(he_driving['yaw']).cuda(),\n",
" 'pitch': torch.from_numpy(he_driving['pitch']).cuda(),\n",
" 'roll': torch.from_numpy(he_driving['roll']).cuda(),\n",
" 't': torch.from_numpy(he_driving['t']).cuda(),\n",
" 'exp': exp}\n",
" he_driving['exp'] = torch.from_numpy(he_driving['exp']).cuda()\n",
"\n",
" kp_source = keypoint_transformation(kp_canonical, he_source, False)\n",
" mean_source = torch.mean(kp_source['value'], dim=1)[0]\n",
" kp_driving = keypoint_transformation(kp_canonical, he_new_driving, False)\n",
" mean_driving = torch.mean(torch.mean(kp_driving['value'], dim=1), dim=0)\n",
" kp_driving['value'] = kp_driving['value']+(mean_source-mean_driving).unsqueeze(0).unsqueeze(0)\n",
" bs = kp_source['value'].shape[0]\n",
" predictions_gen = []\n",
" for i in tqdm(range(num_frames)):\n",
" kp_si = {}\n",
" kp_si['value'] = kp_source['value'][0].unsqueeze(0)\n",
" kp_di = {}\n",
" kp_di['value'] = kp_driving['value'][i].unsqueeze(0)\n",
" generated = generator(source_img, kp_source=kp_si, kp_driving=kp_di, prompt=emoprompt)\n",
" predictions_gen.append(\n",
" (np.transpose(generated['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] * 255).astype(np.uint8))\n",
"\n",
" log_dir = save_dir\n",
" os.makedirs(os.path.join(log_dir, \"temp\"), exist_ok=True)\n",
"\n",
" f_name = os.path.basename(img_path[:-4]) + \"_\" + emotype + \"_\" + os.path.basename(latent_path_driving)[:-4] + \".mp4\"\n",
" video_path = os.path.join(log_dir, \"temp\", f_name)\n",
" imageio.mimsave(video_path, predictions_gen, fps=25.0)\n",
"\n",
" save_video = os.path.join(log_dir, f_name)\n",
" cmd = r'ffmpeg -loglevel error -y -i \"%s\" -i \"%s\" -vcodec copy -shortest \"%s\"' % (video_path, audio_path, save_video)\n",
" os.system(cmd)\n",
" os.remove(video_path)\n",
"\n",
"if __name__ == '__main__':\n",
" argparser = argparse.ArgumentParser()\n",
" argparser.add_argument(\"--save_dir\", type=str, default=\" \", help=\"path of the output video\")\n",
" argparser.add_argument(\"--name\", type=str, default=\"deepprompt_eam3d_all_final_313\", help=\"path of the output video\")\n",
" argparser.add_argument(\"--emo\", type=str, default=\"hap\", help=\"emotion type ('ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur')\")\n",
" argparser.add_argument(\"--root_wav\", type=str, default='./demo/video_processed/M003_neu_1_001', help=\"emotion type ('ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur')\")\n",
" args = argparser.parse_args()\n",
"\n",
" root_wav=args.root_wav\n",
"\n",
" if len(args.name) > 1:\n",
" name = args.name\n",
" print(name)\n",
" test(f'/content/EAT_code/ckpt/deepprompt_eam3d_all_final_313.pth.tar', args.emo, save_dir=f'./demo/output/{name}/')\n",
"\n"
],
"metadata": {
"cellView": "form",
"id": "cMz72QBbgRsc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "CmeH0D2ayRrf"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title trasnformer.py with fixed paths\n",
"import torch.nn as nn\n",
"import torch\n",
"from modules.util import mydownres2Dblock\n",
"import numpy as np\n",
"from modules.util import AntiAliasInterpolation2d, make_coordinate_grid\n",
"from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\n",
"import torch.nn.functional as F\n",
"import copy\n",
"\n",
"\n",
"class PositionalEncoding(nn.Module):\n",
"\n",
" def __init__(self, d_hid, n_position=200):\n",
" super(PositionalEncoding, self).__init__()\n",
"\n",
" # Not a parameter\n",
" self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))\n",
"\n",
" def _get_sinusoid_encoding_table(self, n_position, d_hid):\n",
" ''' Sinusoid position encoding table '''\n",
" # TODO: make it with torch instead of numpy\n",
"\n",
" def get_position_angle_vec(position):\n",
" return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]\n",
"\n",
" sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])\n",
" sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i\n",
" sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1\n",
"\n",
" return torch.FloatTensor(sinusoid_table).unsqueeze(0)\n",
"\n",
" def forward(self, winsize):\n",
" return self.pos_table[:, :winsize].clone().detach()\n",
"\n",
"def _get_activation_fn(activation):\n",
" \"\"\"Return an activation function given a string\"\"\"\n",
" if activation == \"relu\":\n",
" return F.relu\n",
" if activation == \"gelu\":\n",
" return F.gelu\n",
" if activation == \"glu\":\n",
" return F.glu\n",
" raise RuntimeError(F\"activation should be relu/gelu, not {activation}.\")\n",
"\n",
"def _get_clones(module, N):\n",
" return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n",
"\n",
"### light weight transformer encoder\n",
"class TransformerST(nn.Module):\n",
"\n",
" def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n",
" num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n",
" activation=\"relu\", normalize_before=False):\n",
" super().__init__()\n",
"\n",
" encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n",
" dropout, activation, normalize_before)\n",
" encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n",
" self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n",
"\n",
" self._reset_parameters()\n",
"\n",
" self.d_model = d_model\n",
" self.nhead = nhead\n",
"\n",
" def _reset_parameters(self):\n",
" for p in self.parameters():\n",
" if p.dim() > 1:\n",
" nn.init.xavier_uniform_(p)\n",
"\n",
" def forward(self, src, pos_embed):\n",
" # flatten NxCxHxW to HWxNxC\n",
"\n",
" src = src.permute(1, 0, 2)\n",
" pos_embed = pos_embed.permute(1, 0, 2)\n",
"\n",
" memory = self.encoder(src, pos=pos_embed)\n",
"\n",
" return memory\n",
"\n",
"class TransformerEncoder(nn.Module):\n",
"\n",
" def __init__(self, encoder_layer, num_layers, norm=None):\n",
" super().__init__()\n",
" self.layers = _get_clones(encoder_layer, num_layers)\n",
" self.num_layers = num_layers\n",
" self.norm = norm\n",
"\n",
" def forward(self, src, mask = None, src_key_padding_mask = None, pos = None):\n",
" output = src+pos\n",
"\n",
" for layer in self.layers:\n",
" output = layer(output, src_mask=mask,\n",
" src_key_padding_mask=src_key_padding_mask, pos=pos)\n",
"\n",
" if self.norm is not None:\n",
" output = self.norm(output)\n",
"\n",
" return output\n",
"\n",
"class TransformerDecoder(nn.Module):\n",
"\n",
" def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):\n",
" super().__init__()\n",
" self.layers = _get_clones(decoder_layer, num_layers)\n",
" self.num_layers = num_layers\n",
" self.norm = norm\n",
" self.return_intermediate = return_intermediate\n",
"\n",
" def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,\n",
" memory_key_padding_mask = None,\n",
" pos = None,\n",
" query_pos = None):\n",
" output = tgt+pos+query_pos\n",
"\n",
" intermediate = []\n",
"\n",
" for layer in self.layers:\n",
" output = layer(output, memory, tgt_mask=tgt_mask,\n",
" memory_mask=memory_mask,\n",
" tgt_key_padding_mask=tgt_key_padding_mask,\n",
" memory_key_padding_mask=memory_key_padding_mask,\n",
" pos=pos, query_pos=query_pos)\n",
" if self.return_intermediate:\n",
" intermediate.append(self.norm(output))\n",
"\n",
" if self.norm is not None:\n",
" output = self.norm(output)\n",
" if self.return_intermediate:\n",
" intermediate.pop()\n",
" intermediate.append(output)\n",
"\n",
" if self.return_intermediate:\n",
" return torch.stack(intermediate)\n",
"\n",
" return output.unsqueeze(0)\n",
"\n",
"\n",
"class Transformer(nn.Module):\n",
"\n",
" def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n",
" num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n",
" activation=\"relu\", normalize_before=False,\n",
" return_intermediate_dec=True):\n",
" super().__init__()\n",
"\n",
" encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n",
" dropout, activation, normalize_before)\n",
" encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n",
" self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n",
"\n",
" decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,\n",
" dropout, activation, normalize_before)\n",
" decoder_norm = nn.LayerNorm(d_model)\n",
" self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,\n",
" return_intermediate=return_intermediate_dec)\n",
"\n",
" self._reset_parameters()\n",
"\n",
" self.d_model = d_model\n",
" self.nhead = nhead\n",
"\n",
" def _reset_parameters(self):\n",
" for p in self.parameters():\n",
" if p.dim() > 1:\n",
" nn.init.xavier_uniform_(p)\n",
"\n",
" def forward(self, src, query_embed, pos_embed):\n",
" # flatten NxCxHxW to HWxNxC\n",
"\n",
" src = src.permute(1, 0, 2)\n",
" pos_embed = pos_embed.permute(1, 0, 2)\n",
" query_embed = query_embed.permute(1, 0, 2)\n",
"\n",
" tgt = torch.zeros_like(query_embed)\n",
" memory = self.encoder(src, pos=pos_embed)\n",
"\n",
" hs = self.decoder(tgt, memory,\n",
" pos=pos_embed, query_pos=query_embed)\n",
" return hs, memory\n",
"\n",
"\n",
"class TransformerDeep(nn.Module):\n",
"\n",
" def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n",
" num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n",
" activation=\"relu\", normalize_before=False,\n",
" return_intermediate_dec=True):\n",
" super().__init__()\n",
"\n",
" encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n",
" dropout, activation, normalize_before)\n",
" encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n",
" self.encoder = TransformerEncoderDeep(encoder_layer, num_encoder_layers, encoder_norm)\n",
"\n",
" decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,\n",
" dropout, activation, normalize_before)\n",
" decoder_norm = nn.LayerNorm(d_model)\n",
" self.decoder = TransformerDecoderDeep(decoder_layer, num_decoder_layers, decoder_norm,\n",
" return_intermediate=return_intermediate_dec)\n",
"\n",
" self._reset_parameters()\n",
"\n",
" self.d_model = d_model\n",
" self.nhead = nhead\n",
"\n",
" def _reset_parameters(self):\n",
" for p in self.parameters():\n",
" if p.dim() > 1:\n",
" nn.init.xavier_uniform_(p)\n",
"\n",
" def forward(self, src, query_embed, pos_embed, deepprompt):\n",
" # flatten NxCxHxW to HWxNxC\n",
"\n",
" # print('src before permute: ', src.shape) # 5, 12, 128\n",
" src = src.permute(1, 0, 2)\n",
" # print('src after permute: ', src.shape) # 12, 5, 128\n",
" pos_embed = pos_embed.permute(1, 0, 2)\n",
" query_embed = query_embed.permute(1, 0, 2)\n",
"\n",
" tgt = torch.zeros_like(query_embed) # actually is tgt + query_embed\n",
" memory = self.encoder(src, deepprompt, pos=pos_embed)\n",
"\n",
" hs = self.decoder(tgt, deepprompt, memory,\n",
" pos=pos_embed, query_pos=query_embed)\n",
" return hs, memory\n",
"\n",
"class TransformerEncoderDeep(nn.Module):\n",
"\n",
" def __init__(self, encoder_layer, num_layers, norm=None):\n",
" super().__init__()\n",
" self.layers = _get_clones(encoder_layer, num_layers)\n",
" self.num_layers = num_layers\n",
" self.norm = norm\n",
"\n",
" def forward(self, src, deepprompt, mask = None, src_key_padding_mask = None, pos = None):\n",
" # print('input: ', src.shape) # 12 5 128\n",
" # print('deepprompt:', deepprompt.shape) # 1 6 128\n",
" ### TODO: add deep prompt in encoder\n",
" bs = src.shape[1]\n",
" bbs = deepprompt.shape[0]\n",
" idx=0\n",
" emoprompt = deepprompt[:,idx,:]\n",
" emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n",
" # print(emoprompt.shape) # 1 5 128\n",
" src = torch.cat([src, emoprompt], dim=0)\n",
" # print(src.shape) # 13 5 128\n",
" output = src+pos\n",
"\n",
" for layer in self.layers:\n",
" output = layer(output, src_mask=mask,\n",
" src_key_padding_mask=src_key_padding_mask, pos=pos)\n",
"\n",
" ### deep prompt\n",
" if idx+1 < len(self.layers):\n",
" idx = idx + 1\n",
" # print(idx)\n",
" emoprompt = deepprompt[:,idx,:]\n",
" emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n",
" # print(output.shape) # 13 5 128\n",
" output = torch.cat([output[:-1], emoprompt], dim=0)\n",
" # print(output.shape) # 13 5 128\n",
"\n",
" if self.norm is not None:\n",
" output = self.norm(output)\n",
"\n",
" return output\n",
"\n",
"class TransformerDecoderDeep(nn.Module):\n",
"\n",
" def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):\n",
" super().__init__()\n",
" self.layers = _get_clones(decoder_layer, num_layers)\n",
" self.num_layers = num_layers\n",
" self.norm = norm\n",
" self.return_intermediate = return_intermediate\n",
"\n",
" def forward(self, tgt, deepprompt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,\n",
" memory_key_padding_mask = None,\n",
" pos = None,\n",
" query_pos = None):\n",
" # print('input: ', query_pos.shape) 12 5 128\n",
" ### TODO: add deep prompt in encoder\n",
" bs = query_pos.shape[1]\n",
" bbs = deepprompt.shape[0]\n",
" idx=0\n",
" emoprompt = deepprompt[:,idx,:]\n",
" emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n",
" query_pos = torch.cat([query_pos, emoprompt], dim=0)\n",
" # print(query_pos.shape) # 13 5 128\n",
" # print(torch.sum(tgt)) # 0\n",
" output = pos+query_pos\n",
"\n",
" intermediate = []\n",
"\n",
" for layer in self.layers:\n",
" output = layer(output, memory, tgt_mask=tgt_mask,\n",
" memory_mask=memory_mask,\n",
" tgt_key_padding_mask=tgt_key_padding_mask,\n",
" memory_key_padding_mask=memory_key_padding_mask,\n",
" pos=pos, query_pos=query_pos)\n",
" if self.return_intermediate:\n",
" intermediate.append(self.norm(output))\n",
"\n",
" ### deep prompt\n",
" if idx+1 < len(self.layers):\n",
" idx = idx + 1\n",
" # print(idx)\n",
" emoprompt = deepprompt[:,idx,:]\n",
" emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n",
" # print(output.shape) # 13 5 128\n",
" output = torch.cat([output[:-1], emoprompt], dim=0)\n",
" # print(output.shape) # 13 5 128\n",
"\n",
" if self.norm is not None:\n",
" output = self.norm(output)\n",
" if self.return_intermediate:\n",
" intermediate.pop()\n",
" intermediate.append(output)\n",
"\n",
" if self.return_intermediate:\n",
" return torch.stack(intermediate)\n",
"\n",
" return output.unsqueeze(0)\n",
"\n",
"\n",
"class TransformerEncoderLayer(nn.Module):\n",
"\n",
" def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,\n",
" activation=\"relu\", normalize_before=False):\n",
" super().__init__()\n",
" self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n",
" # Implementation of Feedforward model\n",
" self.linear1 = nn.Linear(d_model, dim_feedforward)\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.linear2 = nn.Linear(dim_feedforward, d_model)\n",
"\n",
" self.norm1 = nn.LayerNorm(d_model)\n",
" self.norm2 = nn.LayerNorm(d_model)\n",
" self.dropout1 = nn.Dropout(dropout)\n",
" self.dropout2 = nn.Dropout(dropout)\n",
"\n",
" self.activation = _get_activation_fn(activation)\n",
" self.normalize_before = normalize_before\n",
"\n",
" def with_pos_embed(self, tensor, pos):\n",
" return tensor if pos is None else tensor + pos\n",
"\n",
" def forward_post(self,\n",
" src,\n",
" src_mask = None,\n",
" src_key_padding_mask = None,\n",
" pos = None):\n",
" # q = k = self.with_pos_embed(src, pos)\n",
" src2 = self.self_attn(src, src, value=src, attn_mask=src_mask,\n",
" key_padding_mask=src_key_padding_mask)[0]\n",
" src = src + self.dropout1(src2)\n",
" src = self.norm1(src)\n",
" src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))\n",
" src = src + self.dropout2(src2)\n",
" src = self.norm2(src)\n",
" return src\n",
"\n",
" def forward_pre(self, src,\n",
" src_mask = None,\n",
" src_key_padding_mask = None,\n",
" pos = None):\n",
" src2 = self.norm1(src)\n",
" # q = k = self.with_pos_embed(src2, pos)\n",
" src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask,\n",
" key_padding_mask=src_key_padding_mask)[0]\n",
" src = src + self.dropout1(src2)\n",
" src2 = self.norm2(src)\n",
" src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))\n",
" src = src + self.dropout2(src2)\n",
" return src\n",
"\n",
" def forward(self, src,\n",
" src_mask = None,\n",
" src_key_padding_mask = None,\n",
" pos = None):\n",
" if self.normalize_before:\n",
" return self.forward_pre(src, src_mask, src_key_padding_mask, pos)\n",
" return self.forward_post(src, src_mask, src_key_padding_mask, pos)\n",
"\n",
"class TransformerDecoderLayer(nn.Module):\n",
"\n",
" def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,\n",
" activation=\"relu\", normalize_before=False):\n",
" super().__init__()\n",
" self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n",
" self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n",
" # Implementation of Feedforward model\n",
" self.linear1 = nn.Linear(d_model, dim_feedforward)\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.linear2 = nn.Linear(dim_feedforward, d_model)\n",
"\n",
" self.norm1 = nn.LayerNorm(d_model)\n",
" self.norm2 = nn.LayerNorm(d_model)\n",
" self.norm3 = nn.LayerNorm(d_model)\n",
" self.dropout1 = nn.Dropout(dropout)\n",
" self.dropout2 = nn.Dropout(dropout)\n",
" self.dropout3 = nn.Dropout(dropout)\n",
"\n",
" self.activation = _get_activation_fn(activation)\n",
" self.normalize_before = normalize_before\n",
"\n",
" def with_pos_embed(self, tensor, pos):\n",
" return tensor if pos is None else tensor + pos\n",
"\n",
" def forward_post(self, tgt, memory,\n",
" tgt_mask = None,\n",
" memory_mask = None,\n",
" tgt_key_padding_mask = None,\n",
" memory_key_padding_mask = None,\n",
" pos = None,\n",
" query_pos = None):\n",
" # q = k = self.with_pos_embed(tgt, query_pos)\n",
" tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask,\n",
" key_padding_mask=tgt_key_padding_mask)[0]\n",
" tgt = tgt + self.dropout1(tgt2)\n",
" tgt = self.norm1(tgt)\n",
" tgt2 = self.multihead_attn(query=tgt,\n",
" key=memory,\n",
" value=memory, attn_mask=memory_mask,\n",
" key_padding_mask=memory_key_padding_mask)[0]\n",
" tgt = tgt + self.dropout2(tgt2)\n",
" tgt = self.norm2(tgt)\n",
" tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))\n",
" tgt = tgt + self.dropout3(tgt2)\n",
" tgt = self.norm3(tgt)\n",
" return tgt\n",
"\n",
" def forward_pre(self, tgt, memory,\n",
" tgt_mask = None,\n",
" memory_mask = None,\n",
" tgt_key_padding_mask = None,\n",
" memory_key_padding_mask = None,\n",
" pos = None,\n",
" query_pos = None):\n",
" tgt2 = self.norm1(tgt)\n",
" # q = k = self.with_pos_embed(tgt2, query_pos)\n",
" tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask,\n",
" key_padding_mask=tgt_key_padding_mask)[0]\n",
" tgt = tgt + self.dropout1(tgt2)\n",
" tgt2 = self.norm2(tgt)\n",
" tgt2 = self.multihead_attn(query=tgt2,\n",
" key=memory,\n",
" value=memory, attn_mask=memory_mask,\n",
" key_padding_mask=memory_key_padding_mask)[0]\n",
" tgt = tgt + self.dropout2(tgt2)\n",
" tgt2 = self.norm3(tgt)\n",
" tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))\n",
" tgt = tgt + self.dropout3(tgt2)\n",
" return tgt\n",
"\n",
" def forward(self, tgt, memory,\n",
" tgt_mask = None,\n",
" memory_mask = None,\n",
" tgt_key_padding_mask = None,\n",
" memory_key_padding_mask = None,\n",
" pos = None,\n",
" query_pos = None):\n",
" if self.normalize_before:\n",
" return self.forward_pre(tgt, memory, tgt_mask, memory_mask,\n",
" tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)\n",
" return self.forward_post(tgt, memory, tgt_mask, memory_mask,\n",
" tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)\n",
"\n",
"from Utils.JDC.model import JDCNet\n",
"from modules.audioencoder import AudioEncoder, MappingNetwork, StyleEncoder, AdaIN, EAModule\n",
"\n",
"def headpose_pred_to_degree(pred):\n",
" device = pred.device\n",
" idx_tensor = [idx for idx in range(66)]\n",
" idx_tensor = torch.FloatTensor(idx_tensor).to(device)\n",
" pred = F.softmax(pred, dim=1)\n",
" degree = torch.sum(pred*idx_tensor, axis=1)\n",
" # degree = F.one_hot(degree.to(torch.int64), num_classes=66)\n",
" return degree\n",
"\n",
"def get_rotation_matrix(yaw, pitch, roll):\n",
" yaw = yaw / 180 * 3.14\n",
" pitch = pitch / 180 * 3.14\n",
" roll = roll / 180 * 3.14\n",
"\n",
" roll = roll.unsqueeze(1)\n",
" pitch = pitch.unsqueeze(1)\n",
" yaw = yaw.unsqueeze(1)\n",
"\n",
" pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),\n",
" torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),\n",
" torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)\n",
"\n",
" pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)\n",
"\n",
" yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw),\n",
" torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),\n",
" -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)\n",
" yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)\n",
"\n",
" roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),\n",
" torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),\n",
" torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)\n",
" roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)\n",
"\n",
" rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)\n",
"\n",
" return yaw, pitch, roll, yaw_mat.view(yaw_mat.shape[0], 9), pitch_mat.view(pitch_mat.shape[0], 9), roll_mat.view(roll_mat.shape[0], 9), rot_mat.view(rot_mat.shape[0], 9)\n",
"\n",
"class Audio2kpTransformerBBoxQDeep(nn.Module):\n",
" def __init__(self, embedding_dim, num_kp, num_w, face_adain=False):\n",
" super(Audio2kpTransformerBBoxQDeep, self).__init__()\n",
" self.embedding_dim = embedding_dim\n",
" self.num_kp = num_kp\n",
" self.num_w = num_w\n",
"\n",
"\n",
" self.embedding = nn.Embedding(41, embedding_dim)\n",
"\n",
" self.face_shrink = nn.Linear(240, 32)\n",
" self.hp_extractor = nn.Linear(45, 128)\n",
"\n",
" self.pos_enc = PositionalEncoding(128,20)\n",
" input_dim = 1\n",
"\n",
" self.decode_dim = 64\n",
" self.audio_embedding = nn.Sequential( # n x 29 x 16\n",
" nn.Conv1d(29, 32, kernel_size=3, stride=2,\n",
" padding=1, bias=True), # n x 32 x 8\n",
" nn.LeakyReLU(0.02, True),\n",
" nn.Conv1d(32, 32, kernel_size=3, stride=2,\n",
" padding=1, bias=True), # n x 32 x 4\n",
" nn.LeakyReLU(0.02, True),\n",
" nn.Conv1d(32, 64, kernel_size=3, stride=2,\n",
" padding=1, bias=True), # n x 64 x 2\n",
" nn.LeakyReLU(0.02, True),\n",
" nn.Conv1d(64, 64, kernel_size=3, stride=2,\n",
" padding=1, bias=True), # n x 64 x 1\n",
" nn.LeakyReLU(0.02, True),\n",
" )\n",
" self.encoder_fc1 = nn.Sequential(\n",
" nn.Linear(192, 128),\n",
" nn.LeakyReLU(0.02, True),\n",
" nn.Linear(128, 128),\n",
" )\n",
"\n",
" self.audio_embedding2 = nn.Sequential(nn.Conv2d(1, 8, (3, 17), stride=(1, 1), padding=(1, 0)),\n",
" # nn.GroupNorm(4, 8, affine=True),\n",
" BatchNorm2d(8),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(8, 32, (13, 13), stride=(1, 1), padding=(6, 6)))\n",
"\n",
" self.audioencoder = AudioEncoder(dim_in=64, style_dim=128, max_conv_dim=512, w_hpf=0, F0_channel=256)\n",
" # self.mappingnet = MappingNetwork(latent_dim=16, style_dim=128, num_domains=8, hidden_dim=512)\n",
" # self.stylenet = StyleEncoder(dim_in=64, style_dim=64, num_domains=8, max_conv_dim=512)\n",
" self.face_adain = face_adain\n",
" if self.face_adain:\n",
" self.fadain = AdaIN(style_dim=128, num_features=32)\n",
" # norm = 'layer_2d' #\n",
" norm = 'batch'\n",
"\n",
" self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32, normalize = norm),\n",
" mydownres2Dblock(32,48, normalize = norm),\n",
" mydownres2Dblock(48,64, normalize = norm),\n",
" mydownres2Dblock(64,96, normalize = norm),\n",
" mydownres2Dblock(96,128, normalize = norm),\n",
" nn.AvgPool2d(2))\n",
"\n",
" self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32),\n",
" mydownres2Dblock(32,64),\n",
" mydownres2Dblock(64,128),\n",
" mydownres2Dblock(128,128),\n",
" mydownres2Dblock(128,128),\n",
" nn.AvgPool2d(2))\n",
" self.transformer = Transformer()\n",
" self.kp = nn.Linear(128, 32)\n",
"\n",
" # for m in self.modules():\n",
" # if isinstance(m, nn.Conv2d):\n",
" # # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
" # nn.init.xavier_uniform_(m.weight, gain=1)\n",
"\n",
" # if isinstance(m, nn.Linear):\n",
" # # trunc_normal_(m.weight, std=.03)\n",
" # nn.init.xavier_uniform_(m.weight, gain=1)\n",
" # if isinstance(m, nn.Linear) and m.bias is not None:\n",
" # nn.init.constant_(m.bias, 0)\n",
"\n",
" F0_path = './Utils/JDC/bst.t7'\n",
" F0_model = JDCNet(num_class=1, seq_len=32)\n",
" params = torch.load(F0_path, map_location='cpu')['net']\n",
" F0_model.load_state_dict(params)\n",
" self.f0_model = F0_model\n",
"\n",
" def rotation_and_translation(self, headpose, bbs, bs):\n",
" # print(headpose['roll'].shape, headpose['yaw'].shape, headpose['pitch'].shape, headpose['t'].shape)\n",
"\n",
" yaw = headpose_pred_to_degree(headpose['yaw'].reshape(bbs*bs, -1))\n",
" pitch = headpose_pred_to_degree(headpose['pitch'].reshape(bbs*bs, -1))\n",
" roll = headpose_pred_to_degree(headpose['roll'].reshape(bbs*bs, -1))\n",
" yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v = get_rotation_matrix(yaw, pitch, roll)\n",
" t = headpose['t'].reshape(bbs*bs, -1)\n",
" # hp = torch.cat([yaw, pitch, roll, yaw_v, pitch_v, roll_v, t], dim=1)\n",
" hp = torch.cat([yaw.unsqueeze(1), pitch.unsqueeze(1), roll.unsqueeze(1), yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v, t], dim=1)\n",
" # hp = torch.cat([yaw, pitch, roll, torch.sin(yaw), torch.sin(pitch), torch.sin(roll), torch.cos(yaw), torch.cos(pitch), torch.cos(roll), t], dim=1)\n",
" return hp\n",
"\n",
" def forward(self, x, initial_kp = None, return_strg=False, emoprompt=None, hp=None, side=False):\n",
" bbs, bs, seqlen, _, _ = x['deep'].shape\n",
" # ph = x[\"pho\"].reshape(bbs*bs*seqlen, 1)\n",
" if hp is None:\n",
" hp = self.rotation_and_translation(x['he_driving'], bbs, bs)\n",
" hp = self.hp_extractor(hp)\n",
"\n",
" pose_feature = x[\"pose\"].reshape(bbs*bs*seqlen,1,64,64)\n",
" # pose_feature = self.down_pose(pose).contiguous()\n",
" ### phoneme input feature\n",
" # phoneme_embedding = self.embedding(ph.long())\n",
" # phoneme_embedding = phoneme_embedding.reshape(bbs*bs*seqlen, 1, 16, 16)\n",
" # phoneme_embedding = F.interpolate(phoneme_embedding, scale_factor=4)\n",
" # input_feature = torch.cat((pose_feature, phoneme_embedding), dim=1)\n",
" # print('input_feature: ', input_feature.shape)\n",
" # input_feature = phoneme_embedding\n",
"\n",
" audio = x['deep'].reshape(bbs*bs*seqlen, 16, 29).permute(0, 2, 1)\n",
" deep_feature = self.audio_embedding(audio).squeeze(-1)# ([264, 32, 16, 16])\n",
" # print(deep_feature.shape)\n",
"\n",
" input_feature = pose_feature\n",
" # print(input_feature.shape)\n",
" # assert(0)\n",
" input_feature = self.feature_extract(input_feature).reshape(bbs*bs*seqlen, 128)\n",
" input_feature = torch.cat([input_feature, deep_feature], dim=1)\n",
" input_feature = self.encoder_fc1(input_feature).reshape(bbs*bs, seqlen, 128)\n",
" # phoneme_embedding = self.phoneme_shrink(phoneme_embedding.squeeze(1))# 24*11, 128\n",
" input_feature = torch.cat([input_feature, hp.unsqueeze(1)], dim=1)\n",
"\n",
" ### decode audio feature\n",
" ### use iteration to avoid batchnorm2d in different audio sequence\n",
" decoder_features = []\n",
" for i in range(bbs):\n",
" F0 = self.f0_model.get_feature_GAN(x['mel'][i].reshape(bs, 1, 80, seqlen))\n",
" if emoprompt is None:\n",
" audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=None, masks=None, F0=F0))\n",
" else:\n",
" audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=emoprompt[i].unsqueeze(0), masks=None, F0=F0))\n",
" audio2 = torch.permute(audio_feature, (0, 3, 1, 2)).reshape(bs*seqlen, 1, 64, 80)\n",
" decoder_feature = self.audio_embedding2(audio2)\n",
"\n",
" # decoder_feature = torch.cat([decoder_feature, audio2], dim=1)\n",
" # decoder_feature = F.interpolate(decoder_feature, scale_factor=2)# ([264, 35, 64, 64])\n",
" face_map = initial_kp[\"prediction_map\"][i].reshape(15*16, 64*64).permute(1, 0).reshape(64*64, 15*16)\n",
" feature_map = self.face_shrink(face_map).permute(1, 0).reshape(1, 32, 64, 64)\n",
" if self.face_adain:\n",
" feature_map = self.fadain(feature_map, emoprompt)\n",
" decoder_feature = self.decodefeature_extract(torch.cat(\n",
" (decoder_feature,\n",
" feature_map.repeat(bs, seqlen, 1, 1, 1).reshape(bs * seqlen, 32, 64, 64)),\n",
" dim=1)).reshape(bs, seqlen, 128)\n",
" decoder_features.append(decoder_feature)\n",
" decoder_feature = torch.cat(decoder_features, dim=0)\n",
"\n",
" decoder_feature = torch.cat([decoder_feature, hp.unsqueeze(1)], dim=1)\n",
"\n",
" # decoder_feature = torch.cat([decoder_feature], dim=1)\n",
"\n",
" # a2kp transformer\n",
" # position embedding\n",
" if emoprompt is None:\n",
" posi_em = self.pos_enc(self.num_w*2+1+1) # 11 + headpose token\n",
" else:\n",
" posi_em = self.pos_enc(self.num_w*2+1+1+1) # 11 + headpose token + emotion prompt\n",
" emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(1)\n",
" input_feature = torch.cat([input_feature, emoprompt], dim=1)\n",
" decoder_feature = torch.cat([decoder_feature, emoprompt], dim=1)\n",
" out = {}\n",
" output_feature, memory = self.transformer(input_feature, decoder_feature, posi_em, )\n",
" output_feature = output_feature[-1, self.num_w] # returned intermediate output [6, 13, bbs*bs, 128]\n",
" out[\"emo\"] = self.kp(output_feature)\n",
" if side:\n",
" input_st = {}\n",
" input_st['hp'] = hp\n",
" input_st['decoder_feature'] = decoder_feature\n",
" input_st['memory'] = memory\n",
" return out, input_st\n",
" else:\n",
" return out\n",
"\n",
"\n",
"class Audio2kpTransformerBBoxQDeepPrompt(nn.Module):\n",
" def __init__(self, embedding_dim, num_kp, num_w, face_ea=False):\n",
" super(Audio2kpTransformerBBoxQDeepPrompt, self).__init__()\n",
" self.embedding_dim = embedding_dim\n",
" self.num_kp = num_kp\n",
" self.num_w = num_w\n",
"\n",
"\n",
" self.embedding = nn.Embedding(41, embedding_dim)\n",
"\n",
" self.face_shrink = nn.Linear(240, 32)\n",
" self.hp_extractor = nn.Linear(45, 128)\n",
"\n",
" self.pos_enc = PositionalEncoding(128,20)\n",
" input_dim = 1\n",
"\n",
" self.decode_dim = 64\n",
" self.audio_embedding = nn.Sequential( # n x 29 x 16\n",
" nn.Conv1d(29, 32, kernel_size=3, stride=2,\n",
" padding=1, bias=True), # n x 32 x 8\n",
" nn.LeakyReLU(0.02, True),\n",
" nn.Conv1d(32, 32, kernel_size=3, stride=2,\n",
" padding=1, bias=True), # n x 32 x 4\n",
" nn.LeakyReLU(0.02, True),\n",
" nn.Conv1d(32, 64, kernel_size=3, stride=2,\n",
" padding=1, bias=True), # n x 64 x 2\n",
" nn.LeakyReLU(0.02, True),\n",
" nn.Conv1d(64, 64, kernel_size=3, stride=2,\n",
" padding=1, bias=True), # n x 64 x 1\n",
" nn.LeakyReLU(0.02, True),\n",
" )\n",
" self.encoder_fc1 = nn.Sequential(\n",
" nn.Linear(192, 128),\n",
" nn.LeakyReLU(0.02, True),\n",
" nn.Linear(128, 128),\n",
" )\n",
"\n",
" self.audio_embedding2 = nn.Sequential(nn.Conv2d(1, 8, (3, 17), stride=(1, 1), padding=(1, 0)),\n",
" # nn.GroupNorm(4, 8, affine=True),\n",
" BatchNorm2d(8),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(8, 32, (13, 13), stride=(1, 1), padding=(6, 6)))\n",
"\n",
" self.audioencoder = AudioEncoder(dim_in=64, style_dim=128, max_conv_dim=512, w_hpf=0, F0_channel=256)\n",
" self.face_ea = face_ea\n",
" if self.face_ea:\n",
" self.fea = EAModule(style_dim=128, num_features=32)\n",
" norm = 'batch'\n",
"\n",
" self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32, normalize = norm),\n",
" mydownres2Dblock(32,48, normalize = norm),\n",
" mydownres2Dblock(48,64, normalize = norm),\n",
" mydownres2Dblock(64,96, normalize = norm),\n",
" mydownres2Dblock(96,128, normalize = norm),\n",
" nn.AvgPool2d(2))\n",
"\n",
" self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32),\n",
" mydownres2Dblock(32,64),\n",
" mydownres2Dblock(64,128),\n",
" mydownres2Dblock(128,128),\n",
" mydownres2Dblock(128,128),\n",
" nn.AvgPool2d(2))\n",
" self.transformer = TransformerDeep()\n",
" self.kp = nn.Linear(128, 32)\n",
"\n",
" F0_path = '/content/EAT_code/Utils/JDC/bst.t7'\n",
" F0_model = JDCNet(num_class=1, seq_len=32)\n",
" params = torch.load(F0_path, map_location='cpu')['net']\n",
" F0_model.load_state_dict(params)\n",
" self.f0_model = F0_model\n",
"\n",
" def rotation_and_translation(self, headpose, bbs, bs):\n",
" yaw = headpose_pred_to_degree(headpose['yaw'].reshape(bbs*bs, -1))\n",
" pitch = headpose_pred_to_degree(headpose['pitch'].reshape(bbs*bs, -1))\n",
" roll = headpose_pred_to_degree(headpose['roll'].reshape(bbs*bs, -1))\n",
" yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v = get_rotation_matrix(yaw, pitch, roll)\n",
" t = headpose['t'].reshape(bbs*bs, -1)\n",
" hp = torch.cat([yaw.unsqueeze(1), pitch.unsqueeze(1), roll.unsqueeze(1), yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v, t], dim=1)\n",
" return hp\n",
"\n",
" def forward(self, x, initial_kp = None, return_strg=False, emoprompt=None, deepprompt=None, hp=None, side=False):\n",
" bbs, bs, seqlen, _, _ = x['deep'].shape\n",
" # ph = x[\"pho\"].reshape(bbs*bs*seqlen, 1)\n",
" if hp is None:\n",
" hp = self.rotation_and_translation(x['he_driving'], bbs, bs)\n",
" hp = self.hp_extractor(hp)\n",
"\n",
" pose_feature = x[\"pose\"].reshape(bbs*bs*seqlen,1,64,64)\n",
"\n",
" audio = x['deep'].reshape(bbs*bs*seqlen, 16, 29).permute(0, 2, 1)\n",
" deep_feature = self.audio_embedding(audio).squeeze(-1)# ([264, 32, 16, 16])\n",
"\n",
" input_feature = pose_feature\n",
" input_feature = self.feature_extract(input_feature).reshape(bbs*bs*seqlen, 128)\n",
" input_feature = torch.cat([input_feature, deep_feature], dim=1)\n",
" input_feature = self.encoder_fc1(input_feature).reshape(bbs*bs, seqlen, 128)\n",
" input_feature = torch.cat([input_feature, hp.unsqueeze(1)], dim=1)\n",
"\n",
" ### decode audio feature\n",
" ### use iteration to avoid batchnorm2d in different audio sequence\n",
" decoder_features = []\n",
" for i in range(bbs):\n",
" F0 = self.f0_model.get_feature_GAN(x['mel'][i].reshape(bs, 1, 80, seqlen))\n",
" if emoprompt is None:\n",
" audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=None, masks=None, F0=F0))\n",
" else:\n",
" audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=emoprompt[i].unsqueeze(0), masks=None, F0=F0))\n",
" audio2 = torch.permute(audio_feature, (0, 3, 1, 2)).reshape(bs*seqlen, 1, 64, 80)\n",
" decoder_feature = self.audio_embedding2(audio2)\n",
"\n",
" face_map = initial_kp[\"prediction_map\"][i].reshape(15*16, 64*64).permute(1, 0).reshape(64*64, 15*16)\n",
" face_feature_map = self.face_shrink(face_map).permute(1, 0).reshape(1, 32, 64, 64)\n",
" if self.face_ea:\n",
" face_feature_map = self.fea(face_feature_map, emoprompt)\n",
" decoder_feature = self.decodefeature_extract(torch.cat(\n",
" (decoder_feature,\n",
" face_feature_map.repeat(bs, seqlen, 1, 1, 1).reshape(bs * seqlen, 32, 64, 64)),\n",
" dim=1)).reshape(bs, seqlen, 128)\n",
" decoder_features.append(decoder_feature)\n",
" decoder_feature = torch.cat(decoder_features, dim=0)\n",
"\n",
" decoder_feature = torch.cat([decoder_feature, hp.unsqueeze(1)], dim=1)\n",
"\n",
" # a2kp transformer\n",
" # position embedding\n",
" if emoprompt is None:\n",
" posi_em = self.pos_enc(self.num_w*2+1+1) # 11 + headpose token\n",
" else:\n",
" posi_em = self.pos_enc(self.num_w*2+1+1+1) # 11 + headpose token + deep emotion prompt\n",
" out = {}\n",
" output_feature, memory = self.transformer(input_feature, decoder_feature, posi_em, deepprompt)\n",
" output_feature = output_feature[-1, self.num_w] # returned intermediate output [6, 13, bbs*bs, 128]\n",
" out[\"emo\"] = self.kp(output_feature)\n",
" if side:\n",
" input_st = {}\n",
" input_st['hp'] = hp\n",
" input_st['face_feature_map'] = face_feature_map\n",
" input_st['bs'] = bs\n",
" input_st['bbs'] = bbs\n",
" return out, input_st\n",
" else:\n",
" return out\n",
"\n",
"\n"
],
"metadata": {
"cellView": "form",
"id": "DZwVMAPDgZvO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title deepspeechfeatures.py fixed\n",
"\n",
"\"\"\"\n",
" DeepSpeech features processing routines.\n",
" NB: Based on VOCA code. See the corresponding license restrictions.\n",
"\"\"\"\n",
"\n",
"__all__ = ['conv_audios_to_deepspeech']\n",
"\n",
"import numpy as np\n",
"import warnings\n",
"import resampy\n",
"from scipy.io import wavfile\n",
"from python_speech_features import mfcc\n",
"import tensorflow as tf\n",
"\n",
"\n",
"def conv_audios_to_deepspeech(audios,\n",
" out_files,\n",
" num_frames_info,\n",
" deepspeech_pb_path,\n",
" audio_window_size=16,\n",
" audio_window_stride=1):\n",
" \"\"\"\n",
" Convert list of audio files into files with DeepSpeech features.\n",
"\n",
" Parameters\n",
" ----------\n",
" audios : list of str or list of None\n",
" Paths to input audio files.\n",
" out_files : list of str\n",
" Paths to output files with DeepSpeech features.\n",
" num_frames_info : list of int\n",
" List of numbers of frames.\n",
" deepspeech_pb_path : str\n",
" Path to DeepSpeech 0.1.0 frozen model.\n",
" audio_window_size : int, default 16\n",
" Audio window size.\n",
" audio_window_stride : int, default 1\n",
" Audio window stride.\n",
" \"\"\"\n",
" graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(deepspeech_pb_path)\n",
"\n",
" with tf.compat.v1.Session(graph=graph) as sess:\n",
" for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):\n",
" audio_sample_rate, audio = wavfile.read(audio_file_path)\n",
" if audio.ndim != 1:\n",
" warnings.warn(\"Audio has multiple channels, the first channel is used\")\n",
" audio = audio[:, 0]\n",
" ds_features = pure_conv_audio_to_deepspeech(\n",
" audio=audio,\n",
" audio_sample_rate=audio_sample_rate,\n",
" audio_window_size=audio_window_size,\n",
" audio_window_stride=audio_window_stride,\n",
" num_frames=num_frames,\n",
" net_fn=lambda x: sess.run(\n",
" logits_ph,\n",
" feed_dict={\n",
" input_node_ph: x[np.newaxis, ...],\n",
" input_lengths_ph: [x.shape[0]]}))\n",
" np.save(out_file_path, ds_features)\n",
"\n",
"\n",
"# data_util/deepspeech_features/deepspeech_features.py\n",
"def prepare_deepspeech_net(deepspeech_pb_path):\n",
" # Load graph and place_holders:\n",
" with tf.io.gfile.GFile(deepspeech_pb_path, \"rb\") as f:\n",
" graph_def = tf.compat.v1.GraphDef()\n",
" graph_def.ParseFromString(f.read())\n",
"\n",
" graph = tf.compat.v1.get_default_graph()\n",
"\n",
" tf.import_graph_def(graph_def, name=\"deepspeech\")\n",
" # check all graphs\n",
" # print('~'*50, [tensor for tensor in graph._nodes_by_name], '~'*50)\n",
" # print('~'*50, [tensor.name for tensor in graph.get_operations()], '~'*50)\n",
" # i modified\n",
" logits_ph = graph.get_tensor_by_name(\"logits:0\")\n",
" input_node_ph = graph.get_tensor_by_name(\"input_node:0\")\n",
" input_lengths_ph = graph.get_tensor_by_name(\"input_lengths:0\")\n",
" # original\n",
" # logits_ph = graph.get_tensor_by_name(\"deepspeech/logits:0\")\n",
" # input_node_ph = graph.get_tensor_by_name(\"deepspeech/input_node:0\")\n",
" # input_lengths_ph = graph.get_tensor_by_name(\"deepspeech/input_lengths:0\")\n",
"\n",
" return graph, logits_ph, input_node_ph, input_lengths_ph\n",
"\n",
"\n",
"def pure_conv_audio_to_deepspeech(audio,\n",
" audio_sample_rate,\n",
" audio_window_size,\n",
" audio_window_stride,\n",
" num_frames,\n",
" net_fn):\n",
" \"\"\"\n",
" Core routine for converting audion into DeepSpeech features.\n",
"\n",
" Parameters\n",
" ----------\n",
" audio : np.array\n",
" Audio data.\n",
" audio_sample_rate : int\n",
" Audio sample rate.\n",
" audio_window_size : int\n",
" Audio window size.\n",
" audio_window_stride : int\n",
" Audio window stride.\n",
" num_frames : int or None\n",
" Numbers of frames.\n",
" net_fn : func\n",
" Function for DeepSpeech model call.\n",
"\n",
" Returns\n",
" -------\n",
" np.array\n",
" DeepSpeech features.\n",
" \"\"\"\n",
" target_sample_rate = 16000\n",
" if audio_sample_rate != target_sample_rate:\n",
" resampled_audio = resampy.resample(\n",
" x=audio.astype(np.float),\n",
" sr_orig=audio_sample_rate,\n",
" sr_new=target_sample_rate)\n",
" else:\n",
" resampled_audio = audio.astype(np.float32)\n",
" input_vector = conv_audio_to_deepspeech_input_vector(\n",
" audio=resampled_audio.astype(np.int16),\n",
" sample_rate=target_sample_rate,\n",
" num_cepstrum=26,\n",
" num_context=9)\n",
"\n",
" network_output = net_fn(input_vector)\n",
"\n",
" deepspeech_fps = 50\n",
" video_fps = 60\n",
" audio_len_s = float(audio.shape[0]) / audio_sample_rate\n",
" if num_frames is None:\n",
" num_frames = int(round(audio_len_s * video_fps))\n",
" else:\n",
" video_fps = num_frames / audio_len_s\n",
" network_output = interpolate_features(\n",
" features=network_output[:, 0],\n",
" input_rate=deepspeech_fps,\n",
" output_rate=video_fps,\n",
" output_len=num_frames)\n",
"\n",
" # Make windows:\n",
" zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))\n",
" network_output = np.concatenate((zero_pad, network_output, zero_pad), axis=0)\n",
" windows = []\n",
" for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):\n",
" windows.append(network_output[window_index:window_index + audio_window_size])\n",
"\n",
" return np.array(windows)\n",
"\n",
"\n",
"def conv_audio_to_deepspeech_input_vector(audio,\n",
" sample_rate,\n",
" num_cepstrum,\n",
" num_context):\n",
" \"\"\"\n",
" Convert audio raw data into DeepSpeech input vector.\n",
"\n",
" Parameters\n",
" ----------\n",
" audio : np.array\n",
" Audio data.\n",
" audio_sample_rate : int\n",
" Audio sample rate.\n",
" num_cepstrum : int\n",
" Number of cepstrum.\n",
" num_context : int\n",
" Number of context.\n",
"\n",
" Returns\n",
" -------\n",
" np.array\n",
" DeepSpeech input vector.\n",
" \"\"\"\n",
" # Get mfcc coefficients:\n",
" features = mfcc(\n",
" signal=audio,\n",
" samplerate=sample_rate,\n",
" numcep=num_cepstrum)\n",
"\n",
" # We only keep every second feature (BiRNN stride = 2):\n",
" features = features[::2]\n",
"\n",
" # One stride per time step in the input:\n",
" num_strides = len(features)\n",
"\n",
" # Add empty initial and final contexts:\n",
" empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)\n",
" features = np.concatenate((empty_context, features, empty_context))\n",
"\n",
" # Create a view into the array with overlapping strides of size\n",
" # numcontext (past) + 1 (present) + numcontext (future):\n",
" window_size = 2 * num_context + 1\n",
" train_inputs = np.lib.stride_tricks.as_strided(\n",
" features,\n",
" shape=(num_strides, window_size, num_cepstrum),\n",
" strides=(features.strides[0], features.strides[0], features.strides[1]),\n",
" writeable=False)\n",
"\n",
" # Flatten the second and third dimensions:\n",
" train_inputs = np.reshape(train_inputs, [num_strides, -1])\n",
"\n",
" train_inputs = np.copy(train_inputs)\n",
" train_inputs = (train_inputs - np.mean(train_inputs)) / np.std(train_inputs)\n",
"\n",
" return train_inputs\n",
"\n",
"\n",
"def interpolate_features(features,\n",
" input_rate,\n",
" output_rate,\n",
" output_len):\n",
" \"\"\"\n",
" Interpolate DeepSpeech features.\n",
"\n",
" Parameters\n",
" ----------\n",
" features : np.array\n",
" DeepSpeech features.\n",
" input_rate : int\n",
" input rate (FPS).\n",
" output_rate : int\n",
" Output rate (FPS).\n",
" output_len : int\n",
" Output data length.\n",
"\n",
" Returns\n",
" -------\n",
" np.array\n",
" Interpolated data.\n",
" \"\"\"\n",
" input_len = features.shape[0]\n",
" num_features = features.shape[1]\n",
" input_timestamps = np.arange(input_len) / float(input_rate)\n",
" output_timestamps = np.arange(output_len) / float(output_rate)\n",
" output_features = np.zeros((output_len, num_features))\n",
" for feature_idx in range(num_features):\n",
" output_features[:, feature_idx] = np.interp(\n",
" x=output_timestamps,\n",
" xp=input_timestamps,\n",
" fp=features[:, feature_idx])\n",
" return output_features\n"
],
"metadata": {
"cellView": "form",
"id": "Jq0kqup1Ogqd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title deepspeechstore fixed\n",
"\n",
"\"\"\"\n",
" Routines for loading DeepSpeech model.\n",
"\"\"\"\n",
"\n",
"__all__ = ['get_deepspeech_model_file']\n",
"\n",
"import os\n",
"import zipfile\n",
"import logging\n",
"import hashlib\n",
"\n",
"\n",
"deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'\n",
"\n",
"\n",
"def get_deepspeech_model_file(local_model_store_dir_path=os.path.join(\"~\", \"/content/EAT_code/tensorflow\", \"models\")):\n",
" \"\"\"\n",
" Return location for the pretrained on local file system. This function will download from online model zoo when\n",
" model cannot be found or has mismatch. The root directory will be created if it doesn't exist.\n",
"\n",
" Parameters\n",
" ----------\n",
" local_model_store_dir_path : str, default $TENSORFLOW_HOME/models\n",
" Location for keeping the model parameters.\n",
"\n",
" Returns\n",
" -------\n",
" file_path\n",
" Path to the requested pretrained model file.\n",
" \"\"\"\n",
" sha1_hash = \"b90017e816572ddce84f5843f1fa21e6a377975e\"\n",
" file_name = \"deepspeech-0_1_0-b90017e8.pb\"\n",
" local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)\n",
" file_path = os.path.join(local_model_store_dir_path, file_name)\n",
" if os.path.exists(file_path):\n",
" if _check_sha1(file_path, sha1_hash):\n",
" return file_path\n",
" else:\n",
" logging.warning(\"Mismatch in the content of model file detected. Downloading again.\")\n",
" else:\n",
" logging.info(\"Model file not found. Downloading to {}.\".format(file_path))\n",
"\n",
" if not os.path.exists(local_model_store_dir_path):\n",
" os.makedirs(local_model_store_dir_path)\n",
"\n",
" zip_file_path = file_path + \".zip\"\n",
" _download(\n",
" url=\"{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip\".format(\n",
" repo_url=deepspeech_features_repo_url,\n",
" repo_release_tag=\"v0.0.1\",\n",
" file_name=file_name),\n",
" path=zip_file_path,\n",
" overwrite=False)\n",
" with zipfile.ZipFile(zip_file_path) as zf:\n",
" zf.extractall(local_model_store_dir_path)\n",
" os.remove(zip_file_path)\n",
"\n",
" if _check_sha1(file_path, sha1_hash):\n",
" return file_path\n",
" else:\n",
" raise ValueError(\"Downloaded file has different hash. Please try again.\")\n",
"\n",
"\n",
"def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):\n",
" \"\"\"\n",
" Download an given URL\n",
"\n",
" Parameters\n",
" ----------\n",
" url : str\n",
" URL to download\n",
" path : str, optional\n",
" Destination path to store downloaded file. By default stores to the\n",
" current directory with same name as in url.\n",
" overwrite : bool, optional\n",
" Whether to overwrite destination file if already exists.\n",
" sha1_hash : str, optional\n",
" Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified\n",
" but doesn't match.\n",
" retries : integer, default 5\n",
" The number of times to attempt the download in case of failure or non 200 return codes\n",
" verify_ssl : bool, default True\n",
" Verify SSL certificates.\n",
"\n",
" Returns\n",
" -------\n",
" str\n",
" The file path of the downloaded file.\n",
" \"\"\"\n",
" import warnings\n",
" try:\n",
" import requests\n",
" except ImportError:\n",
" class requests_failed_to_import(object):\n",
" pass\n",
" requests = requests_failed_to_import\n",
"\n",
" if path is None:\n",
" fname = url.split(\"/\")[-1]\n",
" # Empty filenames are invalid\n",
" assert fname, \"Can't construct file-name from this URL. Please set the `path` option manually.\"\n",
" else:\n",
" path = os.path.expanduser(path)\n",
" if os.path.isdir(path):\n",
" fname = os.path.join(path, url.split(\"/\")[-1])\n",
" else:\n",
" fname = path\n",
" assert retries >= 0, \"Number of retries should be at least 0\"\n",
"\n",
" if not verify_ssl:\n",
" warnings.warn(\n",
" \"Unverified HTTPS request is being made (verify_ssl=False). \"\n",
" \"Adding certificate verification is strongly advised.\")\n",
"\n",
" if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):\n",
" dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))\n",
" if not os.path.exists(dirname):\n",
" os.makedirs(dirname)\n",
" while retries + 1 > 0:\n",
" # Disable pyling too broad Exception\n",
" # pylint: disable=W0703\n",
" try:\n",
" print(\"Downloading {} from {}...\".format(fname, url))\n",
" #r = requests.get(url, stream=True, verify=verify_ssl)\n",
" #if r.status_code != 200:\n",
" # raise RuntimeError(\"Failed downloading url {}\".format(url))\n",
" #with open(fname, \"wb\") as f:\n",
" # for chunk in r.iter_content(chunk_size=1024):\n",
" # if chunk: # filter out keep-alive new chunks\n",
" # f.write(chunk)\n",
" if sha1_hash and not _check_sha1(fname, sha1_hash):\n",
" raise UserWarning(\"File {} is downloaded but the content hash does not match.\"\n",
" \" The repo may be outdated or download may be incomplete. \"\n",
" \"If the `repo_url` is overridden, consider switching to \"\n",
" \"the default repo.\".format(fname))\n",
" break\n",
" except Exception as e:\n",
" retries -= 1\n",
" if retries <= 0:\n",
" raise e\n",
" else:\n",
" print(\"download failed, retrying, {} attempt{} left\"\n",
" .format(retries, \"s\" if retries > 1 else \"\"))\n",
"\n",
" return fname\n",
"\n",
"\n",
"def _check_sha1(filename, sha1_hash):\n",
" \"\"\"\n",
" Check whether the sha1 hash of the file content matches the expected hash.\n",
"\n",
" Parameters\n",
" ----------\n",
" filename : str\n",
" Path to the file.\n",
" sha1_hash : str\n",
" Expected sha1 hash in hexadecimal digits.\n",
"\n",
" Returns\n",
" -------\n",
" bool\n",
" Whether the file content matches the expected hash.\n",
" \"\"\"\n",
" sha1 = hashlib.sha1()\n",
" with open(filename, \"rb\") as f:\n",
" while True:\n",
" data = f.read(1048576)\n",
" if not data:\n",
" break\n",
" sha1.update(data)\n",
"\n",
" return sha1.hexdigest() == sha1_hash\n"
],
"metadata": {
"cellView": "form",
"id": "BaZ1iwyuO_bl"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!zip -r EAT_code /content/EAT_code"
],
"metadata": {
"id": "wAFm27NITV_C"
},
"execution_count": null,
"outputs": []
}
]
}