{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchaudio\n", "from IPython.display import Audio, display\n", "\n", "from models.model import Vocos\n", "from utils.audio import LogMelSpectrogram\n", "from config import MelConfig, VocosConfig\n", "\n", "from pathlib import Path\n", "from dataclasses import asdict\n", "import random\n", "\n", "def load_and_resample_audio(audio_path, target_sr):\n", " y, sr = torchaudio.load(audio_path)\n", " if y.size(0) > 1:\n", " y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]\n", " if sr != target_sr:\n", " y = torchaudio.functional.resample(y, sr, target_sr)\n", " return y\n", "\n", "device = 'cpu'\n", "\n", "mel_config = MelConfig()\n", "vocos_config = VocosConfig()\n", "\n", "mel_extractor = LogMelSpectrogram(**asdict(mel_config))\n", "model = Vocos(vocos_config, mel_config).to(device)\n", "model.load_state_dict(torch.load('./checkpoints/generator_0.pt', map_location='cpu'))\n", "model.eval()\n", "\n", "audio_paths = list(Path('./audios').rglob('*.wav'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "audio_path = random.choice(audio_paths)\n", "with torch.inference_mode():\n", " audio = load_and_resample_audio(audio_path, mel_config.sample_rate).to(device)\n", " mel = mel_extractor(audio)\n", " recon_audio = model(mel)\n", "display(Audio(audio, rate=mel_config.sample_rate))\n", "display(Audio(recon_audio, rate=mel_config.sample_rate))" ] } ], "metadata": { "kernelspec": { "display_name": "lxn_vits", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 2 }