Spaces:
Runtime error
Runtime error
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
import soundfile as sf | |
import torch | |
import gradio as gr | |
# load model and processor | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-robust-ft-libri-960h") | |
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-robust-ft-libri-960h") | |
# define function to read in sound file | |
def map_to_array(file): | |
speech, _ = sf.read(file) | |
return speech | |
# tokenize | |
def inference(audio): | |
input_values = processor(map_to_array(audio.name), return_tensors="pt", padding="longest").input_values # Batch size 1 | |
# retrieve logits | |
logits = model(input_values).logits | |
# take argmax and decode | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids) | |
return transcription[0] | |
inputs = gr.inputs.Audio(label="Input Audio", type="file") | |
outputs = gr.outputs.Textbox(label="Output Text") | |
title = "Robust wav2vec 2.0" | |
description = "Gradio demo for Robust wav2vec 2.0. To use it, simply upload your audio, or click one of the examples to load them. Read more at the links below. Currently supports .wav and .flac files" | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.01027' target='_blank'>Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training</a> | <a href='https://github.com/pytorch/fairseq' target='_blank'>Github Repo</a></p>" | |
examples=[['poem.wav']] | |
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples).launch() |