import importlib from types import SimpleNamespace import gradio as gr import pandas as pd import spaces import torch from utmosv2.utils import get_dataset, get_model description = ( "# 🚀 UTMOSv2 demo\n\n" "[![GitHub](https://img.shields.io/badge/-GitHub-181717.svg?logo=github&style=flat)](https://github.com/sarulab-speech/UTMOSv2)\n\n" "This is a demonstration of MOS prediction using UTMOSv2. " "This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate." ) device = torch.device("cuda") config = importlib.import_module("utmosv2.config.fusion_stage3") cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")}) cfg.reproduce = False cfg.config = "fusion_stage3" cfg.print_config = False cfg.data_config = None cfg.phase = "inference" cfg.num_workers = 1 @spaces.GPU @torch.inference_mode() def predict_mos(audio_path: str, domain: str, quick: bool) -> float: data = pd.DataFrame({"file_path": [audio_path]}) data["dataset"] = domain data["mos"] = 0 preds = 0.0 for fold in range(5): cfg.now_fold = fold cfg.weight = f"models/fusion_stage3/fold{fold}_s42_best_model.pth" model = get_model(cfg, device).eval() for _ in range(5): test_dataset = get_dataset(cfg, data, "test") p = model(*[torch.tensor(t,dtype=torch.float32).unsqueeze(0).to(device) for t in test_dataset[0][:-1]]) preds += p.cpu().numpy()[0][0] if quick: return preds preds /= 25.0 return preds with gr.Blocks() as demo: gr.Markdown(description) with gr.Row(): with gr.Column(): audio = gr.Audio(type="filepath", label="Audio") domain = gr.Dropdown( [ "sarulab", "bvcc", "somos", "blizzard2008", "blizzard2009", "blizzard2010-EH1", "blizzard2010-EH2", "blizzard2010-ES1", "blizzard2010-ES3", "blizzard2011", ], label="Data-domain ID for the MOS prediction", value="sarulab", ) quick = gr.Checkbox( label="Quick prediction", value=True, info=( "UTMOSv2 makes predictions repeatedly for five randomly selected frames " "of the input speech waveform for all five folds. " "To make quick predictions by reducing this to a single repetition, " "check this checkbox:", ), ) submit = gr.Button(value="Submit") with gr.Column(): output = gr.Textbox(label="Predicted MOS", type="text") submit.click(fn=predict_mos, inputs=[audio, domain, quick], outputs=[output]) demo.queue().launch()