Spaces:
Build error
Build error
import torch | |
import nltk | |
from scipy.io.wavfile import write | |
import librosa | |
import hashlib | |
from typing import List | |
def embed_questions( | |
question_model, question_tokenizer, questions, max_length=128, device="cpu" | |
): | |
query = question_tokenizer( | |
questions, | |
max_length=max_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt", | |
) | |
with torch.no_grad(): | |
q_reps = question_model( | |
query["input_ids"].to(device), query["attention_mask"].to(device) | |
).pooler_output | |
return q_reps.cpu().numpy() | |
def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cpu"): | |
p = ctx_tokenizer( | |
passages["text"], | |
max_length=max_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt", | |
) | |
with torch.no_grad(): | |
a_reps = ctx_model( | |
p["input_ids"].to(device), p["attention_mask"].to(device) | |
).pooler_output | |
return {"embeddings": a_reps.cpu().numpy()} | |
class Document: | |
def __init__(self, meta={}, content: str = "", id_: str = ""): | |
self.meta = meta | |
self.content = content | |
self.id = id_ | |
def _alter_docs_for_haystack(passages): | |
return [Document(content=passage, id_=str(i)) for i, passage in enumerate(passages)] | |
def embed_passages_haystack( | |
dpr_model, | |
passages, | |
): | |
passages = _alter_docs_for_haystack(passages["text"]) | |
embeddings = dpr_model.embed_documents(passages) | |
return {"embeddings": embeddings} | |
def correct_casing(input_sentence): | |
"""This function is for correcting the casing of the generated transcribed text""" | |
sentences = nltk.sent_tokenize(input_sentence) | |
return " ".join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences]) | |
def clean_transcript(text): | |
text = text.replace("[pad]".upper(), "") | |
return text | |
def add_question_symbols(text): | |
if text[0] != "¿": | |
text = "¿" + text | |
if text[-1] != "?": | |
text = text + "?" | |
return text | |
def remove_chars_to_tts(text): | |
text = text.replace(",", " ") | |
return text | |
def transcript(input_file, audio_array, processor, model): | |
if audio_array: | |
rate, sample = audio_array | |
write("temp.wav", rate, sample) | |
input_file = "temp.wav" | |
transcript = "" | |
# Ensure that the sample rate is 16k | |
sample_rate = librosa.get_samplerate(input_file) | |
# Stream over 10 seconds chunks rather than load the full file | |
stream = librosa.stream( | |
input_file, | |
block_length=20, # number of seconds to split the batch | |
frame_length=sample_rate, # 16000, | |
hop_length=sample_rate, # 16000 | |
) | |
for speech in stream: | |
if len(speech.shape) > 1: | |
speech = speech[:, 0] + speech[:, 1] | |
if sample_rate != 16000: | |
speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000) | |
input_values = processor(speech, return_tensors="pt").input_values | |
logits = model(input_values).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.decode( | |
predicted_ids[0], | |
clean_up_tokenization_spaces=True, | |
skip_special_tokens=True, | |
) | |
transcription = clean_transcript(transcription) | |
# transcript += transcription.lower() | |
transcript += correct_casing(transcription.lower()) + ". " | |
# transcript += " " | |
whole_text = transcript[:3800] | |
whole_text = add_question_symbols(whole_text) | |
return whole_text | |
def parse_final_answer(answer_text: str, contexts: List): | |
"""Parse the final answer into correct format""" | |
answer = f"<p><b>{answer_text}</b></p> \n\n\n" | |
docs = ( | |
"\n".join( | |
[ | |
("""<p style="text-align: justify;">""" + context)[:250] | |
+ "[...]</p>" | |
for context in contexts[:5] | |
] | |
) | |
) | |
return answer, docs | |