File size: 1,233 Bytes
e5e9b34
 
 
 
 
a84c313
e5e9b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a84c313
 
e5e9b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a84c313
e5e9b34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

import librosa
import torch
from transformers import Wav2Vec2ForCTC, AutoProcessor
from transformers import set_seed
import time


def transcribe(fp:str, target_lang:str) -> str:
    ''' 
    For given audio file, transcribe it.
    
    Parameters
    ---------- 
    fp: str
        The file path to the audio file.
    target_lang:str
        The ISO-3 code of the target language.

    Returns
    ---------- 
    transcript:str
        The transcribed text.
    '''
    # Ensure replicability
    set_seed(555) 
    start_time = time.time()
    
    # Load transcription model
    model_id = "facebook/mms-1b-all"

    processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang)
    model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True)

    # Process the audio
    signal, sampling_rate =  librosa.load(fp, sr=16000)
    inputs = processor(signal, sampling_rate=16_000, return_tensors="pt")
    
    # Inference
    with torch.no_grad():
        outputs = model(**inputs).logits
    
    ids = torch.argmax(outputs, dim=-1)[0]
    transcript = processor.decode(ids)

    print("Time elapsed: ", int(time.time() - start_time), " seconds")
    return transcript