File size: 4,032 Bytes
7bcf8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import librosa
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
import json

from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder

ASR_SAMPLING_RATE = 16_000

ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
    for line in f:
        iso, name = line.split(" ", 1)
        ASR_LANGUAGES[iso] = name

MODEL_ID = "facebook/mms-1b-all"

processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)


lm_decoding_config = {}
# lm_decoding_configfile = hf_hub_download(
#     repo_id="facebook/mms-cclms",
#     filename="decoding_config.json",
#     subfolder="mms-1b-all",
# )

# with open(lm_decoding_configfile) as f:
#     lm_decoding_config = json.loads(f.read())

# allow language model decoding for specific languages
lm_decode_isos = ["eng"]


def transcribe(
    audio_source=None, microphone=None, file_upload=None, lang="eng (English)"
):
    if type(microphone) is dict:
        # HACK: microphone variable is a dict when running on examples
        microphone = microphone["name"]
    audio_fp = (
        file_upload if "upload" in str(audio_source or "").lower() else microphone
    )
    audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]

    lang_code = lang.split()[0]
    processor.tokenizer.set_target_lang(lang_code)
    model.load_adapter(lang_code)

    inputs = processor(
        audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
    )

    # set device
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif (
        hasattr(torch.backends, "mps")
        and torch.backends.mps.is_available()
        and torch.backends.mps.is_built()
    ):
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    model.to(device)
    inputs = inputs.to(device)

    with torch.no_grad():
        outputs = model(**inputs).logits

    if lang_code not in lm_decoding_config or lang_code not in lm_decode_isos:
        ids = torch.argmax(outputs, dim=-1)[0]
        transcription = processor.decode(ids)
    else:
        decoding_config = lm_decoding_config[lang_code]

        lm_file = hf_hub_download(
            repo_id="facebook/mms-cclms",
            filename=decoding_config["lmfile"].rsplit("/", 1)[1],
            subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
        )
        token_file = hf_hub_download(
            repo_id="facebook/mms-cclms",
            filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
            subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
        )
        lexicon_file = None
        if decoding_config["lexiconfile"] is not None:
            lexicon_file = hf_hub_download(
                repo_id="facebook/mms-cclms",
                filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
                subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
            )

        beam_search_decoder = ctc_decoder(
            lexicon=lexicon_file,
            tokens=token_file,
            lm=lm_file,
            nbest=1,
            beam_size=500,
            beam_size_token=50,
            lm_weight=float(decoding_config["lmweight"]),
            word_score=float(decoding_config["wordscore"]),
            sil_score=float(decoding_config["silweight"]),
            blank_token="<s>",
        )
        beam_search_result = beam_search_decoder(outputs.to("cpu"))
        transcription = " ".join(beam_search_result[0][0].words).strip()

    return transcription


ASR_EXAMPLES = [
    [None, "assets/english.mp3", None, "eng (English)"],
    # [None, "assets/tamil.mp3", None, "tam (Tamil)"],
    # [None, "assets/burmese.mp3", None, "mya (Burmese)"],
]

ASR_NOTE = """
The above demo uses beam-search decoding with LM for English and greedy decoding results for all other languages. 
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for other languages.
"""