BounWiki / setup_modules.py
LeoGitGuy
added files
9bf0a0f
raw
history blame
2.41 kB
from haystack.nodes.retriever import EmbeddingRetriever
from haystack.nodes import TableReader, FARMReader, RouteDocuments, JoinAnswers
from haystack import Pipeline
text_reader_types = {
"minilm": "deepset/minilm-uncased-squad2",
"distilroberta": "deepset/tinyroberta-squad2",
"electra-base": "deepset/electra-base-squad2",
"bert-base": "deepset/bert-base-cased-squad2",
"deberta-large": "deepset/deberta-v3-large-squad2",
"gpt3": "implement openai answer generator"
}
table_reader_types = {
"tapas": "deepset/tapas-large-nq-hn-reader",
"text": "implement changing tables to text"
}
def create_retriever(document_store):
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/all-mpnet-base-v2-table")
document_store.update_embeddings(retriever=retriever)
return document_store, retriever
def create_readers_and_pipeline(retriever, text_reader_type = "deepset/roberta-base-squad2", table_reader_type="deepset/tapas-large-nq-hn-reader", use_table=True, use_text=True):
both = (use_table and use_text)
if use_text or both:
print("Initializing Text reader..")
text_reader = FARMReader(text_reader_type)
if use_table or both:
print("Initializing table reader..")
table_reader = TableReader(table_reader_type)
if both:
route_documents = RouteDocuments()
join_answers = JoinAnswers()
text_table_qa_pipeline = Pipeline()
text_table_qa_pipeline.add_node(component=retriever, name="EmbeddingRetriever", inputs=["Query"])
if use_table and not use_text:
text_table_qa_pipeline.add_node(component=table_reader, name="TableReader", inputs=["EmbeddingRetriever"])
elif use_text and not use_table:
text_table_qa_pipeline.add_node(component=text_reader, name="TextReader", inputs=["EmbeddingRetriever"])
elif both:
text_table_qa_pipeline.add_node(component=route_documents, name="RouteDocuments", inputs=["EmbeddingRetriever"])
text_table_qa_pipeline.add_node(component=text_reader, name="TextReader", inputs=["RouteDocuments.output_1"])
text_table_qa_pipeline.add_node(component=table_reader, name="TableReader", inputs=["RouteDocuments.output_2"])
text_table_qa_pipeline.add_node(component=join_answers, name="JoinAnswers", inputs=["TextReader", "TableReader"])
return text_table_qa_pipeline