diff --git "a/test/encodec_test.ipynb" "b/test/encodec_test.ipynb" new file mode 100644--- /dev/null +++ "b/test/encodec_test.ipynb" @@ -0,0 +1,338 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [], + "source": [ + "datapath = \"/data/jyk/aac_dataset/clotho/validation/01 A pug struggles to breathe 1_14_2008.wav\"" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import Audio" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [], + "source": [ + "from encodec import EncodecModel\n", + "from encodec.utils import convert_audio\n", + "\n", + "import torchaudio\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [], + "source": [ + "model = EncodecModel.encodec_model_24khz()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14851810\n" + ] + } + ], + "source": [ + "from utils import count_parameters\n", + "print(count_parameters(model))" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "model.set_target_bandwidth(6.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18.8\n", + "44100\n", + "24000\n", + "1\n", + "18.8\n" + ] + } + ], + "source": [ + "wav, sr = torchaudio.load(datapath)\n", + "print(wav.shape[-1]/sr)\n", + "wav = convert_audio(wav, sr, model.sample_rate, model.channels)\n", + "wav = wav.unsqueeze(0)\n", + "print(sr)\n", + "print(model.sample_rate)\n", + "print(model.channels)\n", + "print(wav.shape[-1]/model.sample_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " encoded_frames = model.encode(wav)" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[(tensor([[[675, 798, 635, ..., 281, 169, 457],\n", + " [184, 740, 961, ..., 603, 831, 857],\n", + " [996, 832, 967, ..., 273, 599, 771],\n", + " ...,\n", + " [763, 611, 140, ..., 18, 95, 918],\n", + " [938, 862, 674, ..., 661, 193, 364],\n", + " [412, 326, 339, ..., 614, 424, 428]]]), None)]\n" + ] + } + ], + "source": [ + "print(encoded_frames)" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 8, 1410])\n" + ] + } + ], + "source": [ + "print(encoded_frames[0][0].shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [], + "source": [ + "codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 8, 1410])\n" + ] + } + ], + "source": [ + "print(codes.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[184, 740, 961, ..., 603, 831, 857]])\n" + ] + } + ], + "source": [ + "print(codes[:,1, :])" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(0)\n" + ] + } + ], + "source": [ + "print(codes.transpose(1,2)[:,:,1].min())" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": {}, + "outputs": [], + "source": [ + "code_1 = codes+1" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[185, 741, 962, ..., 604, 832, 858]])\n" + ] + } + ], + "source": [ + "print(code_1[:,1, :])" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [], + "source": [ + "decoded_wav = model.decode(encoded_frames)" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1, 451200])" + ] + }, + "execution_count": 89, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "decoded_wav.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": {}, + "outputs": [], + "source": [ + "decoded_wav = decoded_wav.squeeze().squeeze().detach().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Audio" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 93, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Audio(decoded_wav, rate=24000)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}