{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [ "#@title set font and wrap\n", "from IPython.display import HTML, display\n", "\n", "def set_css():\n", " display(HTML('''\n", " \n", " '''))\n", "get_ipython().events.register('pre_run_cell', set_css)" ], "metadata": { "cellView": "form", "id": "2eK7PTTQiAWv" }, "execution_count": 1, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LBBjwIn0ROsF" }, "outputs": [], "source": [ "#!git clone https://github.com/yuangan/EAT_code.git\n", "!wget https://huggingface.co/waveydaveygravy/styletalk/resolve/main/EAT_code.zip\n", "!unzip /content/EAT_code.zip\n", "%cd /content/EAT_code\n", "!pip install -r requirements.txt\n", "!pip install resampy\n", "#!pip install face_detection\n", "!pip install python_speech_features\n", "#@title make directories if not done yet\n", "!mkdir tensorflow\n", "!mkdir Results\n", "!mkdir ckpt\n", "!mkdir demo\n", "!mkdir Utils\n", "%cd /content/EAT_code/tensorflow\n", "!mkdir models\n", "%cd /content/EAT_code\n", "print(\"done\")" ] }, { "cell_type": "code", "source": [ "#@title in /content/EAT_code/config/deepprompt_eam3d_st_tanh_304_3090_all.yaml change batch_size to 1 line 70" ], "metadata": { "cellView": "form", "id": "_Mxxg6d-WakT" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title make directories if not done yet\n", "!mkdir tensorflow\n", "!mkdir ckpt\n", "!mkdir demo\n", "!mkdir Utils\n", "%cd /content/EAT_code/demo\n", "!mkdir imgs1\n", "!mkdir imgs_cropped1\n", "!mkdir imgs_latent1\n", "%cd /content/EAT_code/tensorflow\n", "!mkdir models\n", "%cd /content/EAT_code" ], "metadata": { "id": "3bf2Yff2DEdH" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title donwload and place deepspeech model\n", "%cd /content/EAT_code\n", "!gdown --id 1KK15n2fOdfLECWN5wvX54mVyDt18IZCo && unzip -q ckpt.zip -d ckpt\n", "!gdown --id 1MeFGC7ig-vgpDLdhh2vpTIiElrhzZmgT && unzip -q demo.zip -d demo\n", "!gdown --id 1HGVzckXh-vYGZEUUKMntY1muIbkbnRcd && unzip -q Utils.zip -d Utils\n", "%cd /content/EAT_code/tensorflow/models\n", "!wget https://github.com/osmr/deepspeech_features/releases/download/v0.0.1/deepspeech-0_1_0-b90017e8.pb.zip\n", "%cd /content/EAT_code" ], "metadata": { "id": "ucK4MQbq0yHx" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title download models\n", "%cd /content/EAT_code\n", "\n", "!gdown --id 1KK15n2fOdfLECWN5wvX54mVyDt18IZCo && unzip -q ckpt.zip -d ckpt\n", "!gdown --id 1MeFGC7ig-vgpDLdhh2vpTIiElrhzZmgT && unzip -q demo.zip -d demo\n", "!gdown --id 1HGVzckXh-vYGZEUUKMntY1muIbkbnRcd && unzip -q Utils.zip -d Utils" ], "metadata": { "id": "eDDFgToSSEgj", "cellView": "form" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title upload custom mp4 to videos\n", "%cd /content/EAT_code/preprocess/video\n", "from google.colab import files\n", "uploaded = files.upload()\n", "%cd /content/EAT_code" ], "metadata": { "id": "rlfKO73uUXFT", "colab": { "base_uri": "https://localhost:8080/", "height": 106 }, "outputId": "e8116307-0a3f-4eea-ebd1-925f3e17a4bf" }, "execution_count": 5, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " " ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "/content/EAT_code/preprocess/video\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " Upload widget is only available when the cell has been executed in the\n", " current browser session. Please rerun this cell to enable.\n", " \n", " " ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "Saving bo_1resized.mp4 to bo_1resized.mp4\n", "/content/EAT_code\n" ] } ] }, { "cell_type": "code", "source": [ "#@title extract boundary boxes\n", "!python /content/EAT_code/preprocess/extract_bbox.py" ], "metadata": { "cellView": "form", "id": "llzj0RprSPu6" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title place custom video in preprocess/video\n", "%cd /content/EAT_code/preprocess\n", "!python /content/EAT_code/preprocess/preprocess_video.py # --deepspeech \"/content/EAT_code/tensorflow/models/deepspeech-0_1_0-b90017e8.pb\"" ], "metadata": { "id": "p99eLmnjW-e1" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ " print(f\"Number of paths: {len(self.A_paths)}\")\n", " print(\"Paths:\")\n", " for path in self.A_paths:\n", " print(path)" ], "metadata": { "id": "B4LWLwEvJOsp" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import numpy as np\n", "\n", "# Replace \"path/to/your/file.npy\" with the actual path to your file\n", "data = np.load(\"/content/EAT_code/demo/video_processed/obama/latent_evp_25/obama.npy\")\n", "\n", "print(f\"Type: {type(data)}\")\n", "print(f\"Shape: {data.shape}\")" ], "metadata": { "id": "V-hPYJBgYIKx" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!python /content/EAT_code/demo.py --root_wav /content/EAT_code/demo/video_processed1/bo_1resized --emo ang --save_dir /content/EAT_code/Results" ], "metadata": { "id": "APJUoQkgaqa9", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "outputId": "e8150705-1bae-4420-8306-069aa274fb1d" }, "execution_count": 14, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " " ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "deepprompt_eam3d_all_final_313\n", "cuda is available\n", "/usr/local/lib/python3.10/dist-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2228.)\n", " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", " 0% 0/1 [00:00 1:\n", " wave_tensor = wave_tensor[:, 0]\n", " mel_tensor = to_melspec(wave_tensor)\n", " mel_tensor = (torch.log(1e-5 + mel_tensor) - mean) / std\n", " name_len = min(mel_tensor.shape[1], poseimg.shape[0], deepfeature.shape[0])\n", "\n", " audio_frames = []\n", " poseimgs = []\n", " deep_feature = []\n", "\n", " pad, deep_pad = np.load('pad.npy', allow_pickle=True)\n", "\n", " if name_len < num_frames:\n", " diff = num_frames - name_len\n", " if diff > 2:\n", " print(f\"Attention: the frames are {diff} more than name_len, we will use name_len to replace num_frames\")\n", " num_frames=name_len\n", " for k in he_driving.keys():\n", " he_driving[k] = he_driving[k][:name_len, :]\n", " for rid in range(0, num_frames):\n", " audio = []\n", " poses = []\n", " deeps = []\n", " for i in range(rid - opt['num_w'], rid + opt['num_w'] + 1):\n", " if i < 0:\n", " audio.append(pad)\n", " poses.append(poseimg[0])\n", " deeps.append(deep_pad)\n", " elif i >= name_len:\n", " audio.append(pad)\n", " poses.append(poseimg[-1])\n", " deeps.append(deep_pad)\n", " else:\n", " audio.append(mel_tensor[:, i])\n", " poses.append(poseimg[i])\n", " deeps.append(deepfeature[i])\n", "\n", " audio_frames.append(torch.stack(audio, dim=1))\n", " poseimgs.append(poses)\n", " deep_feature.append(deeps)\n", " audio_frames = torch.stack(audio_frames, dim=0)\n", " poseimgs = torch.from_numpy(np.array(poseimgs))\n", " deep_feature = torch.from_numpy(np.array(deep_feature)).to(torch.float)\n", " return audio_frames, poseimgs, deep_feature, source_img, he_source, he_driving, num_frames, y_trg, z_trg, latent_path_driving\n", "\n", "def load_ckpt(ckpt, kp_detector, generator, audio2kptransformer, sidetuning, emotionprompt):\n", " checkpoint = torch.load(ckpt, map_location=torch.device('cpu'))\n", " if audio2kptransformer is not None:\n", " audio2kptransformer.load_state_dict(checkpoint['audio2kptransformer'])\n", " if generator is not None:\n", " generator.load_state_dict(checkpoint['generator'])\n", " if kp_detector is not None:\n", " kp_detector.load_state_dict(checkpoint['kp_detector'])\n", " if sidetuning is not None:\n", " sidetuning.load_state_dict(checkpoint['sidetuning'])\n", " if emotionprompt is not None:\n", " emotionprompt.load_state_dict(checkpoint['emotionprompt'])\n", "\n", "import cv2\n", "import dlib\n", "from tqdm import tqdm\n", "from skimage import transform as tf\n", "detector = dlib.get_frontal_face_detector()\n", "predictor = dlib.shape_predictor('/content/EAT_code/demo/shape_predictor_68_face_landmarks.dat')\n", "\n", "def shape_to_np(shape, dtype=\"int\"):\n", " # initialize the list of (x, y)-coordinates\n", " coords = np.zeros((shape.num_parts, 2), dtype=dtype)\n", "\n", " # loop over all facial landmarks and convert them\n", " # to a 2-tuple of (x, y)-coordinates\n", " for i in range(0, shape.num_parts):\n", " coords[i] = (shape.part(i).x, shape.part(i).y)\n", "\n", " # return the list of (x, y)-coordinates\n", " return coords\n", "\n", "def crop_image(image_path, out_path):\n", " template = np.load('./demo/M003_template.npy')\n", " image = cv2.imread(image_path)\n", " gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n", " rects = detector(gray, 1) #detect human face\n", " if len(rects) != 1:\n", " return 0\n", " for (j, rect) in enumerate(rects):\n", " shape = predictor(gray, rect) #detect 68 points\n", " shape = shape_to_np(shape)\n", "\n", " pts2 = np.float32(template[:47,:])\n", " pts1 = np.float32(shape[:47,:]) #eye and nose\n", " tform = tf.SimilarityTransform()\n", " tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.\n", "\n", " dst = tf.warp(image, tform, output_shape=(256, 256))\n", "\n", " dst = np.array(dst * 255, dtype=np.uint8)\n", "\n", " cv2.imwrite(out_path, dst)\n", "\n", "def preprocess_imgs(allimgs, tmp_allimgs_cropped):\n", " name_cropped = []\n", " for path in tmp_allimgs_cropped:\n", " name_cropped.append(os.path.basename(path))\n", " for path in allimgs:\n", " if os.path.basename(path) in name_cropped:\n", " continue\n", " else:\n", " out_path = path.replace('imgs/', 'imgs_cropped/')\n", " crop_image(path, out_path)\n", "\n", "from sync_batchnorm import DataParallelWithCallback\n", "def load_checkpoints_extractor(config_path, checkpoint_path, cpu=False):\n", "\n", " with open(config_path) as f:\n", " config = yaml.load(f, Loader=yaml.FullLoader)\n", "\n", " kp_detector = KPDetector(**config['model_params']['kp_detector_params'],\n", " **config['model_params']['common_params'])\n", " if not cpu:\n", " kp_detector.cuda()\n", "\n", " he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],\n", " **config['model_params']['common_params'])\n", " if not cpu:\n", " he_estimator.cuda()\n", "\n", " if cpu:\n", " checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))\n", " else:\n", " checkpoint = torch.load(checkpoint_path)\n", "\n", " kp_detector.load_state_dict(checkpoint['kp_detector'])\n", " he_estimator.load_state_dict(checkpoint['he_estimator'])\n", "\n", " if not cpu:\n", " kp_detector = DataParallelWithCallback(kp_detector)\n", " he_estimator = DataParallelWithCallback(he_estimator)\n", "\n", " kp_detector.eval()\n", " he_estimator.eval()\n", "\n", " return kp_detector, he_estimator\n", "\n", "def estimate_latent(driving_video, kp_detector, he_estimator):\n", " with torch.no_grad():\n", " predictions = []\n", " driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).cuda()\n", " kp_canonical = kp_detector(driving[:, :, 0])\n", " he_drivings = {'yaw': [], 'pitch': [], 'roll': [], 't': [], 'exp': []}\n", "\n", " for frame_idx in range(driving.shape[2]):\n", " driving_frame = driving[:, :, frame_idx]\n", " he_driving = he_estimator(driving_frame)\n", " for k in he_drivings.keys():\n", " he_drivings[k].append(he_driving[k])\n", " return [kp_canonical, he_drivings]\n", "\n", "def extract_keypoints(extract_list):\n", " kp_detector, he_estimator = load_checkpoints_extractor(config_path='config/vox-256-spade.yaml', checkpoint_path='./ckpt/pretrain_new_274.pth.tar')\n", " if not os.path.exists('./demo/imgs_latent/'):\n", " os.makedirs('./demo/imgs_latent/')\n", " for imgname in tqdm(extract_list):\n", " path_frames = [imgname]\n", " filesname=os.path.basename(imgname)[:-4]\n", " if os.path.exists(f'./demo/imgs_latent/'+filesname+'.npy'):\n", " continue\n", " driving_frames = []\n", " for im in path_frames:\n", " driving_frames.append(imageio.imread(im))\n", " driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_frames]\n", "\n", " kc, he = estimate_latent(driving_video, kp_detector, he_estimator)\n", " kc = kc['value'].cpu().numpy()\n", " for k in he:\n", " he[k] = torch.cat(he[k]).cpu().numpy()\n", " np.save('./demo/imgs_latent/'+filesname, [kc, he])\n", "\n", "def preprocess_cropped_imgs(allimgs_cropped):\n", " extract_list = []\n", " for img_path in allimgs_cropped:\n", " if not os.path.exists(img_path.replace('cropped', 'latent')[:-4]+'.npy'):\n", " extract_list.append(img_path)\n", " if len(extract_list) > 0:\n", " print('=========', \"Extract latent keypoints from New image\", '======')\n", " extract_keypoints(extract_list)\n", "\n", "def test(ckpt, emotype, save_dir=\" \"):\n", " # with open(\"config/vox-transformer2.yaml\") as f:\n", " with open(\"/content/EAT_code/config/deepprompt_eam3d_st_tanh_304_3090_all.yaml\") as f:\n", " config = yaml.load(f, Loader=yaml.FullLoader)\n", " cur_path = os.getcwd()\n", " generator, kp_detector, audio2kptransformer, sidetuning, emotionprompt = build_model(config)\n", " load_ckpt(ckpt, kp_detector=kp_detector, generator=generator, audio2kptransformer=audio2kptransformer, sidetuning=sidetuning, emotionprompt=emotionprompt)\n", "\n", " audio2kptransformer.eval()\n", " generator.eval()\n", " kp_detector.eval()\n", " sidetuning.eval()\n", " emotionprompt.eval()\n", "\n", " all_wavs2 = [f'{root_wav}/{os.path.basename(root_wav)}.wav']\n", " allimg = glob.glob('./demo/imgs/*.jpg')\n", " tmp_allimg_cropped = glob.glob('./demo/imgs_cropped/*.jpg')\n", " preprocess_imgs(allimg, tmp_allimg_cropped) # crop and align images\n", "\n", " allimg_cropped = glob.glob('./demo/imgs_cropped/*.jpg')\n", " preprocess_cropped_imgs(allimg_cropped) # extract latent keypoints if necessary\n", "\n", " for ind in tqdm(range(len(all_wavs2))):\n", " for img_path in tqdm(allimg_cropped):\n", " audio_path = all_wavs2[ind]\n", " # read in data\n", " audio_frames, poseimgs, deep_feature, source_img, he_source, he_driving, num_frames, y_trg, z_trg, latent_path_driving = prepare_test_data(img_path, audio_path, config['model_params']['audio2kp_params'], emotype)\n", "\n", "\n", " with torch.no_grad():\n", " source_img = torch.from_numpy(source_img).unsqueeze(0).cuda()\n", " kp_canonical = kp_detector(source_img, with_feature=True) # {'value': value, 'jacobian': jacobian}\n", " kp_cano = kp_canonical['value']\n", "\n", " x = {}\n", " x['mel'] = audio_frames.unsqueeze(1).unsqueeze(0).cuda()\n", " x['z_trg'] = z_trg.unsqueeze(0).cuda()\n", " x['y_trg'] = torch.tensor(y_trg, dtype=torch.long).cuda().reshape(1)\n", " x['pose'] = poseimgs.cuda()\n", " x['deep'] = deep_feature.cuda().unsqueeze(0)\n", " x['he_driving'] = {'yaw': torch.from_numpy(he_driving['yaw']).cuda().unsqueeze(0),\n", " 'pitch': torch.from_numpy(he_driving['pitch']).cuda().unsqueeze(0),\n", " 'roll': torch.from_numpy(he_driving['roll']).cuda().unsqueeze(0),\n", " 't': torch.from_numpy(he_driving['t']).cuda().unsqueeze(0),\n", " }\n", "\n", " ### emotion prompt\n", " emoprompt, deepprompt = emotionprompt(x)\n", " a2kp_exps = []\n", " emo_exps = []\n", " T = 5\n", " if T == 1:\n", " for i in range(x['mel'].shape[1]):\n", " xi = {}\n", " xi['mel'] = x['mel'][:,i,:,:,:].unsqueeze(1)\n", " xi['z_trg'] = x['z_trg']\n", " xi['y_trg'] = x['y_trg']\n", " xi['pose'] = x['pose'][i,:,:,:,:].unsqueeze(0)\n", " xi['deep'] = x['deep'][:,i,:,:,:].unsqueeze(1)\n", " xi['he_driving'] = {'yaw': x['he_driving']['yaw'][:,i,:].unsqueeze(0),\n", " 'pitch': x['he_driving']['pitch'][:,i,:].unsqueeze(0),\n", " 'roll': x['he_driving']['roll'][:,i,:].unsqueeze(0),\n", " 't': x['he_driving']['t'][:,i,:].unsqueeze(0),\n", " }\n", " he_driving_emo_xi, input_st_xi = audio2kptransformer(xi, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n", " emo_exp = sidetuning(input_st_xi, emoprompt, deepprompt)\n", " a2kp_exps.append(he_driving_emo_xi['emo'])\n", " emo_exps.append(emo_exp)\n", " elif T is not None:\n", " for i in range(x['mel'].shape[1]//T+1):\n", " if i*T >= x['mel'].shape[1]:\n", " break\n", " xi = {}\n", " xi['mel'] = x['mel'][:,i*T:(i+1)*T,:,:,:]\n", " xi['z_trg'] = x['z_trg']\n", " xi['y_trg'] = x['y_trg']\n", " xi['pose'] = x['pose'][i*T:(i+1)*T,:,:,:,:]\n", " xi['deep'] = x['deep'][:,i*T:(i+1)*T,:,:,:]\n", " xi['he_driving'] = {'yaw': x['he_driving']['yaw'][:,i*T:(i+1)*T,:],\n", " 'pitch': x['he_driving']['pitch'][:,i*T:(i+1)*T,:],\n", " 'roll': x['he_driving']['roll'][:,i*T:(i+1)*T,:],\n", " 't': x['he_driving']['t'][:,i*T:(i+1)*T,:],\n", " }\n", " he_driving_emo_xi, input_st_xi = audio2kptransformer(xi, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n", " emo_exp = sidetuning(input_st_xi, emoprompt, deepprompt)\n", " a2kp_exps.append(he_driving_emo_xi['emo'])\n", " emo_exps.append(emo_exp)\n", "\n", " if T is None:\n", " he_driving_emo, input_st = audio2kptransformer(x, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n", " emo_exps = sidetuning(input_st, emoprompt, deepprompt).reshape(-1, 45)\n", " else:\n", " he_driving_emo = {}\n", " he_driving_emo['emo'] = torch.cat(a2kp_exps, dim=0)\n", " emo_exps = torch.cat(emo_exps, dim=0).reshape(-1, 45)\n", "\n", " exp = he_driving_emo['emo']\n", " device = exp.get_device()\n", " exp = torch.mm(exp, expU.t().to(device))\n", " exp = exp + expmean.expand_as(exp).to(device)\n", " exp = exp + emo_exps\n", "\n", "\n", " source_area = ConvexHull(kp_cano[0].cpu().numpy()).volume\n", " exp = exp * source_area\n", "\n", " he_new_driving = {'yaw': torch.from_numpy(he_driving['yaw']).cuda(),\n", " 'pitch': torch.from_numpy(he_driving['pitch']).cuda(),\n", " 'roll': torch.from_numpy(he_driving['roll']).cuda(),\n", " 't': torch.from_numpy(he_driving['t']).cuda(),\n", " 'exp': exp}\n", " he_driving['exp'] = torch.from_numpy(he_driving['exp']).cuda()\n", "\n", " kp_source = keypoint_transformation(kp_canonical, he_source, False)\n", " mean_source = torch.mean(kp_source['value'], dim=1)[0]\n", " kp_driving = keypoint_transformation(kp_canonical, he_new_driving, False)\n", " mean_driving = torch.mean(torch.mean(kp_driving['value'], dim=1), dim=0)\n", " kp_driving['value'] = kp_driving['value']+(mean_source-mean_driving).unsqueeze(0).unsqueeze(0)\n", " bs = kp_source['value'].shape[0]\n", " predictions_gen = []\n", " for i in tqdm(range(num_frames)):\n", " kp_si = {}\n", " kp_si['value'] = kp_source['value'][0].unsqueeze(0)\n", " kp_di = {}\n", " kp_di['value'] = kp_driving['value'][i].unsqueeze(0)\n", " generated = generator(source_img, kp_source=kp_si, kp_driving=kp_di, prompt=emoprompt)\n", " predictions_gen.append(\n", " (np.transpose(generated['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] * 255).astype(np.uint8))\n", "\n", " log_dir = save_dir\n", " os.makedirs(os.path.join(log_dir, \"temp\"), exist_ok=True)\n", "\n", " f_name = os.path.basename(img_path[:-4]) + \"_\" + emotype + \"_\" + os.path.basename(latent_path_driving)[:-4] + \".mp4\"\n", " video_path = os.path.join(log_dir, \"temp\", f_name)\n", " imageio.mimsave(video_path, predictions_gen, fps=25.0)\n", "\n", " save_video = os.path.join(log_dir, f_name)\n", " cmd = r'ffmpeg -loglevel error -y -i \"%s\" -i \"%s\" -vcodec copy -shortest \"%s\"' % (video_path, audio_path, save_video)\n", " os.system(cmd)\n", " os.remove(video_path)\n", "\n", "if __name__ == '__main__':\n", " argparser = argparse.ArgumentParser()\n", " argparser.add_argument(\"--save_dir\", type=str, default=\" \", help=\"path of the output video\")\n", " argparser.add_argument(\"--name\", type=str, default=\"deepprompt_eam3d_all_final_313\", help=\"path of the output video\")\n", " argparser.add_argument(\"--emo\", type=str, default=\"hap\", help=\"emotion type ('ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur')\")\n", " argparser.add_argument(\"--root_wav\", type=str, default='./demo/video_processed/M003_neu_1_001', help=\"emotion type ('ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur')\")\n", " args = argparser.parse_args()\n", "\n", " root_wav=args.root_wav\n", "\n", " if len(args.name) > 1:\n", " name = args.name\n", " print(name)\n", " test(f'/content/EAT_code/ckpt/deepprompt_eam3d_all_final_313.pth.tar', args.emo, save_dir=f'./demo/output/{name}/')\n", "\n" ], "metadata": { "cellView": "form", "id": "cMz72QBbgRsc" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "CmeH0D2ayRrf" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title trasnformer.py with fixed paths\n", "import torch.nn as nn\n", "import torch\n", "from modules.util import mydownres2Dblock\n", "import numpy as np\n", "from modules.util import AntiAliasInterpolation2d, make_coordinate_grid\n", "from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\n", "import torch.nn.functional as F\n", "import copy\n", "\n", "\n", "class PositionalEncoding(nn.Module):\n", "\n", " def __init__(self, d_hid, n_position=200):\n", " super(PositionalEncoding, self).__init__()\n", "\n", " # Not a parameter\n", " self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))\n", "\n", " def _get_sinusoid_encoding_table(self, n_position, d_hid):\n", " ''' Sinusoid position encoding table '''\n", " # TODO: make it with torch instead of numpy\n", "\n", " def get_position_angle_vec(position):\n", " return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]\n", "\n", " sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])\n", " sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i\n", " sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1\n", "\n", " return torch.FloatTensor(sinusoid_table).unsqueeze(0)\n", "\n", " def forward(self, winsize):\n", " return self.pos_table[:, :winsize].clone().detach()\n", "\n", "def _get_activation_fn(activation):\n", " \"\"\"Return an activation function given a string\"\"\"\n", " if activation == \"relu\":\n", " return F.relu\n", " if activation == \"gelu\":\n", " return F.gelu\n", " if activation == \"glu\":\n", " return F.glu\n", " raise RuntimeError(F\"activation should be relu/gelu, not {activation}.\")\n", "\n", "def _get_clones(module, N):\n", " return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n", "\n", "### light weight transformer encoder\n", "class TransformerST(nn.Module):\n", "\n", " def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n", " num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n", " activation=\"relu\", normalize_before=False):\n", " super().__init__()\n", "\n", " encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n", " dropout, activation, normalize_before)\n", " encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n", " self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n", "\n", " self._reset_parameters()\n", "\n", " self.d_model = d_model\n", " self.nhead = nhead\n", "\n", " def _reset_parameters(self):\n", " for p in self.parameters():\n", " if p.dim() > 1:\n", " nn.init.xavier_uniform_(p)\n", "\n", " def forward(self, src, pos_embed):\n", " # flatten NxCxHxW to HWxNxC\n", "\n", " src = src.permute(1, 0, 2)\n", " pos_embed = pos_embed.permute(1, 0, 2)\n", "\n", " memory = self.encoder(src, pos=pos_embed)\n", "\n", " return memory\n", "\n", "class TransformerEncoder(nn.Module):\n", "\n", " def __init__(self, encoder_layer, num_layers, norm=None):\n", " super().__init__()\n", " self.layers = _get_clones(encoder_layer, num_layers)\n", " self.num_layers = num_layers\n", " self.norm = norm\n", "\n", " def forward(self, src, mask = None, src_key_padding_mask = None, pos = None):\n", " output = src+pos\n", "\n", " for layer in self.layers:\n", " output = layer(output, src_mask=mask,\n", " src_key_padding_mask=src_key_padding_mask, pos=pos)\n", "\n", " if self.norm is not None:\n", " output = self.norm(output)\n", "\n", " return output\n", "\n", "class TransformerDecoder(nn.Module):\n", "\n", " def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):\n", " super().__init__()\n", " self.layers = _get_clones(decoder_layer, num_layers)\n", " self.num_layers = num_layers\n", " self.norm = norm\n", " self.return_intermediate = return_intermediate\n", "\n", " def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,\n", " memory_key_padding_mask = None,\n", " pos = None,\n", " query_pos = None):\n", " output = tgt+pos+query_pos\n", "\n", " intermediate = []\n", "\n", " for layer in self.layers:\n", " output = layer(output, memory, tgt_mask=tgt_mask,\n", " memory_mask=memory_mask,\n", " tgt_key_padding_mask=tgt_key_padding_mask,\n", " memory_key_padding_mask=memory_key_padding_mask,\n", " pos=pos, query_pos=query_pos)\n", " if self.return_intermediate:\n", " intermediate.append(self.norm(output))\n", "\n", " if self.norm is not None:\n", " output = self.norm(output)\n", " if self.return_intermediate:\n", " intermediate.pop()\n", " intermediate.append(output)\n", "\n", " if self.return_intermediate:\n", " return torch.stack(intermediate)\n", "\n", " return output.unsqueeze(0)\n", "\n", "\n", "class Transformer(nn.Module):\n", "\n", " def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n", " num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n", " activation=\"relu\", normalize_before=False,\n", " return_intermediate_dec=True):\n", " super().__init__()\n", "\n", " encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n", " dropout, activation, normalize_before)\n", " encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n", " self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n", "\n", " decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,\n", " dropout, activation, normalize_before)\n", " decoder_norm = nn.LayerNorm(d_model)\n", " self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,\n", " return_intermediate=return_intermediate_dec)\n", "\n", " self._reset_parameters()\n", "\n", " self.d_model = d_model\n", " self.nhead = nhead\n", "\n", " def _reset_parameters(self):\n", " for p in self.parameters():\n", " if p.dim() > 1:\n", " nn.init.xavier_uniform_(p)\n", "\n", " def forward(self, src, query_embed, pos_embed):\n", " # flatten NxCxHxW to HWxNxC\n", "\n", " src = src.permute(1, 0, 2)\n", " pos_embed = pos_embed.permute(1, 0, 2)\n", " query_embed = query_embed.permute(1, 0, 2)\n", "\n", " tgt = torch.zeros_like(query_embed)\n", " memory = self.encoder(src, pos=pos_embed)\n", "\n", " hs = self.decoder(tgt, memory,\n", " pos=pos_embed, query_pos=query_embed)\n", " return hs, memory\n", "\n", "\n", "class TransformerDeep(nn.Module):\n", "\n", " def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n", " num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n", " activation=\"relu\", normalize_before=False,\n", " return_intermediate_dec=True):\n", " super().__init__()\n", "\n", " encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n", " dropout, activation, normalize_before)\n", " encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n", " self.encoder = TransformerEncoderDeep(encoder_layer, num_encoder_layers, encoder_norm)\n", "\n", " decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,\n", " dropout, activation, normalize_before)\n", " decoder_norm = nn.LayerNorm(d_model)\n", " self.decoder = TransformerDecoderDeep(decoder_layer, num_decoder_layers, decoder_norm,\n", " return_intermediate=return_intermediate_dec)\n", "\n", " self._reset_parameters()\n", "\n", " self.d_model = d_model\n", " self.nhead = nhead\n", "\n", " def _reset_parameters(self):\n", " for p in self.parameters():\n", " if p.dim() > 1:\n", " nn.init.xavier_uniform_(p)\n", "\n", " def forward(self, src, query_embed, pos_embed, deepprompt):\n", " # flatten NxCxHxW to HWxNxC\n", "\n", " # print('src before permute: ', src.shape) # 5, 12, 128\n", " src = src.permute(1, 0, 2)\n", " # print('src after permute: ', src.shape) # 12, 5, 128\n", " pos_embed = pos_embed.permute(1, 0, 2)\n", " query_embed = query_embed.permute(1, 0, 2)\n", "\n", " tgt = torch.zeros_like(query_embed) # actually is tgt + query_embed\n", " memory = self.encoder(src, deepprompt, pos=pos_embed)\n", "\n", " hs = self.decoder(tgt, deepprompt, memory,\n", " pos=pos_embed, query_pos=query_embed)\n", " return hs, memory\n", "\n", "class TransformerEncoderDeep(nn.Module):\n", "\n", " def __init__(self, encoder_layer, num_layers, norm=None):\n", " super().__init__()\n", " self.layers = _get_clones(encoder_layer, num_layers)\n", " self.num_layers = num_layers\n", " self.norm = norm\n", "\n", " def forward(self, src, deepprompt, mask = None, src_key_padding_mask = None, pos = None):\n", " # print('input: ', src.shape) # 12 5 128\n", " # print('deepprompt:', deepprompt.shape) # 1 6 128\n", " ### TODO: add deep prompt in encoder\n", " bs = src.shape[1]\n", " bbs = deepprompt.shape[0]\n", " idx=0\n", " emoprompt = deepprompt[:,idx,:]\n", " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n", " # print(emoprompt.shape) # 1 5 128\n", " src = torch.cat([src, emoprompt], dim=0)\n", " # print(src.shape) # 13 5 128\n", " output = src+pos\n", "\n", " for layer in self.layers:\n", " output = layer(output, src_mask=mask,\n", " src_key_padding_mask=src_key_padding_mask, pos=pos)\n", "\n", " ### deep prompt\n", " if idx+1 < len(self.layers):\n", " idx = idx + 1\n", " # print(idx)\n", " emoprompt = deepprompt[:,idx,:]\n", " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n", " # print(output.shape) # 13 5 128\n", " output = torch.cat([output[:-1], emoprompt], dim=0)\n", " # print(output.shape) # 13 5 128\n", "\n", " if self.norm is not None:\n", " output = self.norm(output)\n", "\n", " return output\n", "\n", "class TransformerDecoderDeep(nn.Module):\n", "\n", " def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):\n", " super().__init__()\n", " self.layers = _get_clones(decoder_layer, num_layers)\n", " self.num_layers = num_layers\n", " self.norm = norm\n", " self.return_intermediate = return_intermediate\n", "\n", " def forward(self, tgt, deepprompt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,\n", " memory_key_padding_mask = None,\n", " pos = None,\n", " query_pos = None):\n", " # print('input: ', query_pos.shape) 12 5 128\n", " ### TODO: add deep prompt in encoder\n", " bs = query_pos.shape[1]\n", " bbs = deepprompt.shape[0]\n", " idx=0\n", " emoprompt = deepprompt[:,idx,:]\n", " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n", " query_pos = torch.cat([query_pos, emoprompt], dim=0)\n", " # print(query_pos.shape) # 13 5 128\n", " # print(torch.sum(tgt)) # 0\n", " output = pos+query_pos\n", "\n", " intermediate = []\n", "\n", " for layer in self.layers:\n", " output = layer(output, memory, tgt_mask=tgt_mask,\n", " memory_mask=memory_mask,\n", " tgt_key_padding_mask=tgt_key_padding_mask,\n", " memory_key_padding_mask=memory_key_padding_mask,\n", " pos=pos, query_pos=query_pos)\n", " if self.return_intermediate:\n", " intermediate.append(self.norm(output))\n", "\n", " ### deep prompt\n", " if idx+1 < len(self.layers):\n", " idx = idx + 1\n", " # print(idx)\n", " emoprompt = deepprompt[:,idx,:]\n", " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n", " # print(output.shape) # 13 5 128\n", " output = torch.cat([output[:-1], emoprompt], dim=0)\n", " # print(output.shape) # 13 5 128\n", "\n", " if self.norm is not None:\n", " output = self.norm(output)\n", " if self.return_intermediate:\n", " intermediate.pop()\n", " intermediate.append(output)\n", "\n", " if self.return_intermediate:\n", " return torch.stack(intermediate)\n", "\n", " return output.unsqueeze(0)\n", "\n", "\n", "class TransformerEncoderLayer(nn.Module):\n", "\n", " def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,\n", " activation=\"relu\", normalize_before=False):\n", " super().__init__()\n", " self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n", " # Implementation of Feedforward model\n", " self.linear1 = nn.Linear(d_model, dim_feedforward)\n", " self.dropout = nn.Dropout(dropout)\n", " self.linear2 = nn.Linear(dim_feedforward, d_model)\n", "\n", " self.norm1 = nn.LayerNorm(d_model)\n", " self.norm2 = nn.LayerNorm(d_model)\n", " self.dropout1 = nn.Dropout(dropout)\n", " self.dropout2 = nn.Dropout(dropout)\n", "\n", " self.activation = _get_activation_fn(activation)\n", " self.normalize_before = normalize_before\n", "\n", " def with_pos_embed(self, tensor, pos):\n", " return tensor if pos is None else tensor + pos\n", "\n", " def forward_post(self,\n", " src,\n", " src_mask = None,\n", " src_key_padding_mask = None,\n", " pos = None):\n", " # q = k = self.with_pos_embed(src, pos)\n", " src2 = self.self_attn(src, src, value=src, attn_mask=src_mask,\n", " key_padding_mask=src_key_padding_mask)[0]\n", " src = src + self.dropout1(src2)\n", " src = self.norm1(src)\n", " src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))\n", " src = src + self.dropout2(src2)\n", " src = self.norm2(src)\n", " return src\n", "\n", " def forward_pre(self, src,\n", " src_mask = None,\n", " src_key_padding_mask = None,\n", " pos = None):\n", " src2 = self.norm1(src)\n", " # q = k = self.with_pos_embed(src2, pos)\n", " src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask,\n", " key_padding_mask=src_key_padding_mask)[0]\n", " src = src + self.dropout1(src2)\n", " src2 = self.norm2(src)\n", " src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))\n", " src = src + self.dropout2(src2)\n", " return src\n", "\n", " def forward(self, src,\n", " src_mask = None,\n", " src_key_padding_mask = None,\n", " pos = None):\n", " if self.normalize_before:\n", " return self.forward_pre(src, src_mask, src_key_padding_mask, pos)\n", " return self.forward_post(src, src_mask, src_key_padding_mask, pos)\n", "\n", "class TransformerDecoderLayer(nn.Module):\n", "\n", " def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,\n", " activation=\"relu\", normalize_before=False):\n", " super().__init__()\n", " self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n", " self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n", " # Implementation of Feedforward model\n", " self.linear1 = nn.Linear(d_model, dim_feedforward)\n", " self.dropout = nn.Dropout(dropout)\n", " self.linear2 = nn.Linear(dim_feedforward, d_model)\n", "\n", " self.norm1 = nn.LayerNorm(d_model)\n", " self.norm2 = nn.LayerNorm(d_model)\n", " self.norm3 = nn.LayerNorm(d_model)\n", " self.dropout1 = nn.Dropout(dropout)\n", " self.dropout2 = nn.Dropout(dropout)\n", " self.dropout3 = nn.Dropout(dropout)\n", "\n", " self.activation = _get_activation_fn(activation)\n", " self.normalize_before = normalize_before\n", "\n", " def with_pos_embed(self, tensor, pos):\n", " return tensor if pos is None else tensor + pos\n", "\n", " def forward_post(self, tgt, memory,\n", " tgt_mask = None,\n", " memory_mask = None,\n", " tgt_key_padding_mask = None,\n", " memory_key_padding_mask = None,\n", " pos = None,\n", " query_pos = None):\n", " # q = k = self.with_pos_embed(tgt, query_pos)\n", " tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask,\n", " key_padding_mask=tgt_key_padding_mask)[0]\n", " tgt = tgt + self.dropout1(tgt2)\n", " tgt = self.norm1(tgt)\n", " tgt2 = self.multihead_attn(query=tgt,\n", " key=memory,\n", " value=memory, attn_mask=memory_mask,\n", " key_padding_mask=memory_key_padding_mask)[0]\n", " tgt = tgt + self.dropout2(tgt2)\n", " tgt = self.norm2(tgt)\n", " tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))\n", " tgt = tgt + self.dropout3(tgt2)\n", " tgt = self.norm3(tgt)\n", " return tgt\n", "\n", " def forward_pre(self, tgt, memory,\n", " tgt_mask = None,\n", " memory_mask = None,\n", " tgt_key_padding_mask = None,\n", " memory_key_padding_mask = None,\n", " pos = None,\n", " query_pos = None):\n", " tgt2 = self.norm1(tgt)\n", " # q = k = self.with_pos_embed(tgt2, query_pos)\n", " tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask,\n", " key_padding_mask=tgt_key_padding_mask)[0]\n", " tgt = tgt + self.dropout1(tgt2)\n", " tgt2 = self.norm2(tgt)\n", " tgt2 = self.multihead_attn(query=tgt2,\n", " key=memory,\n", " value=memory, attn_mask=memory_mask,\n", " key_padding_mask=memory_key_padding_mask)[0]\n", " tgt = tgt + self.dropout2(tgt2)\n", " tgt2 = self.norm3(tgt)\n", " tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))\n", " tgt = tgt + self.dropout3(tgt2)\n", " return tgt\n", "\n", " def forward(self, tgt, memory,\n", " tgt_mask = None,\n", " memory_mask = None,\n", " tgt_key_padding_mask = None,\n", " memory_key_padding_mask = None,\n", " pos = None,\n", " query_pos = None):\n", " if self.normalize_before:\n", " return self.forward_pre(tgt, memory, tgt_mask, memory_mask,\n", " tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)\n", " return self.forward_post(tgt, memory, tgt_mask, memory_mask,\n", " tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)\n", "\n", "from Utils.JDC.model import JDCNet\n", "from modules.audioencoder import AudioEncoder, MappingNetwork, StyleEncoder, AdaIN, EAModule\n", "\n", "def headpose_pred_to_degree(pred):\n", " device = pred.device\n", " idx_tensor = [idx for idx in range(66)]\n", " idx_tensor = torch.FloatTensor(idx_tensor).to(device)\n", " pred = F.softmax(pred, dim=1)\n", " degree = torch.sum(pred*idx_tensor, axis=1)\n", " # degree = F.one_hot(degree.to(torch.int64), num_classes=66)\n", " return degree\n", "\n", "def get_rotation_matrix(yaw, pitch, roll):\n", " yaw = yaw / 180 * 3.14\n", " pitch = pitch / 180 * 3.14\n", " roll = roll / 180 * 3.14\n", "\n", " roll = roll.unsqueeze(1)\n", " pitch = pitch.unsqueeze(1)\n", " yaw = yaw.unsqueeze(1)\n", "\n", " pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),\n", " torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),\n", " torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)\n", "\n", " pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)\n", "\n", " yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw),\n", " torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),\n", " -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)\n", " yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)\n", "\n", " roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),\n", " torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),\n", " torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)\n", " roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)\n", "\n", " rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)\n", "\n", " return yaw, pitch, roll, yaw_mat.view(yaw_mat.shape[0], 9), pitch_mat.view(pitch_mat.shape[0], 9), roll_mat.view(roll_mat.shape[0], 9), rot_mat.view(rot_mat.shape[0], 9)\n", "\n", "class Audio2kpTransformerBBoxQDeep(nn.Module):\n", " def __init__(self, embedding_dim, num_kp, num_w, face_adain=False):\n", " super(Audio2kpTransformerBBoxQDeep, self).__init__()\n", " self.embedding_dim = embedding_dim\n", " self.num_kp = num_kp\n", " self.num_w = num_w\n", "\n", "\n", " self.embedding = nn.Embedding(41, embedding_dim)\n", "\n", " self.face_shrink = nn.Linear(240, 32)\n", " self.hp_extractor = nn.Linear(45, 128)\n", "\n", " self.pos_enc = PositionalEncoding(128,20)\n", " input_dim = 1\n", "\n", " self.decode_dim = 64\n", " self.audio_embedding = nn.Sequential( # n x 29 x 16\n", " nn.Conv1d(29, 32, kernel_size=3, stride=2,\n", " padding=1, bias=True), # n x 32 x 8\n", " nn.LeakyReLU(0.02, True),\n", " nn.Conv1d(32, 32, kernel_size=3, stride=2,\n", " padding=1, bias=True), # n x 32 x 4\n", " nn.LeakyReLU(0.02, True),\n", " nn.Conv1d(32, 64, kernel_size=3, stride=2,\n", " padding=1, bias=True), # n x 64 x 2\n", " nn.LeakyReLU(0.02, True),\n", " nn.Conv1d(64, 64, kernel_size=3, stride=2,\n", " padding=1, bias=True), # n x 64 x 1\n", " nn.LeakyReLU(0.02, True),\n", " )\n", " self.encoder_fc1 = nn.Sequential(\n", " nn.Linear(192, 128),\n", " nn.LeakyReLU(0.02, True),\n", " nn.Linear(128, 128),\n", " )\n", "\n", " self.audio_embedding2 = nn.Sequential(nn.Conv2d(1, 8, (3, 17), stride=(1, 1), padding=(1, 0)),\n", " # nn.GroupNorm(4, 8, affine=True),\n", " BatchNorm2d(8),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(8, 32, (13, 13), stride=(1, 1), padding=(6, 6)))\n", "\n", " self.audioencoder = AudioEncoder(dim_in=64, style_dim=128, max_conv_dim=512, w_hpf=0, F0_channel=256)\n", " # self.mappingnet = MappingNetwork(latent_dim=16, style_dim=128, num_domains=8, hidden_dim=512)\n", " # self.stylenet = StyleEncoder(dim_in=64, style_dim=64, num_domains=8, max_conv_dim=512)\n", " self.face_adain = face_adain\n", " if self.face_adain:\n", " self.fadain = AdaIN(style_dim=128, num_features=32)\n", " # norm = 'layer_2d' #\n", " norm = 'batch'\n", "\n", " self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32, normalize = norm),\n", " mydownres2Dblock(32,48, normalize = norm),\n", " mydownres2Dblock(48,64, normalize = norm),\n", " mydownres2Dblock(64,96, normalize = norm),\n", " mydownres2Dblock(96,128, normalize = norm),\n", " nn.AvgPool2d(2))\n", "\n", " self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32),\n", " mydownres2Dblock(32,64),\n", " mydownres2Dblock(64,128),\n", " mydownres2Dblock(128,128),\n", " mydownres2Dblock(128,128),\n", " nn.AvgPool2d(2))\n", " self.transformer = Transformer()\n", " self.kp = nn.Linear(128, 32)\n", "\n", " # for m in self.modules():\n", " # if isinstance(m, nn.Conv2d):\n", " # # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n", " # nn.init.xavier_uniform_(m.weight, gain=1)\n", "\n", " # if isinstance(m, nn.Linear):\n", " # # trunc_normal_(m.weight, std=.03)\n", " # nn.init.xavier_uniform_(m.weight, gain=1)\n", " # if isinstance(m, nn.Linear) and m.bias is not None:\n", " # nn.init.constant_(m.bias, 0)\n", "\n", " F0_path = './Utils/JDC/bst.t7'\n", " F0_model = JDCNet(num_class=1, seq_len=32)\n", " params = torch.load(F0_path, map_location='cpu')['net']\n", " F0_model.load_state_dict(params)\n", " self.f0_model = F0_model\n", "\n", " def rotation_and_translation(self, headpose, bbs, bs):\n", " # print(headpose['roll'].shape, headpose['yaw'].shape, headpose['pitch'].shape, headpose['t'].shape)\n", "\n", " yaw = headpose_pred_to_degree(headpose['yaw'].reshape(bbs*bs, -1))\n", " pitch = headpose_pred_to_degree(headpose['pitch'].reshape(bbs*bs, -1))\n", " roll = headpose_pred_to_degree(headpose['roll'].reshape(bbs*bs, -1))\n", " yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v = get_rotation_matrix(yaw, pitch, roll)\n", " t = headpose['t'].reshape(bbs*bs, -1)\n", " # hp = torch.cat([yaw, pitch, roll, yaw_v, pitch_v, roll_v, t], dim=1)\n", " hp = torch.cat([yaw.unsqueeze(1), pitch.unsqueeze(1), roll.unsqueeze(1), yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v, t], dim=1)\n", " # hp = torch.cat([yaw, pitch, roll, torch.sin(yaw), torch.sin(pitch), torch.sin(roll), torch.cos(yaw), torch.cos(pitch), torch.cos(roll), t], dim=1)\n", " return hp\n", "\n", " def forward(self, x, initial_kp = None, return_strg=False, emoprompt=None, hp=None, side=False):\n", " bbs, bs, seqlen, _, _ = x['deep'].shape\n", " # ph = x[\"pho\"].reshape(bbs*bs*seqlen, 1)\n", " if hp is None:\n", " hp = self.rotation_and_translation(x['he_driving'], bbs, bs)\n", " hp = self.hp_extractor(hp)\n", "\n", " pose_feature = x[\"pose\"].reshape(bbs*bs*seqlen,1,64,64)\n", " # pose_feature = self.down_pose(pose).contiguous()\n", " ### phoneme input feature\n", " # phoneme_embedding = self.embedding(ph.long())\n", " # phoneme_embedding = phoneme_embedding.reshape(bbs*bs*seqlen, 1, 16, 16)\n", " # phoneme_embedding = F.interpolate(phoneme_embedding, scale_factor=4)\n", " # input_feature = torch.cat((pose_feature, phoneme_embedding), dim=1)\n", " # print('input_feature: ', input_feature.shape)\n", " # input_feature = phoneme_embedding\n", "\n", " audio = x['deep'].reshape(bbs*bs*seqlen, 16, 29).permute(0, 2, 1)\n", " deep_feature = self.audio_embedding(audio).squeeze(-1)# ([264, 32, 16, 16])\n", " # print(deep_feature.shape)\n", "\n", " input_feature = pose_feature\n", " # print(input_feature.shape)\n", " # assert(0)\n", " input_feature = self.feature_extract(input_feature).reshape(bbs*bs*seqlen, 128)\n", " input_feature = torch.cat([input_feature, deep_feature], dim=1)\n", " input_feature = self.encoder_fc1(input_feature).reshape(bbs*bs, seqlen, 128)\n", " # phoneme_embedding = self.phoneme_shrink(phoneme_embedding.squeeze(1))# 24*11, 128\n", " input_feature = torch.cat([input_feature, hp.unsqueeze(1)], dim=1)\n", "\n", " ### decode audio feature\n", " ### use iteration to avoid batchnorm2d in different audio sequence\n", " decoder_features = []\n", " for i in range(bbs):\n", " F0 = self.f0_model.get_feature_GAN(x['mel'][i].reshape(bs, 1, 80, seqlen))\n", " if emoprompt is None:\n", " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=None, masks=None, F0=F0))\n", " else:\n", " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=emoprompt[i].unsqueeze(0), masks=None, F0=F0))\n", " audio2 = torch.permute(audio_feature, (0, 3, 1, 2)).reshape(bs*seqlen, 1, 64, 80)\n", " decoder_feature = self.audio_embedding2(audio2)\n", "\n", " # decoder_feature = torch.cat([decoder_feature, audio2], dim=1)\n", " # decoder_feature = F.interpolate(decoder_feature, scale_factor=2)# ([264, 35, 64, 64])\n", " face_map = initial_kp[\"prediction_map\"][i].reshape(15*16, 64*64).permute(1, 0).reshape(64*64, 15*16)\n", " feature_map = self.face_shrink(face_map).permute(1, 0).reshape(1, 32, 64, 64)\n", " if self.face_adain:\n", " feature_map = self.fadain(feature_map, emoprompt)\n", " decoder_feature = self.decodefeature_extract(torch.cat(\n", " (decoder_feature,\n", " feature_map.repeat(bs, seqlen, 1, 1, 1).reshape(bs * seqlen, 32, 64, 64)),\n", " dim=1)).reshape(bs, seqlen, 128)\n", " decoder_features.append(decoder_feature)\n", " decoder_feature = torch.cat(decoder_features, dim=0)\n", "\n", " decoder_feature = torch.cat([decoder_feature, hp.unsqueeze(1)], dim=1)\n", "\n", " # decoder_feature = torch.cat([decoder_feature], dim=1)\n", "\n", " # a2kp transformer\n", " # position embedding\n", " if emoprompt is None:\n", " posi_em = self.pos_enc(self.num_w*2+1+1) # 11 + headpose token\n", " else:\n", " posi_em = self.pos_enc(self.num_w*2+1+1+1) # 11 + headpose token + emotion prompt\n", " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(1)\n", " input_feature = torch.cat([input_feature, emoprompt], dim=1)\n", " decoder_feature = torch.cat([decoder_feature, emoprompt], dim=1)\n", " out = {}\n", " output_feature, memory = self.transformer(input_feature, decoder_feature, posi_em, )\n", " output_feature = output_feature[-1, self.num_w] # returned intermediate output [6, 13, bbs*bs, 128]\n", " out[\"emo\"] = self.kp(output_feature)\n", " if side:\n", " input_st = {}\n", " input_st['hp'] = hp\n", " input_st['decoder_feature'] = decoder_feature\n", " input_st['memory'] = memory\n", " return out, input_st\n", " else:\n", " return out\n", "\n", "\n", "class Audio2kpTransformerBBoxQDeepPrompt(nn.Module):\n", " def __init__(self, embedding_dim, num_kp, num_w, face_ea=False):\n", " super(Audio2kpTransformerBBoxQDeepPrompt, self).__init__()\n", " self.embedding_dim = embedding_dim\n", " self.num_kp = num_kp\n", " self.num_w = num_w\n", "\n", "\n", " self.embedding = nn.Embedding(41, embedding_dim)\n", "\n", " self.face_shrink = nn.Linear(240, 32)\n", " self.hp_extractor = nn.Linear(45, 128)\n", "\n", " self.pos_enc = PositionalEncoding(128,20)\n", " input_dim = 1\n", "\n", " self.decode_dim = 64\n", " self.audio_embedding = nn.Sequential( # n x 29 x 16\n", " nn.Conv1d(29, 32, kernel_size=3, stride=2,\n", " padding=1, bias=True), # n x 32 x 8\n", " nn.LeakyReLU(0.02, True),\n", " nn.Conv1d(32, 32, kernel_size=3, stride=2,\n", " padding=1, bias=True), # n x 32 x 4\n", " nn.LeakyReLU(0.02, True),\n", " nn.Conv1d(32, 64, kernel_size=3, stride=2,\n", " padding=1, bias=True), # n x 64 x 2\n", " nn.LeakyReLU(0.02, True),\n", " nn.Conv1d(64, 64, kernel_size=3, stride=2,\n", " padding=1, bias=True), # n x 64 x 1\n", " nn.LeakyReLU(0.02, True),\n", " )\n", " self.encoder_fc1 = nn.Sequential(\n", " nn.Linear(192, 128),\n", " nn.LeakyReLU(0.02, True),\n", " nn.Linear(128, 128),\n", " )\n", "\n", " self.audio_embedding2 = nn.Sequential(nn.Conv2d(1, 8, (3, 17), stride=(1, 1), padding=(1, 0)),\n", " # nn.GroupNorm(4, 8, affine=True),\n", " BatchNorm2d(8),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(8, 32, (13, 13), stride=(1, 1), padding=(6, 6)))\n", "\n", " self.audioencoder = AudioEncoder(dim_in=64, style_dim=128, max_conv_dim=512, w_hpf=0, F0_channel=256)\n", " self.face_ea = face_ea\n", " if self.face_ea:\n", " self.fea = EAModule(style_dim=128, num_features=32)\n", " norm = 'batch'\n", "\n", " self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32, normalize = norm),\n", " mydownres2Dblock(32,48, normalize = norm),\n", " mydownres2Dblock(48,64, normalize = norm),\n", " mydownres2Dblock(64,96, normalize = norm),\n", " mydownres2Dblock(96,128, normalize = norm),\n", " nn.AvgPool2d(2))\n", "\n", " self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32),\n", " mydownres2Dblock(32,64),\n", " mydownres2Dblock(64,128),\n", " mydownres2Dblock(128,128),\n", " mydownres2Dblock(128,128),\n", " nn.AvgPool2d(2))\n", " self.transformer = TransformerDeep()\n", " self.kp = nn.Linear(128, 32)\n", "\n", " F0_path = '/content/EAT_code/Utils/JDC/bst.t7'\n", " F0_model = JDCNet(num_class=1, seq_len=32)\n", " params = torch.load(F0_path, map_location='cpu')['net']\n", " F0_model.load_state_dict(params)\n", " self.f0_model = F0_model\n", "\n", " def rotation_and_translation(self, headpose, bbs, bs):\n", " yaw = headpose_pred_to_degree(headpose['yaw'].reshape(bbs*bs, -1))\n", " pitch = headpose_pred_to_degree(headpose['pitch'].reshape(bbs*bs, -1))\n", " roll = headpose_pred_to_degree(headpose['roll'].reshape(bbs*bs, -1))\n", " yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v = get_rotation_matrix(yaw, pitch, roll)\n", " t = headpose['t'].reshape(bbs*bs, -1)\n", " hp = torch.cat([yaw.unsqueeze(1), pitch.unsqueeze(1), roll.unsqueeze(1), yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v, t], dim=1)\n", " return hp\n", "\n", " def forward(self, x, initial_kp = None, return_strg=False, emoprompt=None, deepprompt=None, hp=None, side=False):\n", " bbs, bs, seqlen, _, _ = x['deep'].shape\n", " # ph = x[\"pho\"].reshape(bbs*bs*seqlen, 1)\n", " if hp is None:\n", " hp = self.rotation_and_translation(x['he_driving'], bbs, bs)\n", " hp = self.hp_extractor(hp)\n", "\n", " pose_feature = x[\"pose\"].reshape(bbs*bs*seqlen,1,64,64)\n", "\n", " audio = x['deep'].reshape(bbs*bs*seqlen, 16, 29).permute(0, 2, 1)\n", " deep_feature = self.audio_embedding(audio).squeeze(-1)# ([264, 32, 16, 16])\n", "\n", " input_feature = pose_feature\n", " input_feature = self.feature_extract(input_feature).reshape(bbs*bs*seqlen, 128)\n", " input_feature = torch.cat([input_feature, deep_feature], dim=1)\n", " input_feature = self.encoder_fc1(input_feature).reshape(bbs*bs, seqlen, 128)\n", " input_feature = torch.cat([input_feature, hp.unsqueeze(1)], dim=1)\n", "\n", " ### decode audio feature\n", " ### use iteration to avoid batchnorm2d in different audio sequence\n", " decoder_features = []\n", " for i in range(bbs):\n", " F0 = self.f0_model.get_feature_GAN(x['mel'][i].reshape(bs, 1, 80, seqlen))\n", " if emoprompt is None:\n", " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=None, masks=None, F0=F0))\n", " else:\n", " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=emoprompt[i].unsqueeze(0), masks=None, F0=F0))\n", " audio2 = torch.permute(audio_feature, (0, 3, 1, 2)).reshape(bs*seqlen, 1, 64, 80)\n", " decoder_feature = self.audio_embedding2(audio2)\n", "\n", " face_map = initial_kp[\"prediction_map\"][i].reshape(15*16, 64*64).permute(1, 0).reshape(64*64, 15*16)\n", " face_feature_map = self.face_shrink(face_map).permute(1, 0).reshape(1, 32, 64, 64)\n", " if self.face_ea:\n", " face_feature_map = self.fea(face_feature_map, emoprompt)\n", " decoder_feature = self.decodefeature_extract(torch.cat(\n", " (decoder_feature,\n", " face_feature_map.repeat(bs, seqlen, 1, 1, 1).reshape(bs * seqlen, 32, 64, 64)),\n", " dim=1)).reshape(bs, seqlen, 128)\n", " decoder_features.append(decoder_feature)\n", " decoder_feature = torch.cat(decoder_features, dim=0)\n", "\n", " decoder_feature = torch.cat([decoder_feature, hp.unsqueeze(1)], dim=1)\n", "\n", " # a2kp transformer\n", " # position embedding\n", " if emoprompt is None:\n", " posi_em = self.pos_enc(self.num_w*2+1+1) # 11 + headpose token\n", " else:\n", " posi_em = self.pos_enc(self.num_w*2+1+1+1) # 11 + headpose token + deep emotion prompt\n", " out = {}\n", " output_feature, memory = self.transformer(input_feature, decoder_feature, posi_em, deepprompt)\n", " output_feature = output_feature[-1, self.num_w] # returned intermediate output [6, 13, bbs*bs, 128]\n", " out[\"emo\"] = self.kp(output_feature)\n", " if side:\n", " input_st = {}\n", " input_st['hp'] = hp\n", " input_st['face_feature_map'] = face_feature_map\n", " input_st['bs'] = bs\n", " input_st['bbs'] = bbs\n", " return out, input_st\n", " else:\n", " return out\n", "\n", "\n" ], "metadata": { "cellView": "form", "id": "DZwVMAPDgZvO" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title deepspeechfeatures.py fixed\n", "\n", "\"\"\"\n", " DeepSpeech features processing routines.\n", " NB: Based on VOCA code. See the corresponding license restrictions.\n", "\"\"\"\n", "\n", "__all__ = ['conv_audios_to_deepspeech']\n", "\n", "import numpy as np\n", "import warnings\n", "import resampy\n", "from scipy.io import wavfile\n", "from python_speech_features import mfcc\n", "import tensorflow as tf\n", "\n", "\n", "def conv_audios_to_deepspeech(audios,\n", " out_files,\n", " num_frames_info,\n", " deepspeech_pb_path,\n", " audio_window_size=16,\n", " audio_window_stride=1):\n", " \"\"\"\n", " Convert list of audio files into files with DeepSpeech features.\n", "\n", " Parameters\n", " ----------\n", " audios : list of str or list of None\n", " Paths to input audio files.\n", " out_files : list of str\n", " Paths to output files with DeepSpeech features.\n", " num_frames_info : list of int\n", " List of numbers of frames.\n", " deepspeech_pb_path : str\n", " Path to DeepSpeech 0.1.0 frozen model.\n", " audio_window_size : int, default 16\n", " Audio window size.\n", " audio_window_stride : int, default 1\n", " Audio window stride.\n", " \"\"\"\n", " graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(deepspeech_pb_path)\n", "\n", " with tf.compat.v1.Session(graph=graph) as sess:\n", " for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):\n", " audio_sample_rate, audio = wavfile.read(audio_file_path)\n", " if audio.ndim != 1:\n", " warnings.warn(\"Audio has multiple channels, the first channel is used\")\n", " audio = audio[:, 0]\n", " ds_features = pure_conv_audio_to_deepspeech(\n", " audio=audio,\n", " audio_sample_rate=audio_sample_rate,\n", " audio_window_size=audio_window_size,\n", " audio_window_stride=audio_window_stride,\n", " num_frames=num_frames,\n", " net_fn=lambda x: sess.run(\n", " logits_ph,\n", " feed_dict={\n", " input_node_ph: x[np.newaxis, ...],\n", " input_lengths_ph: [x.shape[0]]}))\n", " np.save(out_file_path, ds_features)\n", "\n", "\n", "# data_util/deepspeech_features/deepspeech_features.py\n", "def prepare_deepspeech_net(deepspeech_pb_path):\n", " # Load graph and place_holders:\n", " with tf.io.gfile.GFile(deepspeech_pb_path, \"rb\") as f:\n", " graph_def = tf.compat.v1.GraphDef()\n", " graph_def.ParseFromString(f.read())\n", "\n", " graph = tf.compat.v1.get_default_graph()\n", "\n", " tf.import_graph_def(graph_def, name=\"deepspeech\")\n", " # check all graphs\n", " # print('~'*50, [tensor for tensor in graph._nodes_by_name], '~'*50)\n", " # print('~'*50, [tensor.name for tensor in graph.get_operations()], '~'*50)\n", " # i modified\n", " logits_ph = graph.get_tensor_by_name(\"logits:0\")\n", " input_node_ph = graph.get_tensor_by_name(\"input_node:0\")\n", " input_lengths_ph = graph.get_tensor_by_name(\"input_lengths:0\")\n", " # original\n", " # logits_ph = graph.get_tensor_by_name(\"deepspeech/logits:0\")\n", " # input_node_ph = graph.get_tensor_by_name(\"deepspeech/input_node:0\")\n", " # input_lengths_ph = graph.get_tensor_by_name(\"deepspeech/input_lengths:0\")\n", "\n", " return graph, logits_ph, input_node_ph, input_lengths_ph\n", "\n", "\n", "def pure_conv_audio_to_deepspeech(audio,\n", " audio_sample_rate,\n", " audio_window_size,\n", " audio_window_stride,\n", " num_frames,\n", " net_fn):\n", " \"\"\"\n", " Core routine for converting audion into DeepSpeech features.\n", "\n", " Parameters\n", " ----------\n", " audio : np.array\n", " Audio data.\n", " audio_sample_rate : int\n", " Audio sample rate.\n", " audio_window_size : int\n", " Audio window size.\n", " audio_window_stride : int\n", " Audio window stride.\n", " num_frames : int or None\n", " Numbers of frames.\n", " net_fn : func\n", " Function for DeepSpeech model call.\n", "\n", " Returns\n", " -------\n", " np.array\n", " DeepSpeech features.\n", " \"\"\"\n", " target_sample_rate = 16000\n", " if audio_sample_rate != target_sample_rate:\n", " resampled_audio = resampy.resample(\n", " x=audio.astype(np.float),\n", " sr_orig=audio_sample_rate,\n", " sr_new=target_sample_rate)\n", " else:\n", " resampled_audio = audio.astype(np.float32)\n", " input_vector = conv_audio_to_deepspeech_input_vector(\n", " audio=resampled_audio.astype(np.int16),\n", " sample_rate=target_sample_rate,\n", " num_cepstrum=26,\n", " num_context=9)\n", "\n", " network_output = net_fn(input_vector)\n", "\n", " deepspeech_fps = 50\n", " video_fps = 60\n", " audio_len_s = float(audio.shape[0]) / audio_sample_rate\n", " if num_frames is None:\n", " num_frames = int(round(audio_len_s * video_fps))\n", " else:\n", " video_fps = num_frames / audio_len_s\n", " network_output = interpolate_features(\n", " features=network_output[:, 0],\n", " input_rate=deepspeech_fps,\n", " output_rate=video_fps,\n", " output_len=num_frames)\n", "\n", " # Make windows:\n", " zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))\n", " network_output = np.concatenate((zero_pad, network_output, zero_pad), axis=0)\n", " windows = []\n", " for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):\n", " windows.append(network_output[window_index:window_index + audio_window_size])\n", "\n", " return np.array(windows)\n", "\n", "\n", "def conv_audio_to_deepspeech_input_vector(audio,\n", " sample_rate,\n", " num_cepstrum,\n", " num_context):\n", " \"\"\"\n", " Convert audio raw data into DeepSpeech input vector.\n", "\n", " Parameters\n", " ----------\n", " audio : np.array\n", " Audio data.\n", " audio_sample_rate : int\n", " Audio sample rate.\n", " num_cepstrum : int\n", " Number of cepstrum.\n", " num_context : int\n", " Number of context.\n", "\n", " Returns\n", " -------\n", " np.array\n", " DeepSpeech input vector.\n", " \"\"\"\n", " # Get mfcc coefficients:\n", " features = mfcc(\n", " signal=audio,\n", " samplerate=sample_rate,\n", " numcep=num_cepstrum)\n", "\n", " # We only keep every second feature (BiRNN stride = 2):\n", " features = features[::2]\n", "\n", " # One stride per time step in the input:\n", " num_strides = len(features)\n", "\n", " # Add empty initial and final contexts:\n", " empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)\n", " features = np.concatenate((empty_context, features, empty_context))\n", "\n", " # Create a view into the array with overlapping strides of size\n", " # numcontext (past) + 1 (present) + numcontext (future):\n", " window_size = 2 * num_context + 1\n", " train_inputs = np.lib.stride_tricks.as_strided(\n", " features,\n", " shape=(num_strides, window_size, num_cepstrum),\n", " strides=(features.strides[0], features.strides[0], features.strides[1]),\n", " writeable=False)\n", "\n", " # Flatten the second and third dimensions:\n", " train_inputs = np.reshape(train_inputs, [num_strides, -1])\n", "\n", " train_inputs = np.copy(train_inputs)\n", " train_inputs = (train_inputs - np.mean(train_inputs)) / np.std(train_inputs)\n", "\n", " return train_inputs\n", "\n", "\n", "def interpolate_features(features,\n", " input_rate,\n", " output_rate,\n", " output_len):\n", " \"\"\"\n", " Interpolate DeepSpeech features.\n", "\n", " Parameters\n", " ----------\n", " features : np.array\n", " DeepSpeech features.\n", " input_rate : int\n", " input rate (FPS).\n", " output_rate : int\n", " Output rate (FPS).\n", " output_len : int\n", " Output data length.\n", "\n", " Returns\n", " -------\n", " np.array\n", " Interpolated data.\n", " \"\"\"\n", " input_len = features.shape[0]\n", " num_features = features.shape[1]\n", " input_timestamps = np.arange(input_len) / float(input_rate)\n", " output_timestamps = np.arange(output_len) / float(output_rate)\n", " output_features = np.zeros((output_len, num_features))\n", " for feature_idx in range(num_features):\n", " output_features[:, feature_idx] = np.interp(\n", " x=output_timestamps,\n", " xp=input_timestamps,\n", " fp=features[:, feature_idx])\n", " return output_features\n" ], "metadata": { "cellView": "form", "id": "Jq0kqup1Ogqd" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title deepspeechstore fixed\n", "\n", "\"\"\"\n", " Routines for loading DeepSpeech model.\n", "\"\"\"\n", "\n", "__all__ = ['get_deepspeech_model_file']\n", "\n", "import os\n", "import zipfile\n", "import logging\n", "import hashlib\n", "\n", "\n", "deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'\n", "\n", "\n", "def get_deepspeech_model_file(local_model_store_dir_path=os.path.join(\"~\", \"/content/EAT_code/tensorflow\", \"models\")):\n", " \"\"\"\n", " Return location for the pretrained on local file system. This function will download from online model zoo when\n", " model cannot be found or has mismatch. The root directory will be created if it doesn't exist.\n", "\n", " Parameters\n", " ----------\n", " local_model_store_dir_path : str, default $TENSORFLOW_HOME/models\n", " Location for keeping the model parameters.\n", "\n", " Returns\n", " -------\n", " file_path\n", " Path to the requested pretrained model file.\n", " \"\"\"\n", " sha1_hash = \"b90017e816572ddce84f5843f1fa21e6a377975e\"\n", " file_name = \"deepspeech-0_1_0-b90017e8.pb\"\n", " local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)\n", " file_path = os.path.join(local_model_store_dir_path, file_name)\n", " if os.path.exists(file_path):\n", " if _check_sha1(file_path, sha1_hash):\n", " return file_path\n", " else:\n", " logging.warning(\"Mismatch in the content of model file detected. Downloading again.\")\n", " else:\n", " logging.info(\"Model file not found. Downloading to {}.\".format(file_path))\n", "\n", " if not os.path.exists(local_model_store_dir_path):\n", " os.makedirs(local_model_store_dir_path)\n", "\n", " zip_file_path = file_path + \".zip\"\n", " _download(\n", " url=\"{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip\".format(\n", " repo_url=deepspeech_features_repo_url,\n", " repo_release_tag=\"v0.0.1\",\n", " file_name=file_name),\n", " path=zip_file_path,\n", " overwrite=False)\n", " with zipfile.ZipFile(zip_file_path) as zf:\n", " zf.extractall(local_model_store_dir_path)\n", " os.remove(zip_file_path)\n", "\n", " if _check_sha1(file_path, sha1_hash):\n", " return file_path\n", " else:\n", " raise ValueError(\"Downloaded file has different hash. Please try again.\")\n", "\n", "\n", "def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):\n", " \"\"\"\n", " Download an given URL\n", "\n", " Parameters\n", " ----------\n", " url : str\n", " URL to download\n", " path : str, optional\n", " Destination path to store downloaded file. By default stores to the\n", " current directory with same name as in url.\n", " overwrite : bool, optional\n", " Whether to overwrite destination file if already exists.\n", " sha1_hash : str, optional\n", " Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified\n", " but doesn't match.\n", " retries : integer, default 5\n", " The number of times to attempt the download in case of failure or non 200 return codes\n", " verify_ssl : bool, default True\n", " Verify SSL certificates.\n", "\n", " Returns\n", " -------\n", " str\n", " The file path of the downloaded file.\n", " \"\"\"\n", " import warnings\n", " try:\n", " import requests\n", " except ImportError:\n", " class requests_failed_to_import(object):\n", " pass\n", " requests = requests_failed_to_import\n", "\n", " if path is None:\n", " fname = url.split(\"/\")[-1]\n", " # Empty filenames are invalid\n", " assert fname, \"Can't construct file-name from this URL. Please set the `path` option manually.\"\n", " else:\n", " path = os.path.expanduser(path)\n", " if os.path.isdir(path):\n", " fname = os.path.join(path, url.split(\"/\")[-1])\n", " else:\n", " fname = path\n", " assert retries >= 0, \"Number of retries should be at least 0\"\n", "\n", " if not verify_ssl:\n", " warnings.warn(\n", " \"Unverified HTTPS request is being made (verify_ssl=False). \"\n", " \"Adding certificate verification is strongly advised.\")\n", "\n", " if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):\n", " dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))\n", " if not os.path.exists(dirname):\n", " os.makedirs(dirname)\n", " while retries + 1 > 0:\n", " # Disable pyling too broad Exception\n", " # pylint: disable=W0703\n", " try:\n", " print(\"Downloading {} from {}...\".format(fname, url))\n", " #r = requests.get(url, stream=True, verify=verify_ssl)\n", " #if r.status_code != 200:\n", " # raise RuntimeError(\"Failed downloading url {}\".format(url))\n", " #with open(fname, \"wb\") as f:\n", " # for chunk in r.iter_content(chunk_size=1024):\n", " # if chunk: # filter out keep-alive new chunks\n", " # f.write(chunk)\n", " if sha1_hash and not _check_sha1(fname, sha1_hash):\n", " raise UserWarning(\"File {} is downloaded but the content hash does not match.\"\n", " \" The repo may be outdated or download may be incomplete. \"\n", " \"If the `repo_url` is overridden, consider switching to \"\n", " \"the default repo.\".format(fname))\n", " break\n", " except Exception as e:\n", " retries -= 1\n", " if retries <= 0:\n", " raise e\n", " else:\n", " print(\"download failed, retrying, {} attempt{} left\"\n", " .format(retries, \"s\" if retries > 1 else \"\"))\n", "\n", " return fname\n", "\n", "\n", "def _check_sha1(filename, sha1_hash):\n", " \"\"\"\n", " Check whether the sha1 hash of the file content matches the expected hash.\n", "\n", " Parameters\n", " ----------\n", " filename : str\n", " Path to the file.\n", " sha1_hash : str\n", " Expected sha1 hash in hexadecimal digits.\n", "\n", " Returns\n", " -------\n", " bool\n", " Whether the file content matches the expected hash.\n", " \"\"\"\n", " sha1 = hashlib.sha1()\n", " with open(filename, \"rb\") as f:\n", " while True:\n", " data = f.read(1048576)\n", " if not data:\n", " break\n", " sha1.update(data)\n", "\n", " return sha1.hexdigest() == sha1_hash\n" ], "metadata": { "cellView": "form", "id": "BaZ1iwyuO_bl" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!zip -r EAT_code /content/EAT_code" ], "metadata": { "id": "wAFm27NITV_C" }, "execution_count": null, "outputs": [] } ] }