mms-zeroshot / zeroshot.py
vineelpratap's picture
Update zeroshot.py
175fc0f verified
import os
import tempfile
import re
import librosa
import torch
import json
import numpy as np
from transformers import Wav2Vec2ForCTC, AutoProcessor
from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder
from utils.text_norm import text_normalize
from utils.lm import create_unigram_lm, maybe_generate_pseudo_bigram_arpa
uroman_dir = "uroman"
assert os.path.exists(uroman_dir)
UROMAN_PL = os.path.join(uroman_dir, "bin", "uroman.pl")
ASR_SAMPLING_RATE = 16_000
WORD_SCORE_DEFAULT_IF_LM = -0.18
WORD_SCORE_DEFAULT_IF_NOLM = -3.5
LM_SCORE_DEFAULT = 1.48
MODEL_ID = "mms-meta/mms-zeroshot-300m"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
token_file = hf_hub_download(
repo_id=MODEL_ID,
filename="tokens.txt",
)
class MY_LOG:
def __init__(self):
self.text = "[START]"
def add(self, new_log, new_line=True):
self.text = self.text + ("\n" if new_line else " ") + new_log
self.text = self.text.strip()
return self.text
def error_check_file(filepath):
if not isinstance(filepath, str):
return "Expected file to be of type 'str'. Instead got {}".format(
type(filepath)
)
if not os.path.exists(filepath):
return "Input file '{}' doesn't exists".format(type(filepath))
def norm_uroman(text):
text = text.lower()
text = text.replace("’", "'")
text = re.sub("([^a-z' ])", " ", text)
text = re.sub(" +", " ", text)
return text.strip()
def uromanize(words):
iso = "xxx"
with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
with open(tf.name, "w") as f:
f.write("\n".join(words))
cmd = f"perl " + UROMAN_PL
cmd += f" -l {iso} "
cmd += f" < {tf.name} > {tf2.name}"
os.system(cmd)
lexicon = {}
with open(tf2.name) as f:
for idx, line in enumerate(f):
if not line.strip():
continue
line = re.sub(r"\s+", "", norm_uroman(line)).strip()
lexicon[words[idx]] = " ".join(line) + " |"
return lexicon
def filter_lexicon(lexicon, word_counts):
spelling_to_words = {}
for w, s in lexicon.items():
spelling_to_words.setdefault(s, [])
spelling_to_words[s].append(w)
lexicon = {}
for s, ws in spelling_to_words.items():
if len(ws) > 1:
# use the word which has higest counts, fewed additional characters
ws.sort(key=lambda w: (-word_counts[w], len(w)))
lexicon[ws[0]] = s
return lexicon
def load_words(filepath):
words = {}
with open(filepath) as f:
lines = f.readlines()
num_sentences = len(lines)
all_sentences = " ".join([l.strip() for l in lines])
norm_all_sentences = text_normalize(all_sentences)
for w in norm_all_sentences.split():
words.setdefault(w, 0)
words[w] += 1
return words, num_sentences
def process(
audio_data,
words_file,
lm_path=None,
wscore=None,
lmscore=None,
wscore_usedefault=True,
lmscore_usedefault=True,
autolm=True,
reference=None,
):
transcription, logs = "", MY_LOG()
if not audio_data or not words_file:
yield "ERROR: Empty audio data or words file", logs.text
return
if isinstance(audio_data, tuple):
# microphone
sr, audio_samples = audio_data
audio_samples = (audio_samples / 32768.0).astype(float)
if sr != ASR_SAMPLING_RATE:
audio_samples = librosa.resample(
audio_samples, orig_sr=sr, target_sr=ASR_SAMPLING_RATE
)
else:
# file upload
assert isinstance(audio_data, str)
audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
yield transcription, logs.add(f"Number of audio samples: {len(audio_samples)}")
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
# set device
if torch.cuda.is_available():
device = torch.device("cuda")
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")
#device = torch.device("cpu")
model.to(device)
inputs = inputs.to(device)
yield transcription, logs.add(f"Using device: {device}")
with torch.no_grad():
outputs = model(**inputs).logits
# Setup lexicon and decoder
yield transcription, logs.add(f"Loading words....")
try:
word_counts, num_sentences = load_words(words_file)
except Exception as e:
yield f"ERROR: Loading words failed '{str(e)}'", logs.text
return
yield transcription, logs.add(
f"Loaded {len(word_counts)} words from {num_sentences} lines.\nPreparing lexicon...."
)
try:
lexicon = uromanize(list(word_counts.keys()))
except Exception as e:
yield f"ERROR: Creating lexicon failed '{str(e)}'", logs.text
return
yield transcription, logs.add(f"Leixcon size: {len(lexicon)}")
# Input could be sentences OR list of words. Check if atleast one word has a count > 1 to diffentiate
tmp_file = tempfile.NamedTemporaryFile() # could be used for LM
if autolm and any([cnt > 2 for cnt in word_counts.values()]):
yield transcription, logs.add(f"Creating unigram LM...", False)
lm_path = tmp_file.name
create_unigram_lm(word_counts, num_sentences, lm_path)
yield transcription, logs.add(f"OK")
if lm_path is None:
yield transcription, logs.add(f"Filtering lexicon....")
lexicon = filter_lexicon(lexicon, word_counts)
yield transcription, logs.add(
f"Ok. Leixcon size after filtering: {len(lexicon)}"
)
else:
# kenlm throws an error if unigram LM is being used
# HACK: generate a bigram LM from unigram LM and a dummy bigram to trick it
maybe_generate_pseudo_bigram_arpa(lm_path)
with tempfile.NamedTemporaryFile() as lexicon_file:
if lm_path is not None and not lm_path.strip():
lm_path = None
with open(lexicon_file.name, "w") as f:
idx = 10
for word, spelling in lexicon.items():
f.write(word + " " + spelling + "\n")
idx += 1
if wscore_usedefault:
wscore = (
WORD_SCORE_DEFAULT_IF_LM
if lm_path is not None
else WORD_SCORE_DEFAULT_IF_NOLM
)
if lmscore_usedefault:
lmscore = LM_SCORE_DEFAULT if lm_path is not None else 0
yield transcription, logs.add(
f"Using word score: {wscore}\nUsing lm score: {lmscore}"
)
beam_search_decoder = ctc_decoder(
lexicon=lexicon_file.name,
tokens=token_file,
lm=lm_path,
nbest=1,
beam_size=500,
beam_size_token=50,
lm_weight=lmscore,
word_score=wscore,
sil_score=0,
blank_token="<s>",
)
beam_search_result = beam_search_decoder(outputs.to("cpu"))
transcription = " ".join(beam_search_result[0][0].words).strip()
yield transcription, logs.add(f"[DONE]")
# for i in process("upload/english/english.mp3", "upload/english/c4_5k_sentences.txt"):
# print(i)
# for i in process("upload/ligurian/ligurian_1.mp3", "upload/ligurian/zenamt_5k_sentences.txt"):
# print(i)