vtiyyal1 commited on
Commit
18f5c04
1 Parent(s): 7883776

Upload 9 files

Browse files

uploading chat files for gradio

Files changed (9) hide show
  1. README.md +6 -6
  2. app.py +35 -0
  3. feed_to_llm.py +101 -0
  4. feed_to_llm_v2.py +85 -0
  5. full_chain.py +33 -0
  6. get_articles.py +140 -0
  7. get_keywords.py +63 -0
  8. requirements.txt +12 -0
  9. rerank.py +278 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Tobacco Watcher Chat
3
- emoji: 🌖
4
- colorFrom: red
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: RAG Chatbot for the tobacco watcher website
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Manual Copy
3
+ emoji: 🐨
4
+ colorFrom: indigo
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.25.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chat_models import ChatOpenAI
2
+ from langchain.schema import AIMessage, HumanMessage
3
+ import openai
4
+ import gradio as gr
5
+ from full_chain import get_response
6
+
7
+ import os
8
+
9
+ api_key = os.getenv("OPENAI_API_KEY")
10
+ client = openai.OpenAI(api_key=api_key)
11
+
12
+
13
+ def create_hyperlink(url, title, domain):
14
+ return f"<a href='{url}'>{title}</a>" + " (" + domain + ")"
15
+
16
+
17
+ def predict(message, history):
18
+ print("get_responses: ")
19
+ # print(get_response(message, rerank_type="crossencoder"))
20
+ responder, links, titles, domains = get_response(message, rerank_type="crossencoder")
21
+ for i in range(len(links)):
22
+ links[i] = create_hyperlink(links[i], titles[i], domains[i])
23
+
24
+ out = responder + "\n" + "\n".join(links)
25
+
26
+ return out
27
+
28
+
29
+ gr.ChatInterface(predict,
30
+ examples = [
31
+ "How many Americans Smoke?",
32
+ "What are some measures taken by the Indian Government to reduce the smoking population?",
33
+ "Does smoking negatively affect my health?"
34
+ ]
35
+ ).launch()
feed_to_llm.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chat_models import ChatOpenAI
2
+
3
+ from langchain.schema import (
4
+ HumanMessage,
5
+ SystemMessage
6
+ )
7
+ import tiktoken
8
+ import re
9
+
10
+
11
+ def num_tokens_from_string(string: str, encoder) -> int:
12
+ num_tokens = len(encoder.encode(string))
13
+ return num_tokens
14
+
15
+
16
+ def feed_articles_to_gpt_with_links(information, question):
17
+ prompt = "The following pieces of information includes relevant articles. \nUse the following sentences to answer question. \nIf you don't know the answer, just say that you don't know, don't try to make up an answer. "
18
+ prompt += "Please state the number of the article used to answer the question after your response\n"
19
+ end_prompt = "\n----------------\n"
20
+ prompt += end_prompt
21
+ content = ""
22
+ seperator = "<<<<>>>>"
23
+
24
+ token_count = 0
25
+ encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
26
+ token_count += num_tokens_from_string(prompt, encoder)
27
+
28
+ articles = [contents for score, contents, uuids, titles, domains in information]
29
+ uuids = [uuids for score, contents, uuids, titles, domains in information]
30
+ domains = [domains for score, contents, uuids, titles, domains in information]
31
+
32
+ for i in range(len(articles)):
33
+ addition = "Article " + str(i + 1) + ": " + articles[i] + seperator
34
+ addition += articles[i] + seperator
35
+ token_count += num_tokens_from_string(addition, encoder)
36
+ if token_count > 3500:
37
+ print(i)
38
+ break
39
+
40
+ content += addition
41
+
42
+ prompt += content
43
+ llm = ChatOpenAI(temperature=0.0)
44
+ message = [
45
+ SystemMessage(content=prompt),
46
+ HumanMessage(content=question)
47
+ ]
48
+
49
+ response = llm(message)
50
+ print(response.content)
51
+ print("response length: ", len(response.content))
52
+
53
+ answer_found_prompt = "Please check if the following response found the answer. If yes, return 1 and if no, return 0. \n"
54
+ message = [
55
+ SystemMessage(content=answer_found_prompt),
56
+ HumanMessage(content=response.content)
57
+ ]
58
+ print(llm(message).content)
59
+ if llm(message).content == "0":
60
+ return "I could not find the answer.", [], [], []
61
+
62
+ # sources = "\n Sources: \n"
63
+ # for i in range(len(uuids)):
64
+ # link = "https://tobaccowatcher.globaltobaccocontrol.org/articles/" + uuids[i] + "/" + "\n"
65
+ # sources += link
66
+ # response.content += sources
67
+
68
+ lowercase_response = response.content.lower()
69
+ # remove parentheses
70
+ lowercase_response = re.sub('[()]', '', lowercase_response)
71
+ lowercase_split = lowercase_response.split()
72
+ used_article_num = []
73
+ for i in range(len(lowercase_split)):
74
+ if lowercase_split[i] == "article":
75
+ next_word = lowercase_split[i + 1]
76
+ # get rid of non-numenric characters
77
+ next_word = ''.join(c for c in next_word if c.isdigit())
78
+ print("Article number: ", next_word)
79
+ # append only if it is not present in the list
80
+ if next_word not in used_article_num:
81
+ used_article_num.append(next_word)
82
+
83
+ # if empty
84
+ print("Used article num: ", used_article_num)
85
+ if not used_article_num:
86
+ print("I could not find the answer. Reached")
87
+ return "I could not find the answer.", [], [], []
88
+
89
+ used_article_num = [int(num) - 1 for num in used_article_num]
90
+
91
+ links = [f"https://tobaccowatcher.globaltobaccocontrol.org/articles/{uuid}/" for uuid in uuids]
92
+ titles = [titles for score, contents, uuids, titles, domains in information]
93
+
94
+ links = [links[i] for i in used_article_num]
95
+ titles = [titles[i] for i in used_article_num]
96
+ domains = [domains[i] for i in used_article_num]
97
+
98
+ # get rid of substring that starts with (Article and ends with )
99
+ response_without_source = re.sub("""\(Article.*\)""", "", response.content)
100
+
101
+ return response_without_source, links, titles, domains
feed_to_llm_v2.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import OpenAI
2
+
3
+ from langchain.schema import (
4
+ HumanMessage,
5
+ SystemMessage
6
+ )
7
+ import tiktoken
8
+ import re
9
+
10
+ from get_articles import save_solr_articles_full
11
+ from rerank import crossencoder_rerank_answer
12
+
13
+
14
+ def num_tokens_from_string(string: str, encoder) -> int:
15
+ num_tokens = len(encoder.encode(string))
16
+ return num_tokens
17
+
18
+
19
+ def feed_articles_to_gpt_with_links(information, question):
20
+ prompt = """
21
+ You are a Question Answering machine specialized in providing information on tobacco-related queries. You have access to a curated list of articles that span various aspects of tobacco use, health effects, legislation, and quitting resources. When responding to questions, follow these guidelines:
22
+
23
+ 1. Use information from the articles to formulate your answers. Indicate the article number you're referencing at the end of your response.
24
+ 2. If the question's answer is not covered by your articles, clearly state that you do not know the answer. Do not attempt to infer or make up information.
25
+ 3. Avoid using time-relative terms like 'last year,' 'recently,' etc., as the articles' publication dates and the current date may not align. Instead, use absolute terms (e.g., 'In 2022,' 'As of the article's 2020 publication,').
26
+ 4. Aim for concise, informative responses that directly address the question asked.
27
+
28
+ Remember, your goal is to provide accurate, helpful information on tobacco-related topics, aiding in education and informed decision-making.
29
+ """
30
+ end_prompt = "\n----------------\n"
31
+ prompt += end_prompt
32
+ content = ""
33
+ seperator = "<<<<>>>>"
34
+
35
+ token_count = 0
36
+ encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
37
+ token_count += num_tokens_from_string(prompt, encoder)
38
+
39
+ articles = [contents for score, contents, uuids, titles, domains in information]
40
+ uuids = [uuids for score, contents, uuids, titles, domains in information]
41
+ domains = [domains for score, contents, uuids, titles, domains in information]
42
+
43
+ for i in range(len(articles)):
44
+ addition = "Article " + str(i + 1) + ": " + articles[i] + seperator
45
+ addition += articles[i] + seperator
46
+ token_count += num_tokens_from_string(addition, encoder)
47
+ if token_count > 3500:
48
+ print(i)
49
+ break
50
+
51
+ content += addition
52
+
53
+ prompt += content
54
+ llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0.0)
55
+ message = [
56
+ SystemMessage(content=prompt),
57
+ HumanMessage(content=question)
58
+ ]
59
+
60
+ response = llm.invoke(message)
61
+ print(response)
62
+ print("response length:", len(response))
63
+ source = re.findall('\((.*?)\)', response)[-1]
64
+
65
+ # get integers from source
66
+ source = re.findall(r'\d+', source)
67
+ used_article_num = [int(i) - 1 for i in source]
68
+
69
+ links = [f"https://tobaccowatcher.globaltobaccocontrol.org/articles/{uuid}/" for uuid in uuids]
70
+ titles = [titles for score, contents, uuids, titles, domains in information]
71
+
72
+ links = [links[i] for i in used_article_num]
73
+ titles = [titles[i] for i in used_article_num]
74
+ domains = [domains[i] for i in used_article_num]
75
+
76
+ response_without_source = re.sub("""\(Article.*\)""", "", response)
77
+ return response_without_source, links, titles, domains
78
+
79
+ if __name__ == "__main__":
80
+ question = "How is United States fighting against tobacco addiction?"
81
+ rerank_type = "crossencoder"
82
+ llm_type = "chat"
83
+ csv_path = save_solr_articles_full(question, keyword_type="rake")
84
+ reranked_out = crossencoder_rerank_answer(csv_path, question)
85
+ feed_articles_to_gpt_with_links(reranked_out, question)
full_chain.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from get_keywords import get_keywords
4
+ from get_articles import save_solr_articles_full
5
+ from rerank import langchain_rerank_answer, langchain_with_sources, crossencoder_rerank_answer, \
6
+ crossencoder_rerank_sentencewise, crossencoder_rerank_sentencewise_articles, no_rerank
7
+ #from feed_to_llm import feed_articles_to_gpt_with_links
8
+ from feed_to_llm_v2 import feed_articles_to_gpt_with_links
9
+
10
+
11
+ def get_response(question, rerank_type="crossencoder", llm_type="chat"):
12
+ csv_path = save_solr_articles_full(question, keyword_type="rake")
13
+ reranked_out = crossencoder_rerank_answer(csv_path, question)
14
+ return feed_articles_to_gpt_with_links(reranked_out, question)
15
+
16
+
17
+ # save_path = save_solr_articles_full(question)
18
+ # information = crossencoder_rerank_answer(save_path, question)
19
+ # response, links, titles = feed_articles_to_gpt_with_links(information, question)
20
+ #
21
+ # return response, links, titles
22
+
23
+
24
+
25
+ if __name__ == "__main__":
26
+ question = "How is United States fighting against tobacco addiction?"
27
+ rerank_type = "crossencoder"
28
+ llm_type = "chat"
29
+ response, links, titles, domains = get_response(question, rerank_type, llm_type)
30
+ print(response)
31
+ print(links)
32
+ print(titles)
33
+ print(domains)
get_articles.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pysolr import Solr
2
+ import os
3
+ import csv
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import torch
6
+
7
+ from get_keywords import get_keywords
8
+ import os
9
+
10
+ """
11
+ This function creates top 15 articles from Solr and saves them in a csv file
12
+ Input:
13
+ query: str
14
+ num_articles: int
15
+ keyword_type: str (openai, rake, or na)
16
+ Output: path to csv file
17
+ """
18
+ def save_solr_articles_full(query: str, num_articles=15, keyword_type="openai") -> str:
19
+ keywords = get_keywords(query, keyword_type)
20
+ if keyword_type == "na":
21
+ keywords = query
22
+ return save_solr_articles(keywords, num_articles)
23
+
24
+
25
+ """
26
+ Removes spaces and newlines from text
27
+ Input: text: str
28
+ Output: text: str
29
+ """
30
+ def remove_spaces_newlines(text: str) -> str:
31
+ text = text.replace('\n', ' ')
32
+ text = text.replace(' ', ' ')
33
+ return text
34
+
35
+
36
+ # truncates long articles to 1500 words
37
+ def truncate_article(text: str) -> str:
38
+ split = text.split()
39
+ if len(split) > 1500:
40
+ split = split[:1500]
41
+ text = ' '.join(split)
42
+ return text
43
+
44
+
45
+ """
46
+ Searches Solr for articles based on keywords and saves them in a csv file
47
+ Input:
48
+ keywords: str
49
+ num_articles: int
50
+ Output: path to csv file
51
+ Minor details:
52
+ Removes duplicate articles to start with.
53
+ Articles with dead urls are removed since those articles are often wierd.
54
+ Articles with titles that start with five starting words are removed. they are usually duplicates with minor changes.
55
+ If one of title, uuid, cleaned_content, url are missing the article is skipped.
56
+ """
57
+ def save_solr_articles(keywords: str, num_articles=15) -> str:
58
+ solr_key = os.getenv("SOLR_KEY")
59
+ SOLR_ARTICLES_URL = f"https://website:{solr_key}@solr.machines.globalhealthwatcher.org:8080/solr/articles/"
60
+ solr = Solr(SOLR_ARTICLES_URL, verify=False)
61
+
62
+ # No duplicates
63
+ fq = ['-dups:0']
64
+
65
+ query = f'text:({keywords})' + " AND " + "dead_url:(false)"
66
+
67
+ # Get top 2*num_articles articles and then remove misformed or duplicate articles
68
+ outputs = solr.search(query, fq=fq, sort="score desc", rows=num_articles * 2)
69
+
70
+ article_count = 0
71
+
72
+ save_path = os.path.join("data", "articles.csv")
73
+ if not os.path.exists(os.path.dirname(save_path)):
74
+ os.makedirs(os.path.dirname(save_path))
75
+
76
+ with open(save_path, 'w', newline='') as csvfile:
77
+ fieldnames = ['title', 'uuid', 'content', 'url', 'domain']
78
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC)
79
+ writer.writeheader()
80
+
81
+ title_five_words = set()
82
+
83
+ for d in outputs.docs:
84
+ if article_count == num_articles:
85
+ break
86
+
87
+ # skip if title returns a keyerror
88
+ if 'title' not in d or 'uuid' not in d or 'cleaned_content' not in d or 'url' not in d:
89
+ continue
90
+
91
+ title_cleaned = remove_spaces_newlines(d['title'])
92
+
93
+ split = title_cleaned.split()
94
+ # skip if title is a duplicate
95
+ if not len(split) < 5:
96
+ five_words = title_cleaned.split()[:5]
97
+ five_words = ' '.join(five_words)
98
+ if five_words in title_five_words:
99
+ continue
100
+ title_five_words.add(five_words)
101
+
102
+ article_count += 1
103
+
104
+ cleaned_content = remove_spaces_newlines(d['cleaned_content'])
105
+ cleaned_content = truncate_article(cleaned_content)
106
+
107
+ domain = ""
108
+ if 'domain' not in d:
109
+ domain = "Not Specified"
110
+ else:
111
+ domain = d['domain']
112
+ print(domain)
113
+
114
+ writer.writerow({'title': title_cleaned, 'uuid': d['uuid'], 'content': cleaned_content, 'url': d['url'],
115
+ 'domain': domain})
116
+ return save_path
117
+
118
+
119
+ def save_embedding_base_articles(query, article_embeddings, titles, contents, uuids, urls, num_articles=15):
120
+ bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
121
+ query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
122
+ hits = util.semantic_search(query_embedding, article_embeddings, top_k=15)
123
+ hits = hits[0]
124
+ corpus_ids = [item['corpus_id'] for item in hits]
125
+ r_contents = [contents[idx] for idx in corpus_ids]
126
+ r_titles = [titles[idx] for idx in corpus_ids]
127
+ r_uuids = [uuids[idx] for idx in corpus_ids]
128
+ r_urls = [urls[idx] for idx in corpus_ids]
129
+
130
+ save_path = os.path.join("data", "articles.csv")
131
+ if not os.path.exists(os.path.dirname(save_path)):
132
+ os.makedirs(os.path.dirname(save_path))
133
+
134
+ with open(save_path, 'w', newline='', encoding="utf-8") as csvfile:
135
+ fieldNames = ['title', 'uuid', 'content', 'url']
136
+ writer = csv.DictWriter(csvfile, fieldnames=fieldNames, quoting=csv.QUOTE_NONNUMERIC)
137
+ writer.writeheader()
138
+ for i in range(num_articles):
139
+ writer.writerow({'title': r_titles[i], 'uuid': r_uuids[i], 'content': r_contents[i], 'url': r_urls[i]})
140
+ return save_path
get_keywords.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chat_models import ChatOpenAI
2
+ from langchain.schema import (
3
+ HumanMessage,
4
+ SystemMessage
5
+ )
6
+
7
+ from rake_nltk import Rake
8
+ import nltk
9
+ nltk.download('stopwords')
10
+ nltk.download('punkt')
11
+ """
12
+ This function takes in user query and returns keywords
13
+ Input:
14
+ user_query: str
15
+ keyword_type: str (openai, rake, or na)
16
+ If the keyword type is na, then user query is returned.
17
+ Output: keywords: str
18
+ """
19
+ def get_keywords(user_query: str, keyword_type: str) -> str:
20
+ if keyword_type == "openai":
21
+ return get_keywords_openai(user_query)
22
+ if keyword_type == "rake":
23
+ return get_keywords_rake(user_query)
24
+ else:
25
+ return user_query
26
+
27
+
28
+ """
29
+ This function takes user query and returns keywords using rake_nltk
30
+ rake_nltk actually returns keyphrases, not keywords. Since using keyphrases did not show improvement, we are using keywords
31
+ to match the output type of the other keyword functions.
32
+ Input:
33
+ user_query: str
34
+ Output: keywords: str
35
+ """
36
+ def get_keywords_rake(user_query: str) -> str:
37
+ r = Rake()
38
+ r.extract_keywords_from_text(user_query)
39
+ keyphrases = r.get_ranked_phrases()
40
+
41
+ # If we want to get keyphrases, return keyphrases but should do keywords
42
+ out = ""
43
+ for phrase in keyphrases:
44
+ out += phrase + " "
45
+ return out
46
+
47
+
48
+ """
49
+ This function takes user query and returns keywords using openai
50
+ Input:
51
+ user_query: str
52
+ Output: keywords: str
53
+ """
54
+ def get_keywords_openai(user_query: str) -> str:
55
+ llm = ChatOpenAI(temperature=0.0)
56
+ command = "return the keywords of the following query. response should be words separated by commas. "
57
+ message = [
58
+ SystemMessage(content=command),
59
+ HumanMessage(content=user_query)
60
+ ]
61
+ response = llm(message)
62
+ res = response.content.replace(",", "")
63
+ return res
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.25.0
2
+ langchain==0.1.14
3
+ langchain_core==0.1.40
4
+ langchain_openai==0.1.1
5
+ nltk==3.8.1
6
+ openai==1.16.2
7
+ pandas==2.2.1
8
+ pysolr==3.9.0
9
+ rake_nltk==1.0.6
10
+ sentence_transformers==2.2.2
11
+ tiktoken==0.5.2
12
+ torch==2.1.2
rerank.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # reranks the top articles from a given csv file
2
+
3
+ from langchain.chat_models import ChatOpenAI
4
+ from langchain.chains import RetrievalQA
5
+ from langchain.chat_models import ChatOpenAI
6
+ from langchain.document_loaders import CSVLoader
7
+ from langchain.indexes import VectorstoreIndexCreator
8
+ from langchain.vectorstores import DocArrayInMemorySearch
9
+ from sentence_transformers import CrossEncoder, util
10
+ from langchain.chains import RetrievalQAWithSourcesChain
11
+ from nltk import sent_tokenize
12
+ import pandas as pd
13
+ import time
14
+
15
+ """
16
+ This function rerank top articles (15 -> 4) from a given csv, then sends to LLM
17
+ Input:
18
+ csv_path: str
19
+ question: str
20
+ top_n: int
21
+ Output:
22
+ response: str
23
+ links: list of str
24
+ titles: list of str
25
+
26
+ Other functions in this file does not send articles to LLM. This is an exception.
27
+ Created using langchain RAG functions. Deprecated.
28
+ Update: Use langchain_RAG instead.
29
+ """
30
+
31
+
32
+ def langchain_rerank_answer(csv_path, question, source='url', top_n=4):
33
+ llm = ChatOpenAI(temperature=0.0)
34
+ loader = CSVLoader(csv_path, source_column="url")
35
+
36
+ index = VectorstoreIndexCreator(
37
+ vectorstore_cls=DocArrayInMemorySearch,
38
+ ).from_loaders([loader])
39
+
40
+ # prompt_template = """You are an a chatbot that answers tobacco related questions with source. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
41
+ # {context}
42
+ # Question: {question}"""
43
+ # PROMPT = PromptTemplate(
44
+ # template=prompt_template, input_variables=["context", "question"]
45
+ # )
46
+ # chain_type_kwargs = {"prompt": PROMPT}
47
+
48
+ qa = RetrievalQA.from_chain_type(
49
+ llm=llm,
50
+ chain_type="stuff",
51
+ retriever=index.vectorstore.as_retriever(),
52
+ verbose=False,
53
+ return_source_documents=True,
54
+ # chain_type_kwargs=chain_type_kwargs,
55
+ # chain_type_kwargs = {
56
+ # "document_separator": "<<<<>>>>>"
57
+ # },
58
+ )
59
+
60
+ answer = qa({"query": question})
61
+ sources = answer['source_documents']
62
+ sources_out = [source.metadata['source'] for source in sources]
63
+
64
+ return answer['result'], sources_out
65
+
66
+
67
+ """
68
+ Langchain with sources.
69
+ This function is deprecated. Use langchain_RAG instead.
70
+ """
71
+
72
+
73
+ def langchain_with_sources(csv_path, question, top_n=4):
74
+ llm = ChatOpenAI(temperature=0.0)
75
+ loader = CSVLoader(csv_path, source_column="uuid")
76
+ index = VectorstoreIndexCreator(
77
+ vectorstore_cls=DocArrayInMemorySearch,
78
+ ).from_loaders([loader])
79
+
80
+ qa = RetrievalQAWithSourcesChain.from_chain_type(
81
+ llm=llm,
82
+ chain_type="stuff",
83
+ retriever=index.vectorstore.as_retriever(),
84
+ )
85
+ output = qa({"question": question}, return_only_outputs=True)
86
+ return output['answer'], output['sources']
87
+
88
+
89
+ """
90
+ Reranks the top articles using crossencoder.
91
+ Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for embedding / reranking.
92
+ Input:
93
+ csv_path: str
94
+ question: str
95
+ top_n: int
96
+ Output:
97
+ out_values: list of [content, uuid, title]
98
+ """
99
+
100
+
101
+ # returns list of top n similar articles using crossencoder
102
+ def crossencoder_rerank_answer(csv_path: str, question: str, top_n=4) -> list:
103
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
104
+ articles = pd.read_csv(csv_path)
105
+ contents = articles['content'].tolist()
106
+ uuids = articles['uuid'].tolist()
107
+ titles = articles['title'].tolist()
108
+
109
+ # biencoder retrieval does not have domain
110
+ if 'domain' not in articles:
111
+ domain = [""] * len(contents)
112
+ else:
113
+ domain = articles['domain'].tolist()
114
+
115
+ cross_inp = [[question, content] for content in contents]
116
+ cross_scores = cross_encoder.predict(cross_inp)
117
+ scores_sentences = list(zip(cross_scores, contents, uuids, titles, domain))
118
+ scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
119
+
120
+ out_values = scores_sentences[:top_n]
121
+
122
+ # if score is less than 0, truncate
123
+ for idx in range(len(out_values)):
124
+ if out_values[idx][0] < 0:
125
+ out_values = out_values[:idx]
126
+ if len(out_values) == 0:
127
+ out_values = scores_sentences[:1]
128
+
129
+ break
130
+ # print(out_values)
131
+ return out_values
132
+
133
+
134
+ def crossencoder_rerank_sentencewise(csv_path: str, question: str, top_n=10) -> list:
135
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
136
+ articles = pd.read_csv(csv_path)
137
+ contents = articles['content'].tolist()
138
+ uuids = articles['uuid'].tolist()
139
+ titles = articles['title'].tolist()
140
+
141
+ if 'domain' not in articles:
142
+ domain = [""] * len(contents)
143
+ else:
144
+ domain = articles['domain'].tolist()
145
+
146
+ sentences = []
147
+ new_uuids = []
148
+ new_titles = []
149
+ new_domains = []
150
+ for idx in range(len(contents)):
151
+ sents = sent_tokenize(contents[idx])
152
+ sentences.extend(sents)
153
+ new_uuids.extend([uuids[idx]] * len(sents))
154
+ new_titles.extend([titles[idx]] * len(sents))
155
+ new_domains.extend([domain[idx]] * len(sents))
156
+
157
+ cross_inp = [[question, sent] for sent in sentences]
158
+ cross_scores = cross_encoder.predict(cross_inp)
159
+ scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
160
+ scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
161
+
162
+ out_values = scores_sentences[:top_n]
163
+
164
+ # if score is less than 0, truncate
165
+ for idx in range(len(out_values)):
166
+ if out_values[idx][0] < 0:
167
+ out_values = out_values[:idx]
168
+ if len(out_values) == 0:
169
+ out_values = scores_sentences[:1]
170
+
171
+ break
172
+
173
+ return out_values
174
+
175
+
176
+ def crossencoder_rerank_sentencewise_sentence_chunks(csv_path, question, top_n=10, chunk_size=2):
177
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
178
+ articles = pd.read_csv(csv_path)
179
+ contents = articles['content'].tolist()
180
+ uuids = articles['uuid'].tolist()
181
+ titles = articles['title'].tolist()
182
+
183
+ # embeddings do not have domain as column
184
+ if 'domain' not in articles:
185
+ domain = [""] * len(contents)
186
+ else:
187
+ domain = articles['domain'].tolist()
188
+
189
+ sentences = []
190
+ new_uuids = []
191
+ new_titles = []
192
+ new_domains = []
193
+
194
+ for idx in range(len(contents)):
195
+ sents = sent_tokenize(contents[idx])
196
+ sents_merged = []
197
+
198
+ # if the number of sentences is less than chunk size, merge and join
199
+ if len(sents) < chunk_size:
200
+ sents_merged.append(' '.join(sents))
201
+ else:
202
+ for i in range(0, len(sents) - chunk_size + 1):
203
+ sents_merged.append(' '.join(sents[i:i + chunk_size]))
204
+
205
+ sentences.extend(sents_merged)
206
+ new_uuids.extend([uuids[idx]] * len(sents_merged))
207
+ new_titles.extend([titles[idx]] * len(sents_merged))
208
+ new_domains.extend([domain[idx]] * len(sents_merged))
209
+
210
+ cross_inp = [[question, sent] for sent in sentences]
211
+ cross_scores = cross_encoder.predict(cross_inp)
212
+ scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
213
+ scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
214
+
215
+ out_values = scores_sentences[:top_n]
216
+
217
+ for idx in range(len(out_values)):
218
+ if out_values[idx][0] < 0:
219
+ out_values = out_values[:idx]
220
+ if len(out_values) == 0:
221
+ out_values = scores_sentences[:1]
222
+
223
+ break
224
+
225
+ return out_values
226
+
227
+
228
+ def crossencoder_rerank_sentencewise_articles(csv_path, question, top_n=4):
229
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
230
+ contents, uuids, titles, domain = load_articles(csv_path)
231
+
232
+ sentences = []
233
+ contents_elongated = []
234
+ new_uuids = []
235
+ new_titles = []
236
+ new_domains = []
237
+
238
+ for idx in range(len(contents)):
239
+ sents = sent_tokenize(contents[idx])
240
+ sentences.extend(sents)
241
+ new_uuids.extend([uuids[idx]] * len(sents))
242
+ contents_elongated.extend([contents[idx]] * len(sents))
243
+ new_titles.extend([titles[idx]] * len(sents))
244
+ new_domains.extend([domain[idx]] * len(sents))
245
+
246
+ cross_inp = [[question, sent] for sent in sentences]
247
+ cross_scores = cross_encoder.predict(cross_inp)
248
+ scores_sentences = list(zip(cross_scores, contents_elongated, new_uuids, new_titles, new_domains))
249
+ scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
250
+
251
+ score_sentences_compressed = []
252
+ for item in scores_sentences:
253
+ if not score_sentences_compressed:
254
+ score_sentences_compressed.append(item)
255
+ else:
256
+ if item[2] not in [x[2] for x in score_sentences_compressed]:
257
+ score_sentences_compressed.append(item)
258
+
259
+ scores_sentences = score_sentences_compressed
260
+ return scores_sentences[:top_n]
261
+
262
+
263
+ def no_rerank(csv_path, question, top_n=4):
264
+ contents, uuids, titles, domains = load_articles(csv_path)
265
+ return list(zip(contents, uuids, titles, domains))[:top_n]
266
+
267
+
268
+ def load_articles(csv_path:str):
269
+ articles = pd.read_csv(csv_path)
270
+ contents = articles['content'].tolist()
271
+ uuids = articles['uuid'].tolist()
272
+ titles = articles['title'].tolist()
273
+ if 'domain' not in articles:
274
+ domain = [""] * len(contents)
275
+ else:
276
+ domain = articles['domain'].tolist()
277
+ return contents, uuids, titles, domain
278
+