diff --git "a/EAT_v2.ipynb" "b/EAT_v2.ipynb" new file mode 100644--- /dev/null +++ "b/EAT_v2.ipynb" @@ -0,0 +1,2659 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "#@title set font and wrap\n", + "from IPython.display import HTML, display\n", + "\n", + "def set_css():\n", + " display(HTML('''\n", + " \n", + " '''))\n", + "get_ipython().events.register('pre_run_cell', set_css)" + ], + "metadata": { + "cellView": "form", + "id": "2eK7PTTQiAWv" + }, + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LBBjwIn0ROsF" + }, + "outputs": [], + "source": [ + "#!git clone https://github.com/yuangan/EAT_code.git\n", + "!wget https://huggingface.co/waveydaveygravy/styletalk/resolve/main/EAT_code.zip\n", + "!unzip /content/EAT_code.zip\n", + "%cd /content/EAT_code\n", + "!pip install -r requirements.txt\n", + "!pip install resampy\n", + "#!pip install face_detection\n", + "!pip install python_speech_features\n", + "#@title make directories if not done yet\n", + "!mkdir tensorflow\n", + "!mkdir Results\n", + "!mkdir ckpt\n", + "!mkdir demo\n", + "!mkdir Utils\n", + "%cd /content/EAT_code/tensorflow\n", + "!mkdir models\n", + "%cd /content/EAT_code\n", + "print(\"done\")" + ] + }, + { + "cell_type": "code", + "source": [ + "#@title in /content/EAT_code/config/deepprompt_eam3d_st_tanh_304_3090_all.yaml change batch_size to 1 line 70" + ], + "metadata": { + "cellView": "form", + "id": "_Mxxg6d-WakT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title make directories if not done yet\n", + "!mkdir tensorflow\n", + "!mkdir ckpt\n", + "!mkdir demo\n", + "!mkdir Utils\n", + "%cd /content/EAT_code/demo\n", + "!mkdir imgs1\n", + "!mkdir imgs_cropped1\n", + "!mkdir imgs_latent1\n", + "%cd /content/EAT_code/tensorflow\n", + "!mkdir models\n", + "%cd /content/EAT_code" + ], + "metadata": { + "id": "3bf2Yff2DEdH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title donwload and place deepspeech model\n", + "%cd /content/EAT_code\n", + "!gdown --id 1KK15n2fOdfLECWN5wvX54mVyDt18IZCo && unzip -q ckpt.zip -d ckpt\n", + "!gdown --id 1MeFGC7ig-vgpDLdhh2vpTIiElrhzZmgT && unzip -q demo.zip -d demo\n", + "!gdown --id 1HGVzckXh-vYGZEUUKMntY1muIbkbnRcd && unzip -q Utils.zip -d Utils\n", + "%cd /content/EAT_code/tensorflow/models\n", + "!wget https://github.com/osmr/deepspeech_features/releases/download/v0.0.1/deepspeech-0_1_0-b90017e8.pb.zip\n", + "%cd /content/EAT_code" + ], + "metadata": { + "id": "ucK4MQbq0yHx" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title download models\n", + "%cd /content/EAT_code\n", + "\n", + "!gdown --id 1KK15n2fOdfLECWN5wvX54mVyDt18IZCo && unzip -q ckpt.zip -d ckpt\n", + "!gdown --id 1MeFGC7ig-vgpDLdhh2vpTIiElrhzZmgT && unzip -q demo.zip -d demo\n", + "!gdown --id 1HGVzckXh-vYGZEUUKMntY1muIbkbnRcd && unzip -q Utils.zip -d Utils" + ], + "metadata": { + "id": "eDDFgToSSEgj", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title upload custom mp4 to videos\n", + "%cd /content/EAT_code/preprocess/video\n", + "from google.colab import files\n", + "uploaded = files.upload()\n", + "%cd /content/EAT_code" + ], + "metadata": { + "id": "rlfKO73uUXFT", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 106 + }, + "outputId": "e8116307-0a3f-4eea-ebd1-925f3e17a4bf" + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/EAT_code/preprocess/video\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " \n", + " Upload widget is only available when the cell has been executed in the\n", + " current browser session. Please rerun this cell to enable.\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Saving bo_1resized.mp4 to bo_1resized.mp4\n", + "/content/EAT_code\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title extract boundary boxes\n", + "!python /content/EAT_code/preprocess/extract_bbox.py" + ], + "metadata": { + "cellView": "form", + "id": "llzj0RprSPu6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title place custom video in preprocess/video\n", + "%cd /content/EAT_code/preprocess\n", + "!python /content/EAT_code/preprocess/preprocess_video.py # --deepspeech \"/content/EAT_code/tensorflow/models/deepspeech-0_1_0-b90017e8.pb\"" + ], + "metadata": { + "id": "p99eLmnjW-e1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + " print(f\"Number of paths: {len(self.A_paths)}\")\n", + " print(\"Paths:\")\n", + " for path in self.A_paths:\n", + " print(path)" + ], + "metadata": { + "id": "B4LWLwEvJOsp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "\n", + "# Replace \"path/to/your/file.npy\" with the actual path to your file\n", + "data = np.load(\"/content/EAT_code/demo/video_processed/obama/latent_evp_25/obama.npy\")\n", + "\n", + "print(f\"Type: {type(data)}\")\n", + "print(f\"Shape: {data.shape}\")" + ], + "metadata": { + "id": "V-hPYJBgYIKx" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!python /content/EAT_code/demo.py --root_wav /content/EAT_code/demo/video_processed1/bo_1resized --emo ang --save_dir /content/EAT_code/Results" + ], + "metadata": { + "id": "APJUoQkgaqa9", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "e8150705-1bae-4420-8306-069aa274fb1d" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "deepprompt_eam3d_all_final_313\n", + "cuda is available\n", + "/usr/local/lib/python3.10/dist-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2228.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", + " 0% 0/1 [00:00 1:\n", + " wave_tensor = wave_tensor[:, 0]\n", + " mel_tensor = to_melspec(wave_tensor)\n", + " mel_tensor = (torch.log(1e-5 + mel_tensor) - mean) / std\n", + " name_len = min(mel_tensor.shape[1], poseimg.shape[0], deepfeature.shape[0])\n", + "\n", + " audio_frames = []\n", + " poseimgs = []\n", + " deep_feature = []\n", + "\n", + " pad, deep_pad = np.load('pad.npy', allow_pickle=True)\n", + "\n", + " if name_len < num_frames:\n", + " diff = num_frames - name_len\n", + " if diff > 2:\n", + " print(f\"Attention: the frames are {diff} more than name_len, we will use name_len to replace num_frames\")\n", + " num_frames=name_len\n", + " for k in he_driving.keys():\n", + " he_driving[k] = he_driving[k][:name_len, :]\n", + " for rid in range(0, num_frames):\n", + " audio = []\n", + " poses = []\n", + " deeps = []\n", + " for i in range(rid - opt['num_w'], rid + opt['num_w'] + 1):\n", + " if i < 0:\n", + " audio.append(pad)\n", + " poses.append(poseimg[0])\n", + " deeps.append(deep_pad)\n", + " elif i >= name_len:\n", + " audio.append(pad)\n", + " poses.append(poseimg[-1])\n", + " deeps.append(deep_pad)\n", + " else:\n", + " audio.append(mel_tensor[:, i])\n", + " poses.append(poseimg[i])\n", + " deeps.append(deepfeature[i])\n", + "\n", + " audio_frames.append(torch.stack(audio, dim=1))\n", + " poseimgs.append(poses)\n", + " deep_feature.append(deeps)\n", + " audio_frames = torch.stack(audio_frames, dim=0)\n", + " poseimgs = torch.from_numpy(np.array(poseimgs))\n", + " deep_feature = torch.from_numpy(np.array(deep_feature)).to(torch.float)\n", + " return audio_frames, poseimgs, deep_feature, source_img, he_source, he_driving, num_frames, y_trg, z_trg, latent_path_driving\n", + "\n", + "def load_ckpt(ckpt, kp_detector, generator, audio2kptransformer, sidetuning, emotionprompt):\n", + " checkpoint = torch.load(ckpt, map_location=torch.device('cpu'))\n", + " if audio2kptransformer is not None:\n", + " audio2kptransformer.load_state_dict(checkpoint['audio2kptransformer'])\n", + " if generator is not None:\n", + " generator.load_state_dict(checkpoint['generator'])\n", + " if kp_detector is not None:\n", + " kp_detector.load_state_dict(checkpoint['kp_detector'])\n", + " if sidetuning is not None:\n", + " sidetuning.load_state_dict(checkpoint['sidetuning'])\n", + " if emotionprompt is not None:\n", + " emotionprompt.load_state_dict(checkpoint['emotionprompt'])\n", + "\n", + "import cv2\n", + "import dlib\n", + "from tqdm import tqdm\n", + "from skimage import transform as tf\n", + "detector = dlib.get_frontal_face_detector()\n", + "predictor = dlib.shape_predictor('/content/EAT_code/demo/shape_predictor_68_face_landmarks.dat')\n", + "\n", + "def shape_to_np(shape, dtype=\"int\"):\n", + " # initialize the list of (x, y)-coordinates\n", + " coords = np.zeros((shape.num_parts, 2), dtype=dtype)\n", + "\n", + " # loop over all facial landmarks and convert them\n", + " # to a 2-tuple of (x, y)-coordinates\n", + " for i in range(0, shape.num_parts):\n", + " coords[i] = (shape.part(i).x, shape.part(i).y)\n", + "\n", + " # return the list of (x, y)-coordinates\n", + " return coords\n", + "\n", + "def crop_image(image_path, out_path):\n", + " template = np.load('./demo/M003_template.npy')\n", + " image = cv2.imread(image_path)\n", + " gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n", + " rects = detector(gray, 1) #detect human face\n", + " if len(rects) != 1:\n", + " return 0\n", + " for (j, rect) in enumerate(rects):\n", + " shape = predictor(gray, rect) #detect 68 points\n", + " shape = shape_to_np(shape)\n", + "\n", + " pts2 = np.float32(template[:47,:])\n", + " pts1 = np.float32(shape[:47,:]) #eye and nose\n", + " tform = tf.SimilarityTransform()\n", + " tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.\n", + "\n", + " dst = tf.warp(image, tform, output_shape=(256, 256))\n", + "\n", + " dst = np.array(dst * 255, dtype=np.uint8)\n", + "\n", + " cv2.imwrite(out_path, dst)\n", + "\n", + "def preprocess_imgs(allimgs, tmp_allimgs_cropped):\n", + " name_cropped = []\n", + " for path in tmp_allimgs_cropped:\n", + " name_cropped.append(os.path.basename(path))\n", + " for path in allimgs:\n", + " if os.path.basename(path) in name_cropped:\n", + " continue\n", + " else:\n", + " out_path = path.replace('imgs/', 'imgs_cropped/')\n", + " crop_image(path, out_path)\n", + "\n", + "from sync_batchnorm import DataParallelWithCallback\n", + "def load_checkpoints_extractor(config_path, checkpoint_path, cpu=False):\n", + "\n", + " with open(config_path) as f:\n", + " config = yaml.load(f, Loader=yaml.FullLoader)\n", + "\n", + " kp_detector = KPDetector(**config['model_params']['kp_detector_params'],\n", + " **config['model_params']['common_params'])\n", + " if not cpu:\n", + " kp_detector.cuda()\n", + "\n", + " he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],\n", + " **config['model_params']['common_params'])\n", + " if not cpu:\n", + " he_estimator.cuda()\n", + "\n", + " if cpu:\n", + " checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))\n", + " else:\n", + " checkpoint = torch.load(checkpoint_path)\n", + "\n", + " kp_detector.load_state_dict(checkpoint['kp_detector'])\n", + " he_estimator.load_state_dict(checkpoint['he_estimator'])\n", + "\n", + " if not cpu:\n", + " kp_detector = DataParallelWithCallback(kp_detector)\n", + " he_estimator = DataParallelWithCallback(he_estimator)\n", + "\n", + " kp_detector.eval()\n", + " he_estimator.eval()\n", + "\n", + " return kp_detector, he_estimator\n", + "\n", + "def estimate_latent(driving_video, kp_detector, he_estimator):\n", + " with torch.no_grad():\n", + " predictions = []\n", + " driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).cuda()\n", + " kp_canonical = kp_detector(driving[:, :, 0])\n", + " he_drivings = {'yaw': [], 'pitch': [], 'roll': [], 't': [], 'exp': []}\n", + "\n", + " for frame_idx in range(driving.shape[2]):\n", + " driving_frame = driving[:, :, frame_idx]\n", + " he_driving = he_estimator(driving_frame)\n", + " for k in he_drivings.keys():\n", + " he_drivings[k].append(he_driving[k])\n", + " return [kp_canonical, he_drivings]\n", + "\n", + "def extract_keypoints(extract_list):\n", + " kp_detector, he_estimator = load_checkpoints_extractor(config_path='config/vox-256-spade.yaml', checkpoint_path='./ckpt/pretrain_new_274.pth.tar')\n", + " if not os.path.exists('./demo/imgs_latent/'):\n", + " os.makedirs('./demo/imgs_latent/')\n", + " for imgname in tqdm(extract_list):\n", + " path_frames = [imgname]\n", + " filesname=os.path.basename(imgname)[:-4]\n", + " if os.path.exists(f'./demo/imgs_latent/'+filesname+'.npy'):\n", + " continue\n", + " driving_frames = []\n", + " for im in path_frames:\n", + " driving_frames.append(imageio.imread(im))\n", + " driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_frames]\n", + "\n", + " kc, he = estimate_latent(driving_video, kp_detector, he_estimator)\n", + " kc = kc['value'].cpu().numpy()\n", + " for k in he:\n", + " he[k] = torch.cat(he[k]).cpu().numpy()\n", + " np.save('./demo/imgs_latent/'+filesname, [kc, he])\n", + "\n", + "def preprocess_cropped_imgs(allimgs_cropped):\n", + " extract_list = []\n", + " for img_path in allimgs_cropped:\n", + " if not os.path.exists(img_path.replace('cropped', 'latent')[:-4]+'.npy'):\n", + " extract_list.append(img_path)\n", + " if len(extract_list) > 0:\n", + " print('=========', \"Extract latent keypoints from New image\", '======')\n", + " extract_keypoints(extract_list)\n", + "\n", + "def test(ckpt, emotype, save_dir=\" \"):\n", + " # with open(\"config/vox-transformer2.yaml\") as f:\n", + " with open(\"/content/EAT_code/config/deepprompt_eam3d_st_tanh_304_3090_all.yaml\") as f:\n", + " config = yaml.load(f, Loader=yaml.FullLoader)\n", + " cur_path = os.getcwd()\n", + " generator, kp_detector, audio2kptransformer, sidetuning, emotionprompt = build_model(config)\n", + " load_ckpt(ckpt, kp_detector=kp_detector, generator=generator, audio2kptransformer=audio2kptransformer, sidetuning=sidetuning, emotionprompt=emotionprompt)\n", + "\n", + " audio2kptransformer.eval()\n", + " generator.eval()\n", + " kp_detector.eval()\n", + " sidetuning.eval()\n", + " emotionprompt.eval()\n", + "\n", + " all_wavs2 = [f'{root_wav}/{os.path.basename(root_wav)}.wav']\n", + " allimg = glob.glob('./demo/imgs/*.jpg')\n", + " tmp_allimg_cropped = glob.glob('./demo/imgs_cropped/*.jpg')\n", + " preprocess_imgs(allimg, tmp_allimg_cropped) # crop and align images\n", + "\n", + " allimg_cropped = glob.glob('./demo/imgs_cropped/*.jpg')\n", + " preprocess_cropped_imgs(allimg_cropped) # extract latent keypoints if necessary\n", + "\n", + " for ind in tqdm(range(len(all_wavs2))):\n", + " for img_path in tqdm(allimg_cropped):\n", + " audio_path = all_wavs2[ind]\n", + " # read in data\n", + " audio_frames, poseimgs, deep_feature, source_img, he_source, he_driving, num_frames, y_trg, z_trg, latent_path_driving = prepare_test_data(img_path, audio_path, config['model_params']['audio2kp_params'], emotype)\n", + "\n", + "\n", + " with torch.no_grad():\n", + " source_img = torch.from_numpy(source_img).unsqueeze(0).cuda()\n", + " kp_canonical = kp_detector(source_img, with_feature=True) # {'value': value, 'jacobian': jacobian}\n", + " kp_cano = kp_canonical['value']\n", + "\n", + " x = {}\n", + " x['mel'] = audio_frames.unsqueeze(1).unsqueeze(0).cuda()\n", + " x['z_trg'] = z_trg.unsqueeze(0).cuda()\n", + " x['y_trg'] = torch.tensor(y_trg, dtype=torch.long).cuda().reshape(1)\n", + " x['pose'] = poseimgs.cuda()\n", + " x['deep'] = deep_feature.cuda().unsqueeze(0)\n", + " x['he_driving'] = {'yaw': torch.from_numpy(he_driving['yaw']).cuda().unsqueeze(0),\n", + " 'pitch': torch.from_numpy(he_driving['pitch']).cuda().unsqueeze(0),\n", + " 'roll': torch.from_numpy(he_driving['roll']).cuda().unsqueeze(0),\n", + " 't': torch.from_numpy(he_driving['t']).cuda().unsqueeze(0),\n", + " }\n", + "\n", + " ### emotion prompt\n", + " emoprompt, deepprompt = emotionprompt(x)\n", + " a2kp_exps = []\n", + " emo_exps = []\n", + " T = 5\n", + " if T == 1:\n", + " for i in range(x['mel'].shape[1]):\n", + " xi = {}\n", + " xi['mel'] = x['mel'][:,i,:,:,:].unsqueeze(1)\n", + " xi['z_trg'] = x['z_trg']\n", + " xi['y_trg'] = x['y_trg']\n", + " xi['pose'] = x['pose'][i,:,:,:,:].unsqueeze(0)\n", + " xi['deep'] = x['deep'][:,i,:,:,:].unsqueeze(1)\n", + " xi['he_driving'] = {'yaw': x['he_driving']['yaw'][:,i,:].unsqueeze(0),\n", + " 'pitch': x['he_driving']['pitch'][:,i,:].unsqueeze(0),\n", + " 'roll': x['he_driving']['roll'][:,i,:].unsqueeze(0),\n", + " 't': x['he_driving']['t'][:,i,:].unsqueeze(0),\n", + " }\n", + " he_driving_emo_xi, input_st_xi = audio2kptransformer(xi, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n", + " emo_exp = sidetuning(input_st_xi, emoprompt, deepprompt)\n", + " a2kp_exps.append(he_driving_emo_xi['emo'])\n", + " emo_exps.append(emo_exp)\n", + " elif T is not None:\n", + " for i in range(x['mel'].shape[1]//T+1):\n", + " if i*T >= x['mel'].shape[1]:\n", + " break\n", + " xi = {}\n", + " xi['mel'] = x['mel'][:,i*T:(i+1)*T,:,:,:]\n", + " xi['z_trg'] = x['z_trg']\n", + " xi['y_trg'] = x['y_trg']\n", + " xi['pose'] = x['pose'][i*T:(i+1)*T,:,:,:,:]\n", + " xi['deep'] = x['deep'][:,i*T:(i+1)*T,:,:,:]\n", + " xi['he_driving'] = {'yaw': x['he_driving']['yaw'][:,i*T:(i+1)*T,:],\n", + " 'pitch': x['he_driving']['pitch'][:,i*T:(i+1)*T,:],\n", + " 'roll': x['he_driving']['roll'][:,i*T:(i+1)*T,:],\n", + " 't': x['he_driving']['t'][:,i*T:(i+1)*T,:],\n", + " }\n", + " he_driving_emo_xi, input_st_xi = audio2kptransformer(xi, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n", + " emo_exp = sidetuning(input_st_xi, emoprompt, deepprompt)\n", + " a2kp_exps.append(he_driving_emo_xi['emo'])\n", + " emo_exps.append(emo_exp)\n", + "\n", + " if T is None:\n", + " he_driving_emo, input_st = audio2kptransformer(x, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n", + " emo_exps = sidetuning(input_st, emoprompt, deepprompt).reshape(-1, 45)\n", + " else:\n", + " he_driving_emo = {}\n", + " he_driving_emo['emo'] = torch.cat(a2kp_exps, dim=0)\n", + " emo_exps = torch.cat(emo_exps, dim=0).reshape(-1, 45)\n", + "\n", + " exp = he_driving_emo['emo']\n", + " device = exp.get_device()\n", + " exp = torch.mm(exp, expU.t().to(device))\n", + " exp = exp + expmean.expand_as(exp).to(device)\n", + " exp = exp + emo_exps\n", + "\n", + "\n", + " source_area = ConvexHull(kp_cano[0].cpu().numpy()).volume\n", + " exp = exp * source_area\n", + "\n", + " he_new_driving = {'yaw': torch.from_numpy(he_driving['yaw']).cuda(),\n", + " 'pitch': torch.from_numpy(he_driving['pitch']).cuda(),\n", + " 'roll': torch.from_numpy(he_driving['roll']).cuda(),\n", + " 't': torch.from_numpy(he_driving['t']).cuda(),\n", + " 'exp': exp}\n", + " he_driving['exp'] = torch.from_numpy(he_driving['exp']).cuda()\n", + "\n", + " kp_source = keypoint_transformation(kp_canonical, he_source, False)\n", + " mean_source = torch.mean(kp_source['value'], dim=1)[0]\n", + " kp_driving = keypoint_transformation(kp_canonical, he_new_driving, False)\n", + " mean_driving = torch.mean(torch.mean(kp_driving['value'], dim=1), dim=0)\n", + " kp_driving['value'] = kp_driving['value']+(mean_source-mean_driving).unsqueeze(0).unsqueeze(0)\n", + " bs = kp_source['value'].shape[0]\n", + " predictions_gen = []\n", + " for i in tqdm(range(num_frames)):\n", + " kp_si = {}\n", + " kp_si['value'] = kp_source['value'][0].unsqueeze(0)\n", + " kp_di = {}\n", + " kp_di['value'] = kp_driving['value'][i].unsqueeze(0)\n", + " generated = generator(source_img, kp_source=kp_si, kp_driving=kp_di, prompt=emoprompt)\n", + " predictions_gen.append(\n", + " (np.transpose(generated['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] * 255).astype(np.uint8))\n", + "\n", + " log_dir = save_dir\n", + " os.makedirs(os.path.join(log_dir, \"temp\"), exist_ok=True)\n", + "\n", + " f_name = os.path.basename(img_path[:-4]) + \"_\" + emotype + \"_\" + os.path.basename(latent_path_driving)[:-4] + \".mp4\"\n", + " video_path = os.path.join(log_dir, \"temp\", f_name)\n", + " imageio.mimsave(video_path, predictions_gen, fps=25.0)\n", + "\n", + " save_video = os.path.join(log_dir, f_name)\n", + " cmd = r'ffmpeg -loglevel error -y -i \"%s\" -i \"%s\" -vcodec copy -shortest \"%s\"' % (video_path, audio_path, save_video)\n", + " os.system(cmd)\n", + " os.remove(video_path)\n", + "\n", + "if __name__ == '__main__':\n", + " argparser = argparse.ArgumentParser()\n", + " argparser.add_argument(\"--save_dir\", type=str, default=\" \", help=\"path of the output video\")\n", + " argparser.add_argument(\"--name\", type=str, default=\"deepprompt_eam3d_all_final_313\", help=\"path of the output video\")\n", + " argparser.add_argument(\"--emo\", type=str, default=\"hap\", help=\"emotion type ('ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur')\")\n", + " argparser.add_argument(\"--root_wav\", type=str, default='./demo/video_processed/M003_neu_1_001', help=\"emotion type ('ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur')\")\n", + " args = argparser.parse_args()\n", + "\n", + " root_wav=args.root_wav\n", + "\n", + " if len(args.name) > 1:\n", + " name = args.name\n", + " print(name)\n", + " test(f'/content/EAT_code/ckpt/deepprompt_eam3d_all_final_313.pth.tar', args.emo, save_dir=f'./demo/output/{name}/')\n", + "\n" + ], + "metadata": { + "cellView": "form", + "id": "cMz72QBbgRsc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "CmeH0D2ayRrf" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title trasnformer.py with fixed paths\n", + "import torch.nn as nn\n", + "import torch\n", + "from modules.util import mydownres2Dblock\n", + "import numpy as np\n", + "from modules.util import AntiAliasInterpolation2d, make_coordinate_grid\n", + "from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\n", + "import torch.nn.functional as F\n", + "import copy\n", + "\n", + "\n", + "class PositionalEncoding(nn.Module):\n", + "\n", + " def __init__(self, d_hid, n_position=200):\n", + " super(PositionalEncoding, self).__init__()\n", + "\n", + " # Not a parameter\n", + " self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))\n", + "\n", + " def _get_sinusoid_encoding_table(self, n_position, d_hid):\n", + " ''' Sinusoid position encoding table '''\n", + " # TODO: make it with torch instead of numpy\n", + "\n", + " def get_position_angle_vec(position):\n", + " return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]\n", + "\n", + " sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])\n", + " sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i\n", + " sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1\n", + "\n", + " return torch.FloatTensor(sinusoid_table).unsqueeze(0)\n", + "\n", + " def forward(self, winsize):\n", + " return self.pos_table[:, :winsize].clone().detach()\n", + "\n", + "def _get_activation_fn(activation):\n", + " \"\"\"Return an activation function given a string\"\"\"\n", + " if activation == \"relu\":\n", + " return F.relu\n", + " if activation == \"gelu\":\n", + " return F.gelu\n", + " if activation == \"glu\":\n", + " return F.glu\n", + " raise RuntimeError(F\"activation should be relu/gelu, not {activation}.\")\n", + "\n", + "def _get_clones(module, N):\n", + " return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n", + "\n", + "### light weight transformer encoder\n", + "class TransformerST(nn.Module):\n", + "\n", + " def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n", + " num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n", + " activation=\"relu\", normalize_before=False):\n", + " super().__init__()\n", + "\n", + " encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n", + " dropout, activation, normalize_before)\n", + " encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n", + " self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n", + "\n", + " self._reset_parameters()\n", + "\n", + " self.d_model = d_model\n", + " self.nhead = nhead\n", + "\n", + " def _reset_parameters(self):\n", + " for p in self.parameters():\n", + " if p.dim() > 1:\n", + " nn.init.xavier_uniform_(p)\n", + "\n", + " def forward(self, src, pos_embed):\n", + " # flatten NxCxHxW to HWxNxC\n", + "\n", + " src = src.permute(1, 0, 2)\n", + " pos_embed = pos_embed.permute(1, 0, 2)\n", + "\n", + " memory = self.encoder(src, pos=pos_embed)\n", + "\n", + " return memory\n", + "\n", + "class TransformerEncoder(nn.Module):\n", + "\n", + " def __init__(self, encoder_layer, num_layers, norm=None):\n", + " super().__init__()\n", + " self.layers = _get_clones(encoder_layer, num_layers)\n", + " self.num_layers = num_layers\n", + " self.norm = norm\n", + "\n", + " def forward(self, src, mask = None, src_key_padding_mask = None, pos = None):\n", + " output = src+pos\n", + "\n", + " for layer in self.layers:\n", + " output = layer(output, src_mask=mask,\n", + " src_key_padding_mask=src_key_padding_mask, pos=pos)\n", + "\n", + " if self.norm is not None:\n", + " output = self.norm(output)\n", + "\n", + " return output\n", + "\n", + "class TransformerDecoder(nn.Module):\n", + "\n", + " def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):\n", + " super().__init__()\n", + " self.layers = _get_clones(decoder_layer, num_layers)\n", + " self.num_layers = num_layers\n", + " self.norm = norm\n", + " self.return_intermediate = return_intermediate\n", + "\n", + " def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,\n", + " memory_key_padding_mask = None,\n", + " pos = None,\n", + " query_pos = None):\n", + " output = tgt+pos+query_pos\n", + "\n", + " intermediate = []\n", + "\n", + " for layer in self.layers:\n", + " output = layer(output, memory, tgt_mask=tgt_mask,\n", + " memory_mask=memory_mask,\n", + " tgt_key_padding_mask=tgt_key_padding_mask,\n", + " memory_key_padding_mask=memory_key_padding_mask,\n", + " pos=pos, query_pos=query_pos)\n", + " if self.return_intermediate:\n", + " intermediate.append(self.norm(output))\n", + "\n", + " if self.norm is not None:\n", + " output = self.norm(output)\n", + " if self.return_intermediate:\n", + " intermediate.pop()\n", + " intermediate.append(output)\n", + "\n", + " if self.return_intermediate:\n", + " return torch.stack(intermediate)\n", + "\n", + " return output.unsqueeze(0)\n", + "\n", + "\n", + "class Transformer(nn.Module):\n", + "\n", + " def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n", + " num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n", + " activation=\"relu\", normalize_before=False,\n", + " return_intermediate_dec=True):\n", + " super().__init__()\n", + "\n", + " encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n", + " dropout, activation, normalize_before)\n", + " encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n", + " self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n", + "\n", + " decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,\n", + " dropout, activation, normalize_before)\n", + " decoder_norm = nn.LayerNorm(d_model)\n", + " self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,\n", + " return_intermediate=return_intermediate_dec)\n", + "\n", + " self._reset_parameters()\n", + "\n", + " self.d_model = d_model\n", + " self.nhead = nhead\n", + "\n", + " def _reset_parameters(self):\n", + " for p in self.parameters():\n", + " if p.dim() > 1:\n", + " nn.init.xavier_uniform_(p)\n", + "\n", + " def forward(self, src, query_embed, pos_embed):\n", + " # flatten NxCxHxW to HWxNxC\n", + "\n", + " src = src.permute(1, 0, 2)\n", + " pos_embed = pos_embed.permute(1, 0, 2)\n", + " query_embed = query_embed.permute(1, 0, 2)\n", + "\n", + " tgt = torch.zeros_like(query_embed)\n", + " memory = self.encoder(src, pos=pos_embed)\n", + "\n", + " hs = self.decoder(tgt, memory,\n", + " pos=pos_embed, query_pos=query_embed)\n", + " return hs, memory\n", + "\n", + "\n", + "class TransformerDeep(nn.Module):\n", + "\n", + " def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n", + " num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n", + " activation=\"relu\", normalize_before=False,\n", + " return_intermediate_dec=True):\n", + " super().__init__()\n", + "\n", + " encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n", + " dropout, activation, normalize_before)\n", + " encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n", + " self.encoder = TransformerEncoderDeep(encoder_layer, num_encoder_layers, encoder_norm)\n", + "\n", + " decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,\n", + " dropout, activation, normalize_before)\n", + " decoder_norm = nn.LayerNorm(d_model)\n", + " self.decoder = TransformerDecoderDeep(decoder_layer, num_decoder_layers, decoder_norm,\n", + " return_intermediate=return_intermediate_dec)\n", + "\n", + " self._reset_parameters()\n", + "\n", + " self.d_model = d_model\n", + " self.nhead = nhead\n", + "\n", + " def _reset_parameters(self):\n", + " for p in self.parameters():\n", + " if p.dim() > 1:\n", + " nn.init.xavier_uniform_(p)\n", + "\n", + " def forward(self, src, query_embed, pos_embed, deepprompt):\n", + " # flatten NxCxHxW to HWxNxC\n", + "\n", + " # print('src before permute: ', src.shape) # 5, 12, 128\n", + " src = src.permute(1, 0, 2)\n", + " # print('src after permute: ', src.shape) # 12, 5, 128\n", + " pos_embed = pos_embed.permute(1, 0, 2)\n", + " query_embed = query_embed.permute(1, 0, 2)\n", + "\n", + " tgt = torch.zeros_like(query_embed) # actually is tgt + query_embed\n", + " memory = self.encoder(src, deepprompt, pos=pos_embed)\n", + "\n", + " hs = self.decoder(tgt, deepprompt, memory,\n", + " pos=pos_embed, query_pos=query_embed)\n", + " return hs, memory\n", + "\n", + "class TransformerEncoderDeep(nn.Module):\n", + "\n", + " def __init__(self, encoder_layer, num_layers, norm=None):\n", + " super().__init__()\n", + " self.layers = _get_clones(encoder_layer, num_layers)\n", + " self.num_layers = num_layers\n", + " self.norm = norm\n", + "\n", + " def forward(self, src, deepprompt, mask = None, src_key_padding_mask = None, pos = None):\n", + " # print('input: ', src.shape) # 12 5 128\n", + " # print('deepprompt:', deepprompt.shape) # 1 6 128\n", + " ### TODO: add deep prompt in encoder\n", + " bs = src.shape[1]\n", + " bbs = deepprompt.shape[0]\n", + " idx=0\n", + " emoprompt = deepprompt[:,idx,:]\n", + " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n", + " # print(emoprompt.shape) # 1 5 128\n", + " src = torch.cat([src, emoprompt], dim=0)\n", + " # print(src.shape) # 13 5 128\n", + " output = src+pos\n", + "\n", + " for layer in self.layers:\n", + " output = layer(output, src_mask=mask,\n", + " src_key_padding_mask=src_key_padding_mask, pos=pos)\n", + "\n", + " ### deep prompt\n", + " if idx+1 < len(self.layers):\n", + " idx = idx + 1\n", + " # print(idx)\n", + " emoprompt = deepprompt[:,idx,:]\n", + " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n", + " # print(output.shape) # 13 5 128\n", + " output = torch.cat([output[:-1], emoprompt], dim=0)\n", + " # print(output.shape) # 13 5 128\n", + "\n", + " if self.norm is not None:\n", + " output = self.norm(output)\n", + "\n", + " return output\n", + "\n", + "class TransformerDecoderDeep(nn.Module):\n", + "\n", + " def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):\n", + " super().__init__()\n", + " self.layers = _get_clones(decoder_layer, num_layers)\n", + " self.num_layers = num_layers\n", + " self.norm = norm\n", + " self.return_intermediate = return_intermediate\n", + "\n", + " def forward(self, tgt, deepprompt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,\n", + " memory_key_padding_mask = None,\n", + " pos = None,\n", + " query_pos = None):\n", + " # print('input: ', query_pos.shape) 12 5 128\n", + " ### TODO: add deep prompt in encoder\n", + " bs = query_pos.shape[1]\n", + " bbs = deepprompt.shape[0]\n", + " idx=0\n", + " emoprompt = deepprompt[:,idx,:]\n", + " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n", + " query_pos = torch.cat([query_pos, emoprompt], dim=0)\n", + " # print(query_pos.shape) # 13 5 128\n", + " # print(torch.sum(tgt)) # 0\n", + " output = pos+query_pos\n", + "\n", + " intermediate = []\n", + "\n", + " for layer in self.layers:\n", + " output = layer(output, memory, tgt_mask=tgt_mask,\n", + " memory_mask=memory_mask,\n", + " tgt_key_padding_mask=tgt_key_padding_mask,\n", + " memory_key_padding_mask=memory_key_padding_mask,\n", + " pos=pos, query_pos=query_pos)\n", + " if self.return_intermediate:\n", + " intermediate.append(self.norm(output))\n", + "\n", + " ### deep prompt\n", + " if idx+1 < len(self.layers):\n", + " idx = idx + 1\n", + " # print(idx)\n", + " emoprompt = deepprompt[:,idx,:]\n", + " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n", + " # print(output.shape) # 13 5 128\n", + " output = torch.cat([output[:-1], emoprompt], dim=0)\n", + " # print(output.shape) # 13 5 128\n", + "\n", + " if self.norm is not None:\n", + " output = self.norm(output)\n", + " if self.return_intermediate:\n", + " intermediate.pop()\n", + " intermediate.append(output)\n", + "\n", + " if self.return_intermediate:\n", + " return torch.stack(intermediate)\n", + "\n", + " return output.unsqueeze(0)\n", + "\n", + "\n", + "class TransformerEncoderLayer(nn.Module):\n", + "\n", + " def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,\n", + " activation=\"relu\", normalize_before=False):\n", + " super().__init__()\n", + " self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n", + " # Implementation of Feedforward model\n", + " self.linear1 = nn.Linear(d_model, dim_feedforward)\n", + " self.dropout = nn.Dropout(dropout)\n", + " self.linear2 = nn.Linear(dim_feedforward, d_model)\n", + "\n", + " self.norm1 = nn.LayerNorm(d_model)\n", + " self.norm2 = nn.LayerNorm(d_model)\n", + " self.dropout1 = nn.Dropout(dropout)\n", + " self.dropout2 = nn.Dropout(dropout)\n", + "\n", + " self.activation = _get_activation_fn(activation)\n", + " self.normalize_before = normalize_before\n", + "\n", + " def with_pos_embed(self, tensor, pos):\n", + " return tensor if pos is None else tensor + pos\n", + "\n", + " def forward_post(self,\n", + " src,\n", + " src_mask = None,\n", + " src_key_padding_mask = None,\n", + " pos = None):\n", + " # q = k = self.with_pos_embed(src, pos)\n", + " src2 = self.self_attn(src, src, value=src, attn_mask=src_mask,\n", + " key_padding_mask=src_key_padding_mask)[0]\n", + " src = src + self.dropout1(src2)\n", + " src = self.norm1(src)\n", + " src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))\n", + " src = src + self.dropout2(src2)\n", + " src = self.norm2(src)\n", + " return src\n", + "\n", + " def forward_pre(self, src,\n", + " src_mask = None,\n", + " src_key_padding_mask = None,\n", + " pos = None):\n", + " src2 = self.norm1(src)\n", + " # q = k = self.with_pos_embed(src2, pos)\n", + " src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask,\n", + " key_padding_mask=src_key_padding_mask)[0]\n", + " src = src + self.dropout1(src2)\n", + " src2 = self.norm2(src)\n", + " src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))\n", + " src = src + self.dropout2(src2)\n", + " return src\n", + "\n", + " def forward(self, src,\n", + " src_mask = None,\n", + " src_key_padding_mask = None,\n", + " pos = None):\n", + " if self.normalize_before:\n", + " return self.forward_pre(src, src_mask, src_key_padding_mask, pos)\n", + " return self.forward_post(src, src_mask, src_key_padding_mask, pos)\n", + "\n", + "class TransformerDecoderLayer(nn.Module):\n", + "\n", + " def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,\n", + " activation=\"relu\", normalize_before=False):\n", + " super().__init__()\n", + " self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n", + " self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n", + " # Implementation of Feedforward model\n", + " self.linear1 = nn.Linear(d_model, dim_feedforward)\n", + " self.dropout = nn.Dropout(dropout)\n", + " self.linear2 = nn.Linear(dim_feedforward, d_model)\n", + "\n", + " self.norm1 = nn.LayerNorm(d_model)\n", + " self.norm2 = nn.LayerNorm(d_model)\n", + " self.norm3 = nn.LayerNorm(d_model)\n", + " self.dropout1 = nn.Dropout(dropout)\n", + " self.dropout2 = nn.Dropout(dropout)\n", + " self.dropout3 = nn.Dropout(dropout)\n", + "\n", + " self.activation = _get_activation_fn(activation)\n", + " self.normalize_before = normalize_before\n", + "\n", + " def with_pos_embed(self, tensor, pos):\n", + " return tensor if pos is None else tensor + pos\n", + "\n", + " def forward_post(self, tgt, memory,\n", + " tgt_mask = None,\n", + " memory_mask = None,\n", + " tgt_key_padding_mask = None,\n", + " memory_key_padding_mask = None,\n", + " pos = None,\n", + " query_pos = None):\n", + " # q = k = self.with_pos_embed(tgt, query_pos)\n", + " tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask,\n", + " key_padding_mask=tgt_key_padding_mask)[0]\n", + " tgt = tgt + self.dropout1(tgt2)\n", + " tgt = self.norm1(tgt)\n", + " tgt2 = self.multihead_attn(query=tgt,\n", + " key=memory,\n", + " value=memory, attn_mask=memory_mask,\n", + " key_padding_mask=memory_key_padding_mask)[0]\n", + " tgt = tgt + self.dropout2(tgt2)\n", + " tgt = self.norm2(tgt)\n", + " tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))\n", + " tgt = tgt + self.dropout3(tgt2)\n", + " tgt = self.norm3(tgt)\n", + " return tgt\n", + "\n", + " def forward_pre(self, tgt, memory,\n", + " tgt_mask = None,\n", + " memory_mask = None,\n", + " tgt_key_padding_mask = None,\n", + " memory_key_padding_mask = None,\n", + " pos = None,\n", + " query_pos = None):\n", + " tgt2 = self.norm1(tgt)\n", + " # q = k = self.with_pos_embed(tgt2, query_pos)\n", + " tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask,\n", + " key_padding_mask=tgt_key_padding_mask)[0]\n", + " tgt = tgt + self.dropout1(tgt2)\n", + " tgt2 = self.norm2(tgt)\n", + " tgt2 = self.multihead_attn(query=tgt2,\n", + " key=memory,\n", + " value=memory, attn_mask=memory_mask,\n", + " key_padding_mask=memory_key_padding_mask)[0]\n", + " tgt = tgt + self.dropout2(tgt2)\n", + " tgt2 = self.norm3(tgt)\n", + " tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))\n", + " tgt = tgt + self.dropout3(tgt2)\n", + " return tgt\n", + "\n", + " def forward(self, tgt, memory,\n", + " tgt_mask = None,\n", + " memory_mask = None,\n", + " tgt_key_padding_mask = None,\n", + " memory_key_padding_mask = None,\n", + " pos = None,\n", + " query_pos = None):\n", + " if self.normalize_before:\n", + " return self.forward_pre(tgt, memory, tgt_mask, memory_mask,\n", + " tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)\n", + " return self.forward_post(tgt, memory, tgt_mask, memory_mask,\n", + " tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)\n", + "\n", + "from Utils.JDC.model import JDCNet\n", + "from modules.audioencoder import AudioEncoder, MappingNetwork, StyleEncoder, AdaIN, EAModule\n", + "\n", + "def headpose_pred_to_degree(pred):\n", + " device = pred.device\n", + " idx_tensor = [idx for idx in range(66)]\n", + " idx_tensor = torch.FloatTensor(idx_tensor).to(device)\n", + " pred = F.softmax(pred, dim=1)\n", + " degree = torch.sum(pred*idx_tensor, axis=1)\n", + " # degree = F.one_hot(degree.to(torch.int64), num_classes=66)\n", + " return degree\n", + "\n", + "def get_rotation_matrix(yaw, pitch, roll):\n", + " yaw = yaw / 180 * 3.14\n", + " pitch = pitch / 180 * 3.14\n", + " roll = roll / 180 * 3.14\n", + "\n", + " roll = roll.unsqueeze(1)\n", + " pitch = pitch.unsqueeze(1)\n", + " yaw = yaw.unsqueeze(1)\n", + "\n", + " pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),\n", + " torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),\n", + " torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)\n", + "\n", + " pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)\n", + "\n", + " yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw),\n", + " torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),\n", + " -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)\n", + " yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)\n", + "\n", + " roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),\n", + " torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),\n", + " torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)\n", + " roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)\n", + "\n", + " rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)\n", + "\n", + " return yaw, pitch, roll, yaw_mat.view(yaw_mat.shape[0], 9), pitch_mat.view(pitch_mat.shape[0], 9), roll_mat.view(roll_mat.shape[0], 9), rot_mat.view(rot_mat.shape[0], 9)\n", + "\n", + "class Audio2kpTransformerBBoxQDeep(nn.Module):\n", + " def __init__(self, embedding_dim, num_kp, num_w, face_adain=False):\n", + " super(Audio2kpTransformerBBoxQDeep, self).__init__()\n", + " self.embedding_dim = embedding_dim\n", + " self.num_kp = num_kp\n", + " self.num_w = num_w\n", + "\n", + "\n", + " self.embedding = nn.Embedding(41, embedding_dim)\n", + "\n", + " self.face_shrink = nn.Linear(240, 32)\n", + " self.hp_extractor = nn.Linear(45, 128)\n", + "\n", + " self.pos_enc = PositionalEncoding(128,20)\n", + " input_dim = 1\n", + "\n", + " self.decode_dim = 64\n", + " self.audio_embedding = nn.Sequential( # n x 29 x 16\n", + " nn.Conv1d(29, 32, kernel_size=3, stride=2,\n", + " padding=1, bias=True), # n x 32 x 8\n", + " nn.LeakyReLU(0.02, True),\n", + " nn.Conv1d(32, 32, kernel_size=3, stride=2,\n", + " padding=1, bias=True), # n x 32 x 4\n", + " nn.LeakyReLU(0.02, True),\n", + " nn.Conv1d(32, 64, kernel_size=3, stride=2,\n", + " padding=1, bias=True), # n x 64 x 2\n", + " nn.LeakyReLU(0.02, True),\n", + " nn.Conv1d(64, 64, kernel_size=3, stride=2,\n", + " padding=1, bias=True), # n x 64 x 1\n", + " nn.LeakyReLU(0.02, True),\n", + " )\n", + " self.encoder_fc1 = nn.Sequential(\n", + " nn.Linear(192, 128),\n", + " nn.LeakyReLU(0.02, True),\n", + " nn.Linear(128, 128),\n", + " )\n", + "\n", + " self.audio_embedding2 = nn.Sequential(nn.Conv2d(1, 8, (3, 17), stride=(1, 1), padding=(1, 0)),\n", + " # nn.GroupNorm(4, 8, affine=True),\n", + " BatchNorm2d(8),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(8, 32, (13, 13), stride=(1, 1), padding=(6, 6)))\n", + "\n", + " self.audioencoder = AudioEncoder(dim_in=64, style_dim=128, max_conv_dim=512, w_hpf=0, F0_channel=256)\n", + " # self.mappingnet = MappingNetwork(latent_dim=16, style_dim=128, num_domains=8, hidden_dim=512)\n", + " # self.stylenet = StyleEncoder(dim_in=64, style_dim=64, num_domains=8, max_conv_dim=512)\n", + " self.face_adain = face_adain\n", + " if self.face_adain:\n", + " self.fadain = AdaIN(style_dim=128, num_features=32)\n", + " # norm = 'layer_2d' #\n", + " norm = 'batch'\n", + "\n", + " self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32, normalize = norm),\n", + " mydownres2Dblock(32,48, normalize = norm),\n", + " mydownres2Dblock(48,64, normalize = norm),\n", + " mydownres2Dblock(64,96, normalize = norm),\n", + " mydownres2Dblock(96,128, normalize = norm),\n", + " nn.AvgPool2d(2))\n", + "\n", + " self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32),\n", + " mydownres2Dblock(32,64),\n", + " mydownres2Dblock(64,128),\n", + " mydownres2Dblock(128,128),\n", + " mydownres2Dblock(128,128),\n", + " nn.AvgPool2d(2))\n", + " self.transformer = Transformer()\n", + " self.kp = nn.Linear(128, 32)\n", + "\n", + " # for m in self.modules():\n", + " # if isinstance(m, nn.Conv2d):\n", + " # # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n", + " # nn.init.xavier_uniform_(m.weight, gain=1)\n", + "\n", + " # if isinstance(m, nn.Linear):\n", + " # # trunc_normal_(m.weight, std=.03)\n", + " # nn.init.xavier_uniform_(m.weight, gain=1)\n", + " # if isinstance(m, nn.Linear) and m.bias is not None:\n", + " # nn.init.constant_(m.bias, 0)\n", + "\n", + " F0_path = './Utils/JDC/bst.t7'\n", + " F0_model = JDCNet(num_class=1, seq_len=32)\n", + " params = torch.load(F0_path, map_location='cpu')['net']\n", + " F0_model.load_state_dict(params)\n", + " self.f0_model = F0_model\n", + "\n", + " def rotation_and_translation(self, headpose, bbs, bs):\n", + " # print(headpose['roll'].shape, headpose['yaw'].shape, headpose['pitch'].shape, headpose['t'].shape)\n", + "\n", + " yaw = headpose_pred_to_degree(headpose['yaw'].reshape(bbs*bs, -1))\n", + " pitch = headpose_pred_to_degree(headpose['pitch'].reshape(bbs*bs, -1))\n", + " roll = headpose_pred_to_degree(headpose['roll'].reshape(bbs*bs, -1))\n", + " yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v = get_rotation_matrix(yaw, pitch, roll)\n", + " t = headpose['t'].reshape(bbs*bs, -1)\n", + " # hp = torch.cat([yaw, pitch, roll, yaw_v, pitch_v, roll_v, t], dim=1)\n", + " hp = torch.cat([yaw.unsqueeze(1), pitch.unsqueeze(1), roll.unsqueeze(1), yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v, t], dim=1)\n", + " # hp = torch.cat([yaw, pitch, roll, torch.sin(yaw), torch.sin(pitch), torch.sin(roll), torch.cos(yaw), torch.cos(pitch), torch.cos(roll), t], dim=1)\n", + " return hp\n", + "\n", + " def forward(self, x, initial_kp = None, return_strg=False, emoprompt=None, hp=None, side=False):\n", + " bbs, bs, seqlen, _, _ = x['deep'].shape\n", + " # ph = x[\"pho\"].reshape(bbs*bs*seqlen, 1)\n", + " if hp is None:\n", + " hp = self.rotation_and_translation(x['he_driving'], bbs, bs)\n", + " hp = self.hp_extractor(hp)\n", + "\n", + " pose_feature = x[\"pose\"].reshape(bbs*bs*seqlen,1,64,64)\n", + " # pose_feature = self.down_pose(pose).contiguous()\n", + " ### phoneme input feature\n", + " # phoneme_embedding = self.embedding(ph.long())\n", + " # phoneme_embedding = phoneme_embedding.reshape(bbs*bs*seqlen, 1, 16, 16)\n", + " # phoneme_embedding = F.interpolate(phoneme_embedding, scale_factor=4)\n", + " # input_feature = torch.cat((pose_feature, phoneme_embedding), dim=1)\n", + " # print('input_feature: ', input_feature.shape)\n", + " # input_feature = phoneme_embedding\n", + "\n", + " audio = x['deep'].reshape(bbs*bs*seqlen, 16, 29).permute(0, 2, 1)\n", + " deep_feature = self.audio_embedding(audio).squeeze(-1)# ([264, 32, 16, 16])\n", + " # print(deep_feature.shape)\n", + "\n", + " input_feature = pose_feature\n", + " # print(input_feature.shape)\n", + " # assert(0)\n", + " input_feature = self.feature_extract(input_feature).reshape(bbs*bs*seqlen, 128)\n", + " input_feature = torch.cat([input_feature, deep_feature], dim=1)\n", + " input_feature = self.encoder_fc1(input_feature).reshape(bbs*bs, seqlen, 128)\n", + " # phoneme_embedding = self.phoneme_shrink(phoneme_embedding.squeeze(1))# 24*11, 128\n", + " input_feature = torch.cat([input_feature, hp.unsqueeze(1)], dim=1)\n", + "\n", + " ### decode audio feature\n", + " ### use iteration to avoid batchnorm2d in different audio sequence\n", + " decoder_features = []\n", + " for i in range(bbs):\n", + " F0 = self.f0_model.get_feature_GAN(x['mel'][i].reshape(bs, 1, 80, seqlen))\n", + " if emoprompt is None:\n", + " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=None, masks=None, F0=F0))\n", + " else:\n", + " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=emoprompt[i].unsqueeze(0), masks=None, F0=F0))\n", + " audio2 = torch.permute(audio_feature, (0, 3, 1, 2)).reshape(bs*seqlen, 1, 64, 80)\n", + " decoder_feature = self.audio_embedding2(audio2)\n", + "\n", + " # decoder_feature = torch.cat([decoder_feature, audio2], dim=1)\n", + " # decoder_feature = F.interpolate(decoder_feature, scale_factor=2)# ([264, 35, 64, 64])\n", + " face_map = initial_kp[\"prediction_map\"][i].reshape(15*16, 64*64).permute(1, 0).reshape(64*64, 15*16)\n", + " feature_map = self.face_shrink(face_map).permute(1, 0).reshape(1, 32, 64, 64)\n", + " if self.face_adain:\n", + " feature_map = self.fadain(feature_map, emoprompt)\n", + " decoder_feature = self.decodefeature_extract(torch.cat(\n", + " (decoder_feature,\n", + " feature_map.repeat(bs, seqlen, 1, 1, 1).reshape(bs * seqlen, 32, 64, 64)),\n", + " dim=1)).reshape(bs, seqlen, 128)\n", + " decoder_features.append(decoder_feature)\n", + " decoder_feature = torch.cat(decoder_features, dim=0)\n", + "\n", + " decoder_feature = torch.cat([decoder_feature, hp.unsqueeze(1)], dim=1)\n", + "\n", + " # decoder_feature = torch.cat([decoder_feature], dim=1)\n", + "\n", + " # a2kp transformer\n", + " # position embedding\n", + " if emoprompt is None:\n", + " posi_em = self.pos_enc(self.num_w*2+1+1) # 11 + headpose token\n", + " else:\n", + " posi_em = self.pos_enc(self.num_w*2+1+1+1) # 11 + headpose token + emotion prompt\n", + " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(1)\n", + " input_feature = torch.cat([input_feature, emoprompt], dim=1)\n", + " decoder_feature = torch.cat([decoder_feature, emoprompt], dim=1)\n", + " out = {}\n", + " output_feature, memory = self.transformer(input_feature, decoder_feature, posi_em, )\n", + " output_feature = output_feature[-1, self.num_w] # returned intermediate output [6, 13, bbs*bs, 128]\n", + " out[\"emo\"] = self.kp(output_feature)\n", + " if side:\n", + " input_st = {}\n", + " input_st['hp'] = hp\n", + " input_st['decoder_feature'] = decoder_feature\n", + " input_st['memory'] = memory\n", + " return out, input_st\n", + " else:\n", + " return out\n", + "\n", + "\n", + "class Audio2kpTransformerBBoxQDeepPrompt(nn.Module):\n", + " def __init__(self, embedding_dim, num_kp, num_w, face_ea=False):\n", + " super(Audio2kpTransformerBBoxQDeepPrompt, self).__init__()\n", + " self.embedding_dim = embedding_dim\n", + " self.num_kp = num_kp\n", + " self.num_w = num_w\n", + "\n", + "\n", + " self.embedding = nn.Embedding(41, embedding_dim)\n", + "\n", + " self.face_shrink = nn.Linear(240, 32)\n", + " self.hp_extractor = nn.Linear(45, 128)\n", + "\n", + " self.pos_enc = PositionalEncoding(128,20)\n", + " input_dim = 1\n", + "\n", + " self.decode_dim = 64\n", + " self.audio_embedding = nn.Sequential( # n x 29 x 16\n", + " nn.Conv1d(29, 32, kernel_size=3, stride=2,\n", + " padding=1, bias=True), # n x 32 x 8\n", + " nn.LeakyReLU(0.02, True),\n", + " nn.Conv1d(32, 32, kernel_size=3, stride=2,\n", + " padding=1, bias=True), # n x 32 x 4\n", + " nn.LeakyReLU(0.02, True),\n", + " nn.Conv1d(32, 64, kernel_size=3, stride=2,\n", + " padding=1, bias=True), # n x 64 x 2\n", + " nn.LeakyReLU(0.02, True),\n", + " nn.Conv1d(64, 64, kernel_size=3, stride=2,\n", + " padding=1, bias=True), # n x 64 x 1\n", + " nn.LeakyReLU(0.02, True),\n", + " )\n", + " self.encoder_fc1 = nn.Sequential(\n", + " nn.Linear(192, 128),\n", + " nn.LeakyReLU(0.02, True),\n", + " nn.Linear(128, 128),\n", + " )\n", + "\n", + " self.audio_embedding2 = nn.Sequential(nn.Conv2d(1, 8, (3, 17), stride=(1, 1), padding=(1, 0)),\n", + " # nn.GroupNorm(4, 8, affine=True),\n", + " BatchNorm2d(8),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(8, 32, (13, 13), stride=(1, 1), padding=(6, 6)))\n", + "\n", + " self.audioencoder = AudioEncoder(dim_in=64, style_dim=128, max_conv_dim=512, w_hpf=0, F0_channel=256)\n", + " self.face_ea = face_ea\n", + " if self.face_ea:\n", + " self.fea = EAModule(style_dim=128, num_features=32)\n", + " norm = 'batch'\n", + "\n", + " self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32, normalize = norm),\n", + " mydownres2Dblock(32,48, normalize = norm),\n", + " mydownres2Dblock(48,64, normalize = norm),\n", + " mydownres2Dblock(64,96, normalize = norm),\n", + " mydownres2Dblock(96,128, normalize = norm),\n", + " nn.AvgPool2d(2))\n", + "\n", + " self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32),\n", + " mydownres2Dblock(32,64),\n", + " mydownres2Dblock(64,128),\n", + " mydownres2Dblock(128,128),\n", + " mydownres2Dblock(128,128),\n", + " nn.AvgPool2d(2))\n", + " self.transformer = TransformerDeep()\n", + " self.kp = nn.Linear(128, 32)\n", + "\n", + " F0_path = '/content/EAT_code/Utils/JDC/bst.t7'\n", + " F0_model = JDCNet(num_class=1, seq_len=32)\n", + " params = torch.load(F0_path, map_location='cpu')['net']\n", + " F0_model.load_state_dict(params)\n", + " self.f0_model = F0_model\n", + "\n", + " def rotation_and_translation(self, headpose, bbs, bs):\n", + " yaw = headpose_pred_to_degree(headpose['yaw'].reshape(bbs*bs, -1))\n", + " pitch = headpose_pred_to_degree(headpose['pitch'].reshape(bbs*bs, -1))\n", + " roll = headpose_pred_to_degree(headpose['roll'].reshape(bbs*bs, -1))\n", + " yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v = get_rotation_matrix(yaw, pitch, roll)\n", + " t = headpose['t'].reshape(bbs*bs, -1)\n", + " hp = torch.cat([yaw.unsqueeze(1), pitch.unsqueeze(1), roll.unsqueeze(1), yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v, t], dim=1)\n", + " return hp\n", + "\n", + " def forward(self, x, initial_kp = None, return_strg=False, emoprompt=None, deepprompt=None, hp=None, side=False):\n", + " bbs, bs, seqlen, _, _ = x['deep'].shape\n", + " # ph = x[\"pho\"].reshape(bbs*bs*seqlen, 1)\n", + " if hp is None:\n", + " hp = self.rotation_and_translation(x['he_driving'], bbs, bs)\n", + " hp = self.hp_extractor(hp)\n", + "\n", + " pose_feature = x[\"pose\"].reshape(bbs*bs*seqlen,1,64,64)\n", + "\n", + " audio = x['deep'].reshape(bbs*bs*seqlen, 16, 29).permute(0, 2, 1)\n", + " deep_feature = self.audio_embedding(audio).squeeze(-1)# ([264, 32, 16, 16])\n", + "\n", + " input_feature = pose_feature\n", + " input_feature = self.feature_extract(input_feature).reshape(bbs*bs*seqlen, 128)\n", + " input_feature = torch.cat([input_feature, deep_feature], dim=1)\n", + " input_feature = self.encoder_fc1(input_feature).reshape(bbs*bs, seqlen, 128)\n", + " input_feature = torch.cat([input_feature, hp.unsqueeze(1)], dim=1)\n", + "\n", + " ### decode audio feature\n", + " ### use iteration to avoid batchnorm2d in different audio sequence\n", + " decoder_features = []\n", + " for i in range(bbs):\n", + " F0 = self.f0_model.get_feature_GAN(x['mel'][i].reshape(bs, 1, 80, seqlen))\n", + " if emoprompt is None:\n", + " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=None, masks=None, F0=F0))\n", + " else:\n", + " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=emoprompt[i].unsqueeze(0), masks=None, F0=F0))\n", + " audio2 = torch.permute(audio_feature, (0, 3, 1, 2)).reshape(bs*seqlen, 1, 64, 80)\n", + " decoder_feature = self.audio_embedding2(audio2)\n", + "\n", + " face_map = initial_kp[\"prediction_map\"][i].reshape(15*16, 64*64).permute(1, 0).reshape(64*64, 15*16)\n", + " face_feature_map = self.face_shrink(face_map).permute(1, 0).reshape(1, 32, 64, 64)\n", + " if self.face_ea:\n", + " face_feature_map = self.fea(face_feature_map, emoprompt)\n", + " decoder_feature = self.decodefeature_extract(torch.cat(\n", + " (decoder_feature,\n", + " face_feature_map.repeat(bs, seqlen, 1, 1, 1).reshape(bs * seqlen, 32, 64, 64)),\n", + " dim=1)).reshape(bs, seqlen, 128)\n", + " decoder_features.append(decoder_feature)\n", + " decoder_feature = torch.cat(decoder_features, dim=0)\n", + "\n", + " decoder_feature = torch.cat([decoder_feature, hp.unsqueeze(1)], dim=1)\n", + "\n", + " # a2kp transformer\n", + " # position embedding\n", + " if emoprompt is None:\n", + " posi_em = self.pos_enc(self.num_w*2+1+1) # 11 + headpose token\n", + " else:\n", + " posi_em = self.pos_enc(self.num_w*2+1+1+1) # 11 + headpose token + deep emotion prompt\n", + " out = {}\n", + " output_feature, memory = self.transformer(input_feature, decoder_feature, posi_em, deepprompt)\n", + " output_feature = output_feature[-1, self.num_w] # returned intermediate output [6, 13, bbs*bs, 128]\n", + " out[\"emo\"] = self.kp(output_feature)\n", + " if side:\n", + " input_st = {}\n", + " input_st['hp'] = hp\n", + " input_st['face_feature_map'] = face_feature_map\n", + " input_st['bs'] = bs\n", + " input_st['bbs'] = bbs\n", + " return out, input_st\n", + " else:\n", + " return out\n", + "\n", + "\n" + ], + "metadata": { + "cellView": "form", + "id": "DZwVMAPDgZvO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title deepspeechfeatures.py fixed\n", + "\n", + "\"\"\"\n", + " DeepSpeech features processing routines.\n", + " NB: Based on VOCA code. See the corresponding license restrictions.\n", + "\"\"\"\n", + "\n", + "__all__ = ['conv_audios_to_deepspeech']\n", + "\n", + "import numpy as np\n", + "import warnings\n", + "import resampy\n", + "from scipy.io import wavfile\n", + "from python_speech_features import mfcc\n", + "import tensorflow as tf\n", + "\n", + "\n", + "def conv_audios_to_deepspeech(audios,\n", + " out_files,\n", + " num_frames_info,\n", + " deepspeech_pb_path,\n", + " audio_window_size=16,\n", + " audio_window_stride=1):\n", + " \"\"\"\n", + " Convert list of audio files into files with DeepSpeech features.\n", + "\n", + " Parameters\n", + " ----------\n", + " audios : list of str or list of None\n", + " Paths to input audio files.\n", + " out_files : list of str\n", + " Paths to output files with DeepSpeech features.\n", + " num_frames_info : list of int\n", + " List of numbers of frames.\n", + " deepspeech_pb_path : str\n", + " Path to DeepSpeech 0.1.0 frozen model.\n", + " audio_window_size : int, default 16\n", + " Audio window size.\n", + " audio_window_stride : int, default 1\n", + " Audio window stride.\n", + " \"\"\"\n", + " graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(deepspeech_pb_path)\n", + "\n", + " with tf.compat.v1.Session(graph=graph) as sess:\n", + " for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):\n", + " audio_sample_rate, audio = wavfile.read(audio_file_path)\n", + " if audio.ndim != 1:\n", + " warnings.warn(\"Audio has multiple channels, the first channel is used\")\n", + " audio = audio[:, 0]\n", + " ds_features = pure_conv_audio_to_deepspeech(\n", + " audio=audio,\n", + " audio_sample_rate=audio_sample_rate,\n", + " audio_window_size=audio_window_size,\n", + " audio_window_stride=audio_window_stride,\n", + " num_frames=num_frames,\n", + " net_fn=lambda x: sess.run(\n", + " logits_ph,\n", + " feed_dict={\n", + " input_node_ph: x[np.newaxis, ...],\n", + " input_lengths_ph: [x.shape[0]]}))\n", + " np.save(out_file_path, ds_features)\n", + "\n", + "\n", + "# data_util/deepspeech_features/deepspeech_features.py\n", + "def prepare_deepspeech_net(deepspeech_pb_path):\n", + " # Load graph and place_holders:\n", + " with tf.io.gfile.GFile(deepspeech_pb_path, \"rb\") as f:\n", + " graph_def = tf.compat.v1.GraphDef()\n", + " graph_def.ParseFromString(f.read())\n", + "\n", + " graph = tf.compat.v1.get_default_graph()\n", + "\n", + " tf.import_graph_def(graph_def, name=\"deepspeech\")\n", + " # check all graphs\n", + " # print('~'*50, [tensor for tensor in graph._nodes_by_name], '~'*50)\n", + " # print('~'*50, [tensor.name for tensor in graph.get_operations()], '~'*50)\n", + " # i modified\n", + " logits_ph = graph.get_tensor_by_name(\"logits:0\")\n", + " input_node_ph = graph.get_tensor_by_name(\"input_node:0\")\n", + " input_lengths_ph = graph.get_tensor_by_name(\"input_lengths:0\")\n", + " # original\n", + " # logits_ph = graph.get_tensor_by_name(\"deepspeech/logits:0\")\n", + " # input_node_ph = graph.get_tensor_by_name(\"deepspeech/input_node:0\")\n", + " # input_lengths_ph = graph.get_tensor_by_name(\"deepspeech/input_lengths:0\")\n", + "\n", + " return graph, logits_ph, input_node_ph, input_lengths_ph\n", + "\n", + "\n", + "def pure_conv_audio_to_deepspeech(audio,\n", + " audio_sample_rate,\n", + " audio_window_size,\n", + " audio_window_stride,\n", + " num_frames,\n", + " net_fn):\n", + " \"\"\"\n", + " Core routine for converting audion into DeepSpeech features.\n", + "\n", + " Parameters\n", + " ----------\n", + " audio : np.array\n", + " Audio data.\n", + " audio_sample_rate : int\n", + " Audio sample rate.\n", + " audio_window_size : int\n", + " Audio window size.\n", + " audio_window_stride : int\n", + " Audio window stride.\n", + " num_frames : int or None\n", + " Numbers of frames.\n", + " net_fn : func\n", + " Function for DeepSpeech model call.\n", + "\n", + " Returns\n", + " -------\n", + " np.array\n", + " DeepSpeech features.\n", + " \"\"\"\n", + " target_sample_rate = 16000\n", + " if audio_sample_rate != target_sample_rate:\n", + " resampled_audio = resampy.resample(\n", + " x=audio.astype(np.float),\n", + " sr_orig=audio_sample_rate,\n", + " sr_new=target_sample_rate)\n", + " else:\n", + " resampled_audio = audio.astype(np.float32)\n", + " input_vector = conv_audio_to_deepspeech_input_vector(\n", + " audio=resampled_audio.astype(np.int16),\n", + " sample_rate=target_sample_rate,\n", + " num_cepstrum=26,\n", + " num_context=9)\n", + "\n", + " network_output = net_fn(input_vector)\n", + "\n", + " deepspeech_fps = 50\n", + " video_fps = 60\n", + " audio_len_s = float(audio.shape[0]) / audio_sample_rate\n", + " if num_frames is None:\n", + " num_frames = int(round(audio_len_s * video_fps))\n", + " else:\n", + " video_fps = num_frames / audio_len_s\n", + " network_output = interpolate_features(\n", + " features=network_output[:, 0],\n", + " input_rate=deepspeech_fps,\n", + " output_rate=video_fps,\n", + " output_len=num_frames)\n", + "\n", + " # Make windows:\n", + " zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))\n", + " network_output = np.concatenate((zero_pad, network_output, zero_pad), axis=0)\n", + " windows = []\n", + " for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):\n", + " windows.append(network_output[window_index:window_index + audio_window_size])\n", + "\n", + " return np.array(windows)\n", + "\n", + "\n", + "def conv_audio_to_deepspeech_input_vector(audio,\n", + " sample_rate,\n", + " num_cepstrum,\n", + " num_context):\n", + " \"\"\"\n", + " Convert audio raw data into DeepSpeech input vector.\n", + "\n", + " Parameters\n", + " ----------\n", + " audio : np.array\n", + " Audio data.\n", + " audio_sample_rate : int\n", + " Audio sample rate.\n", + " num_cepstrum : int\n", + " Number of cepstrum.\n", + " num_context : int\n", + " Number of context.\n", + "\n", + " Returns\n", + " -------\n", + " np.array\n", + " DeepSpeech input vector.\n", + " \"\"\"\n", + " # Get mfcc coefficients:\n", + " features = mfcc(\n", + " signal=audio,\n", + " samplerate=sample_rate,\n", + " numcep=num_cepstrum)\n", + "\n", + " # We only keep every second feature (BiRNN stride = 2):\n", + " features = features[::2]\n", + "\n", + " # One stride per time step in the input:\n", + " num_strides = len(features)\n", + "\n", + " # Add empty initial and final contexts:\n", + " empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)\n", + " features = np.concatenate((empty_context, features, empty_context))\n", + "\n", + " # Create a view into the array with overlapping strides of size\n", + " # numcontext (past) + 1 (present) + numcontext (future):\n", + " window_size = 2 * num_context + 1\n", + " train_inputs = np.lib.stride_tricks.as_strided(\n", + " features,\n", + " shape=(num_strides, window_size, num_cepstrum),\n", + " strides=(features.strides[0], features.strides[0], features.strides[1]),\n", + " writeable=False)\n", + "\n", + " # Flatten the second and third dimensions:\n", + " train_inputs = np.reshape(train_inputs, [num_strides, -1])\n", + "\n", + " train_inputs = np.copy(train_inputs)\n", + " train_inputs = (train_inputs - np.mean(train_inputs)) / np.std(train_inputs)\n", + "\n", + " return train_inputs\n", + "\n", + "\n", + "def interpolate_features(features,\n", + " input_rate,\n", + " output_rate,\n", + " output_len):\n", + " \"\"\"\n", + " Interpolate DeepSpeech features.\n", + "\n", + " Parameters\n", + " ----------\n", + " features : np.array\n", + " DeepSpeech features.\n", + " input_rate : int\n", + " input rate (FPS).\n", + " output_rate : int\n", + " Output rate (FPS).\n", + " output_len : int\n", + " Output data length.\n", + "\n", + " Returns\n", + " -------\n", + " np.array\n", + " Interpolated data.\n", + " \"\"\"\n", + " input_len = features.shape[0]\n", + " num_features = features.shape[1]\n", + " input_timestamps = np.arange(input_len) / float(input_rate)\n", + " output_timestamps = np.arange(output_len) / float(output_rate)\n", + " output_features = np.zeros((output_len, num_features))\n", + " for feature_idx in range(num_features):\n", + " output_features[:, feature_idx] = np.interp(\n", + " x=output_timestamps,\n", + " xp=input_timestamps,\n", + " fp=features[:, feature_idx])\n", + " return output_features\n" + ], + "metadata": { + "cellView": "form", + "id": "Jq0kqup1Ogqd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title deepspeechstore fixed\n", + "\n", + "\"\"\"\n", + " Routines for loading DeepSpeech model.\n", + "\"\"\"\n", + "\n", + "__all__ = ['get_deepspeech_model_file']\n", + "\n", + "import os\n", + "import zipfile\n", + "import logging\n", + "import hashlib\n", + "\n", + "\n", + "deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'\n", + "\n", + "\n", + "def get_deepspeech_model_file(local_model_store_dir_path=os.path.join(\"~\", \"/content/EAT_code/tensorflow\", \"models\")):\n", + " \"\"\"\n", + " Return location for the pretrained on local file system. This function will download from online model zoo when\n", + " model cannot be found or has mismatch. The root directory will be created if it doesn't exist.\n", + "\n", + " Parameters\n", + " ----------\n", + " local_model_store_dir_path : str, default $TENSORFLOW_HOME/models\n", + " Location for keeping the model parameters.\n", + "\n", + " Returns\n", + " -------\n", + " file_path\n", + " Path to the requested pretrained model file.\n", + " \"\"\"\n", + " sha1_hash = \"b90017e816572ddce84f5843f1fa21e6a377975e\"\n", + " file_name = \"deepspeech-0_1_0-b90017e8.pb\"\n", + " local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)\n", + " file_path = os.path.join(local_model_store_dir_path, file_name)\n", + " if os.path.exists(file_path):\n", + " if _check_sha1(file_path, sha1_hash):\n", + " return file_path\n", + " else:\n", + " logging.warning(\"Mismatch in the content of model file detected. Downloading again.\")\n", + " else:\n", + " logging.info(\"Model file not found. Downloading to {}.\".format(file_path))\n", + "\n", + " if not os.path.exists(local_model_store_dir_path):\n", + " os.makedirs(local_model_store_dir_path)\n", + "\n", + " zip_file_path = file_path + \".zip\"\n", + " _download(\n", + " url=\"{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip\".format(\n", + " repo_url=deepspeech_features_repo_url,\n", + " repo_release_tag=\"v0.0.1\",\n", + " file_name=file_name),\n", + " path=zip_file_path,\n", + " overwrite=False)\n", + " with zipfile.ZipFile(zip_file_path) as zf:\n", + " zf.extractall(local_model_store_dir_path)\n", + " os.remove(zip_file_path)\n", + "\n", + " if _check_sha1(file_path, sha1_hash):\n", + " return file_path\n", + " else:\n", + " raise ValueError(\"Downloaded file has different hash. Please try again.\")\n", + "\n", + "\n", + "def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):\n", + " \"\"\"\n", + " Download an given URL\n", + "\n", + " Parameters\n", + " ----------\n", + " url : str\n", + " URL to download\n", + " path : str, optional\n", + " Destination path to store downloaded file. By default stores to the\n", + " current directory with same name as in url.\n", + " overwrite : bool, optional\n", + " Whether to overwrite destination file if already exists.\n", + " sha1_hash : str, optional\n", + " Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified\n", + " but doesn't match.\n", + " retries : integer, default 5\n", + " The number of times to attempt the download in case of failure or non 200 return codes\n", + " verify_ssl : bool, default True\n", + " Verify SSL certificates.\n", + "\n", + " Returns\n", + " -------\n", + " str\n", + " The file path of the downloaded file.\n", + " \"\"\"\n", + " import warnings\n", + " try:\n", + " import requests\n", + " except ImportError:\n", + " class requests_failed_to_import(object):\n", + " pass\n", + " requests = requests_failed_to_import\n", + "\n", + " if path is None:\n", + " fname = url.split(\"/\")[-1]\n", + " # Empty filenames are invalid\n", + " assert fname, \"Can't construct file-name from this URL. Please set the `path` option manually.\"\n", + " else:\n", + " path = os.path.expanduser(path)\n", + " if os.path.isdir(path):\n", + " fname = os.path.join(path, url.split(\"/\")[-1])\n", + " else:\n", + " fname = path\n", + " assert retries >= 0, \"Number of retries should be at least 0\"\n", + "\n", + " if not verify_ssl:\n", + " warnings.warn(\n", + " \"Unverified HTTPS request is being made (verify_ssl=False). \"\n", + " \"Adding certificate verification is strongly advised.\")\n", + "\n", + " if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):\n", + " dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))\n", + " if not os.path.exists(dirname):\n", + " os.makedirs(dirname)\n", + " while retries + 1 > 0:\n", + " # Disable pyling too broad Exception\n", + " # pylint: disable=W0703\n", + " try:\n", + " print(\"Downloading {} from {}...\".format(fname, url))\n", + " #r = requests.get(url, stream=True, verify=verify_ssl)\n", + " #if r.status_code != 200:\n", + " # raise RuntimeError(\"Failed downloading url {}\".format(url))\n", + " #with open(fname, \"wb\") as f:\n", + " # for chunk in r.iter_content(chunk_size=1024):\n", + " # if chunk: # filter out keep-alive new chunks\n", + " # f.write(chunk)\n", + " if sha1_hash and not _check_sha1(fname, sha1_hash):\n", + " raise UserWarning(\"File {} is downloaded but the content hash does not match.\"\n", + " \" The repo may be outdated or download may be incomplete. \"\n", + " \"If the `repo_url` is overridden, consider switching to \"\n", + " \"the default repo.\".format(fname))\n", + " break\n", + " except Exception as e:\n", + " retries -= 1\n", + " if retries <= 0:\n", + " raise e\n", + " else:\n", + " print(\"download failed, retrying, {} attempt{} left\"\n", + " .format(retries, \"s\" if retries > 1 else \"\"))\n", + "\n", + " return fname\n", + "\n", + "\n", + "def _check_sha1(filename, sha1_hash):\n", + " \"\"\"\n", + " Check whether the sha1 hash of the file content matches the expected hash.\n", + "\n", + " Parameters\n", + " ----------\n", + " filename : str\n", + " Path to the file.\n", + " sha1_hash : str\n", + " Expected sha1 hash in hexadecimal digits.\n", + "\n", + " Returns\n", + " -------\n", + " bool\n", + " Whether the file content matches the expected hash.\n", + " \"\"\"\n", + " sha1 = hashlib.sha1()\n", + " with open(filename, \"rb\") as f:\n", + " while True:\n", + " data = f.read(1048576)\n", + " if not data:\n", + " break\n", + " sha1.update(data)\n", + "\n", + " return sha1.hexdigest() == sha1_hash\n" + ], + "metadata": { + "cellView": "form", + "id": "BaZ1iwyuO_bl" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!zip -r EAT_code /content/EAT_code" + ], + "metadata": { + "id": "wAFm27NITV_C" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file