from datasets import load_dataset from transformers import ( DPRQuestionEncoder, DPRQuestionEncoderTokenizer, MT5ForConditionalGeneration, AutoTokenizer, AutoModelForCTC, Wav2Vec2Tokenizer, ) from general_utils import ( embed_questions, transcript, remove_chars_to_tts, parse_final_answer, ) from typing import List import gradio as gr from article_app import article, description, examples from haystack.nodes import DensePassageRetriever from haystack.document_stores import InMemoryDocumentStore import numpy as np from sentence_transformers import SentenceTransformer, util, CrossEncoder topk = 21 minchars = 200 min_snippet_length = 20 device = "cpu" covidterms = ["covid19", "covid", "coronavirus", "covid-19", "sars-cov-2"] models = { "wav2vec2-iic": { "processor": Wav2Vec2Tokenizer.from_pretrained( "IIC/wav2vec2-spanish-multilibrispeech" ), "model": AutoModelForCTC.from_pretrained( "IIC/wav2vec2-spanish-multilibrispeech" ), }, } tts_es = gr.Interface.load("huggingface/facebook/tts_transformer-es-css10") params_generate = { "min_length": 50, "max_length": 250, "do_sample": False, "early_stopping": True, "num_beams": 8, "temperature": 1.0, "top_k": None, "top_p": None, "no_repeat_ngram_size": 3, "num_return_sequences": 1, } dpr = DensePassageRetriever( document_store=InMemoryDocumentStore(), query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", passage_embedding_model="IIC/dpr-spanish-passage_encoder-allqa-base", max_seq_len_query=64, max_seq_len_passage=256, batch_size=512, use_gpu=False, ) mt5_tokenizer = AutoTokenizer.from_pretrained("IIC/mt5-base-lfqa-es") mt5_lfqa = MT5ForConditionalGeneration.from_pretrained("IIC/mt5-base-lfqa-es") similarity_model = SentenceTransformer( "distiluse-base-multilingual-cased", device="cpu" ) crossencoder = CrossEncoder("IIC/roberta-base-bne-ranker", device="cpu") dataset = load_dataset("IIC/spanish_biomedical_crawled_corpus", split="train") dataset = dataset.filter(lambda example: len(example["text"]) > minchars) dataset.load_faiss_index( "embeddings", "dpr_index_bio_newdpr.faiss", ) def query_index(question: str): question_embedding = dpr.embed_queries([question])[0] scores, closest_passages = dataset.get_nearest_examples( "embeddings", question_embedding, k=topk ) contexts = [ closest_passages["text"][i] for i in range(len(closest_passages["text"])) ] # [:int(topk / 3)] return [ context for context in contexts if len(context.split()) > min_snippet_length ] def sort_on_similarity(question, contexts, include_rank: int = 5): # TODO: METER AQUÍ EL CROSSENCODER nuestro question_encoded = similarity_model.encode([question])[0] ctxs_encoded = similarity_model.encode(contexts) similarity_scores = [ util.cos_sim(question_encoded, ctx_encoded) for ctx_encoded in ctxs_encoded ] similarity_ranking_idx = np.flip(np.argsort(similarity_scores)) return [contexts[idx] for idx in similarity_ranking_idx][:include_rank] def create_context(contexts: List): return "
" + "
".join(contexts) def create_model_input(question: str, context: str): return f"question: {question} context: {context}" def generate_answer(model_input, update_params): model_input = mt5_tokenizer( model_input, truncation=True, padding=True, return_tensors="pt", max_length=1024 ) params_generate.update(update_params) answers_encoded = mt5_lfqa.generate( input_ids=model_input["input_ids"].to(device), attention_mask=model_input["attention_mask"].to(device), **params_generate, ) answers = mt5_tokenizer.batch_decode( answers_encoded, skip_special_tokens=True, clean_up_tokenization_spaces=True ) results = [{"generated_text": answer} for answer in answers] return results def search_and_answer( question, audio_file, audio_array, min_length_answer, num_beams, no_repeat_ngram_size, temperature, max_answer_length, wav2vec2_name, do_tts, ): update_params = { "min_length": min_length_answer, "max_length": max_answer_length, "num_beams": int(num_beams), "temperature": temperature, "no_repeat_ngram_size": no_repeat_ngram_size, } if not question: s2t_model = models[wav2vec2_name]["model"] s2t_processor = models[wav2vec2_name]["processor"] question = transcript( audio_file, audio_array, processor=s2t_processor, model=s2t_model ) print(f"Transcripted question: *** {question} ****") if any([any([term in word.lower() for term in covidterms]) for word in question.split(" ")]): return "Del COVID no queremos saber ya más nada, lo sentimos, pregúntame sobre otra cosa :P ", "tmptdsnrh_8.flac" contexts = query_index(question) contexts = sort_on_similarity(question, contexts) context = create_context(contexts) model_input = create_model_input(question, context) answers = generate_answer(model_input, update_params) final_answer = answers[0]["generated_text"] if do_tts: audio_answer = tts_es(remove_chars_to_tts(final_answer)) final_answer = parse_final_answer(final_answer, contexts) return final_answer, audio_answer if do_tts else "audio_troll.flac" if __name__ == "__main__": gr.Interface( search_and_answer, inputs=[ gr.inputs.Textbox( lines=2, label="Question", placeholder="Type your question (in spanish) to the system.", optional=True, ), gr.inputs.Audio( source="upload", type="filepath", label="Upload your audio asking a question here.", optional=True, ), gr.inputs.Audio( source="microphone", type="numpy", label="Record your audio asking a question.", optional=True, ), gr.inputs.Slider( minimum=10, maximum=200, default=50, label="Minimum size for the answer", step=1, ), gr.inputs.Slider( minimum=4, maximum=12, default=8, label="number of beams", step=1 ), gr.inputs.Slider( minimum=2, maximum=5, default=3, label="no repeat n-gram size", step=1 ), gr.inputs.Slider( minimum=0.8, maximum=2.0, default=1.0, label="temperature", step=0.1 ), gr.inputs.Slider( minimum=220, maximum=360, default=250, label="maximum answer length", step=1, ), gr.inputs.Dropdown( ["wav2vec2-iic"], type="value", default=None, label="Select the speech recognition model.", optional=False, ), gr.inputs.Checkbox( default=False, label="Text to Speech", optional=True), ], outputs=[ gr.outputs.HTML( label="Answer from the system." ), gr.outputs.Audio(label="Answer in audio"), ], description=description, examples=examples, theme="grass", article=article, thumbnail="IIC_logoP.png", css="https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css", ).launch()