notSoNLPnerd commited on
Commit
94aee35
1 Parent(s): bd2e0e7

missed commits

Browse files
Files changed (2) hide show
  1. logo/haystack-logo-colored.png +0 -0
  2. utils/backend.py +67 -0
logo/haystack-logo-colored.png ADDED
utils/backend.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from haystack import Pipeline
3
+ from haystack.document_stores import FAISSDocumentStore
4
+ from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever
5
+ from haystack.nodes.retriever.web import WebRetriever
6
+
7
+
8
+ @st.cache_resource(show_spinner=False)
9
+ def get_plain_pipeline():
10
+ prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])
11
+ # Now let make one PromptNode use the default model and the other one the OpenAI model:
12
+ plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: $query")
13
+ node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)
14
+ pipeline = Pipeline()
15
+ pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
16
+ return pipeline
17
+
18
+
19
+ @st.cache_resource(show_spinner=False)
20
+ def get_retrieval_augmented_pipeline():
21
+ ds = FAISSDocumentStore(faiss_index_path="data/my_faiss_index.faiss",
22
+ faiss_config_path="data/my_faiss_index.json")
23
+
24
+ retriever = EmbeddingRetriever(
25
+ document_store=ds,
26
+ embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
27
+ model_format="sentence_transformers",
28
+ top_k=2
29
+ )
30
+ shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
31
+
32
+ default_template = PromptTemplate(
33
+ name="question-answering",
34
+ prompt_text="Given the context please answer the question. Context: $documents; Question: "
35
+ "$query; Answer:",
36
+ )
37
+ # Let's initiate the PromptNode
38
+ node = PromptNode("text-davinci-003", default_prompt_template=default_template,
39
+ api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
40
+
41
+ # Let's create a pipeline with Shaper and PromptNode
42
+ pipeline = Pipeline()
43
+ pipeline.add_node(component=retriever, name='retriever', inputs=['Query'])
44
+ pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"])
45
+ pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"])
46
+ return pipeline
47
+
48
+
49
+ @st.cache_resource(show_spinner=False)
50
+ def get_web_retrieval_augmented_pipeline():
51
+ search_key = st.secrets["WEBRET_API_KEY"]
52
+ web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
53
+ shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
54
+ default_template = PromptTemplate(
55
+ name="question-answering",
56
+ prompt_text="Given the context please answer the question. Context: $documents; Question: "
57
+ "$query; Answer:",
58
+ )
59
+ # Let's initiate the PromptNode
60
+ node = PromptNode("text-davinci-003", default_prompt_template=default_template,
61
+ api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
62
+ # Let's create a pipeline with Shaper and PromptNode
63
+ pipeline = Pipeline()
64
+ pipeline.add_node(component=web_retriever, name='retriever', inputs=['Query'])
65
+ pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"])
66
+ pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"])
67
+ return pipeline