diff --git "a/test/encodec_test.ipynb" "b/test/encodec_test.ipynb"
new file mode 100644--- /dev/null
+++ "b/test/encodec_test.ipynb"
@@ -0,0 +1,338 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "datapath = \"/data/jyk/aac_dataset/clotho/validation/01 A pug struggles to breathe 1_14_2008.wav\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 72,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from datasets import Audio"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from encodec import EncodecModel\n",
+ "from encodec.utils import convert_audio\n",
+ "\n",
+ "import torchaudio\n",
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = EncodecModel.encodec_model_24khz()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 75,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "14851810\n"
+ ]
+ }
+ ],
+ "source": [
+ "from utils import count_parameters\n",
+ "print(count_parameters(model))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 76,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.set_target_bandwidth(6.0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 77,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18.8\n",
+ "44100\n",
+ "24000\n",
+ "1\n",
+ "18.8\n"
+ ]
+ }
+ ],
+ "source": [
+ "wav, sr = torchaudio.load(datapath)\n",
+ "print(wav.shape[-1]/sr)\n",
+ "wav = convert_audio(wav, sr, model.sample_rate, model.channels)\n",
+ "wav = wav.unsqueeze(0)\n",
+ "print(sr)\n",
+ "print(model.sample_rate)\n",
+ "print(model.channels)\n",
+ "print(wav.shape[-1]/model.sample_rate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 78,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with torch.no_grad():\n",
+ " encoded_frames = model.encode(wav)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 79,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[(tensor([[[675, 798, 635, ..., 281, 169, 457],\n",
+ " [184, 740, 961, ..., 603, 831, 857],\n",
+ " [996, 832, 967, ..., 273, 599, 771],\n",
+ " ...,\n",
+ " [763, 611, 140, ..., 18, 95, 918],\n",
+ " [938, 862, 674, ..., 661, 193, 364],\n",
+ " [412, 326, 339, ..., 614, 424, 428]]]), None)]\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(encoded_frames)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 80,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 8, 1410])\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(encoded_frames[0][0].shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 81,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 82,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 8, 1410])\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(codes.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 83,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor([[184, 740, 961, ..., 603, 831, 857]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(codes[:,1, :])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 84,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor(0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(codes.transpose(1,2)[:,:,1].min())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 85,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "code_1 = codes+1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 86,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor([[185, 741, 962, ..., 604, 832, 858]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(code_1[:,1, :])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 88,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "decoded_wav = model.decode(encoded_frames)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 89,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 1, 451200])"
+ ]
+ },
+ "execution_count": 89,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "decoded_wav.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 92,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "decoded_wav = decoded_wav.squeeze().squeeze().detach().numpy()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 90,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from IPython.display import Audio"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 93,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 93,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "Audio(decoded_wav, rate=24000)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}