Spaces:
Running
Running
Upload 9 files
Browse filesuploading chat files for gradio
- README.md +6 -6
- app.py +35 -0
- feed_to_llm.py +101 -0
- feed_to_llm_v2.py +85 -0
- full_chain.py +33 -0
- get_articles.py +140 -0
- get_keywords.py +63 -0
- requirements.txt +12 -0
- rerank.py +278 -0
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Manual Copy
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.25.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: mit
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chat_models import ChatOpenAI
|
2 |
+
from langchain.schema import AIMessage, HumanMessage
|
3 |
+
import openai
|
4 |
+
import gradio as gr
|
5 |
+
from full_chain import get_response
|
6 |
+
|
7 |
+
import os
|
8 |
+
|
9 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
10 |
+
client = openai.OpenAI(api_key=api_key)
|
11 |
+
|
12 |
+
|
13 |
+
def create_hyperlink(url, title, domain):
|
14 |
+
return f"<a href='{url}'>{title}</a>" + " (" + domain + ")"
|
15 |
+
|
16 |
+
|
17 |
+
def predict(message, history):
|
18 |
+
print("get_responses: ")
|
19 |
+
# print(get_response(message, rerank_type="crossencoder"))
|
20 |
+
responder, links, titles, domains = get_response(message, rerank_type="crossencoder")
|
21 |
+
for i in range(len(links)):
|
22 |
+
links[i] = create_hyperlink(links[i], titles[i], domains[i])
|
23 |
+
|
24 |
+
out = responder + "\n" + "\n".join(links)
|
25 |
+
|
26 |
+
return out
|
27 |
+
|
28 |
+
|
29 |
+
gr.ChatInterface(predict,
|
30 |
+
examples = [
|
31 |
+
"How many Americans Smoke?",
|
32 |
+
"What are some measures taken by the Indian Government to reduce the smoking population?",
|
33 |
+
"Does smoking negatively affect my health?"
|
34 |
+
]
|
35 |
+
).launch()
|
feed_to_llm.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chat_models import ChatOpenAI
|
2 |
+
|
3 |
+
from langchain.schema import (
|
4 |
+
HumanMessage,
|
5 |
+
SystemMessage
|
6 |
+
)
|
7 |
+
import tiktoken
|
8 |
+
import re
|
9 |
+
|
10 |
+
|
11 |
+
def num_tokens_from_string(string: str, encoder) -> int:
|
12 |
+
num_tokens = len(encoder.encode(string))
|
13 |
+
return num_tokens
|
14 |
+
|
15 |
+
|
16 |
+
def feed_articles_to_gpt_with_links(information, question):
|
17 |
+
prompt = "The following pieces of information includes relevant articles. \nUse the following sentences to answer question. \nIf you don't know the answer, just say that you don't know, don't try to make up an answer. "
|
18 |
+
prompt += "Please state the number of the article used to answer the question after your response\n"
|
19 |
+
end_prompt = "\n----------------\n"
|
20 |
+
prompt += end_prompt
|
21 |
+
content = ""
|
22 |
+
seperator = "<<<<>>>>"
|
23 |
+
|
24 |
+
token_count = 0
|
25 |
+
encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
26 |
+
token_count += num_tokens_from_string(prompt, encoder)
|
27 |
+
|
28 |
+
articles = [contents for score, contents, uuids, titles, domains in information]
|
29 |
+
uuids = [uuids for score, contents, uuids, titles, domains in information]
|
30 |
+
domains = [domains for score, contents, uuids, titles, domains in information]
|
31 |
+
|
32 |
+
for i in range(len(articles)):
|
33 |
+
addition = "Article " + str(i + 1) + ": " + articles[i] + seperator
|
34 |
+
addition += articles[i] + seperator
|
35 |
+
token_count += num_tokens_from_string(addition, encoder)
|
36 |
+
if token_count > 3500:
|
37 |
+
print(i)
|
38 |
+
break
|
39 |
+
|
40 |
+
content += addition
|
41 |
+
|
42 |
+
prompt += content
|
43 |
+
llm = ChatOpenAI(temperature=0.0)
|
44 |
+
message = [
|
45 |
+
SystemMessage(content=prompt),
|
46 |
+
HumanMessage(content=question)
|
47 |
+
]
|
48 |
+
|
49 |
+
response = llm(message)
|
50 |
+
print(response.content)
|
51 |
+
print("response length: ", len(response.content))
|
52 |
+
|
53 |
+
answer_found_prompt = "Please check if the following response found the answer. If yes, return 1 and if no, return 0. \n"
|
54 |
+
message = [
|
55 |
+
SystemMessage(content=answer_found_prompt),
|
56 |
+
HumanMessage(content=response.content)
|
57 |
+
]
|
58 |
+
print(llm(message).content)
|
59 |
+
if llm(message).content == "0":
|
60 |
+
return "I could not find the answer.", [], [], []
|
61 |
+
|
62 |
+
# sources = "\n Sources: \n"
|
63 |
+
# for i in range(len(uuids)):
|
64 |
+
# link = "https://tobaccowatcher.globaltobaccocontrol.org/articles/" + uuids[i] + "/" + "\n"
|
65 |
+
# sources += link
|
66 |
+
# response.content += sources
|
67 |
+
|
68 |
+
lowercase_response = response.content.lower()
|
69 |
+
# remove parentheses
|
70 |
+
lowercase_response = re.sub('[()]', '', lowercase_response)
|
71 |
+
lowercase_split = lowercase_response.split()
|
72 |
+
used_article_num = []
|
73 |
+
for i in range(len(lowercase_split)):
|
74 |
+
if lowercase_split[i] == "article":
|
75 |
+
next_word = lowercase_split[i + 1]
|
76 |
+
# get rid of non-numenric characters
|
77 |
+
next_word = ''.join(c for c in next_word if c.isdigit())
|
78 |
+
print("Article number: ", next_word)
|
79 |
+
# append only if it is not present in the list
|
80 |
+
if next_word not in used_article_num:
|
81 |
+
used_article_num.append(next_word)
|
82 |
+
|
83 |
+
# if empty
|
84 |
+
print("Used article num: ", used_article_num)
|
85 |
+
if not used_article_num:
|
86 |
+
print("I could not find the answer. Reached")
|
87 |
+
return "I could not find the answer.", [], [], []
|
88 |
+
|
89 |
+
used_article_num = [int(num) - 1 for num in used_article_num]
|
90 |
+
|
91 |
+
links = [f"https://tobaccowatcher.globaltobaccocontrol.org/articles/{uuid}/" for uuid in uuids]
|
92 |
+
titles = [titles for score, contents, uuids, titles, domains in information]
|
93 |
+
|
94 |
+
links = [links[i] for i in used_article_num]
|
95 |
+
titles = [titles[i] for i in used_article_num]
|
96 |
+
domains = [domains[i] for i in used_article_num]
|
97 |
+
|
98 |
+
# get rid of substring that starts with (Article and ends with )
|
99 |
+
response_without_source = re.sub("""\(Article.*\)""", "", response.content)
|
100 |
+
|
101 |
+
return response_without_source, links, titles, domains
|
feed_to_llm_v2.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_openai import OpenAI
|
2 |
+
|
3 |
+
from langchain.schema import (
|
4 |
+
HumanMessage,
|
5 |
+
SystemMessage
|
6 |
+
)
|
7 |
+
import tiktoken
|
8 |
+
import re
|
9 |
+
|
10 |
+
from get_articles import save_solr_articles_full
|
11 |
+
from rerank import crossencoder_rerank_answer
|
12 |
+
|
13 |
+
|
14 |
+
def num_tokens_from_string(string: str, encoder) -> int:
|
15 |
+
num_tokens = len(encoder.encode(string))
|
16 |
+
return num_tokens
|
17 |
+
|
18 |
+
|
19 |
+
def feed_articles_to_gpt_with_links(information, question):
|
20 |
+
prompt = """
|
21 |
+
You are a Question Answering machine specialized in providing information on tobacco-related queries. You have access to a curated list of articles that span various aspects of tobacco use, health effects, legislation, and quitting resources. When responding to questions, follow these guidelines:
|
22 |
+
|
23 |
+
1. Use information from the articles to formulate your answers. Indicate the article number you're referencing at the end of your response.
|
24 |
+
2. If the question's answer is not covered by your articles, clearly state that you do not know the answer. Do not attempt to infer or make up information.
|
25 |
+
3. Avoid using time-relative terms like 'last year,' 'recently,' etc., as the articles' publication dates and the current date may not align. Instead, use absolute terms (e.g., 'In 2022,' 'As of the article's 2020 publication,').
|
26 |
+
4. Aim for concise, informative responses that directly address the question asked.
|
27 |
+
|
28 |
+
Remember, your goal is to provide accurate, helpful information on tobacco-related topics, aiding in education and informed decision-making.
|
29 |
+
"""
|
30 |
+
end_prompt = "\n----------------\n"
|
31 |
+
prompt += end_prompt
|
32 |
+
content = ""
|
33 |
+
seperator = "<<<<>>>>"
|
34 |
+
|
35 |
+
token_count = 0
|
36 |
+
encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
37 |
+
token_count += num_tokens_from_string(prompt, encoder)
|
38 |
+
|
39 |
+
articles = [contents for score, contents, uuids, titles, domains in information]
|
40 |
+
uuids = [uuids for score, contents, uuids, titles, domains in information]
|
41 |
+
domains = [domains for score, contents, uuids, titles, domains in information]
|
42 |
+
|
43 |
+
for i in range(len(articles)):
|
44 |
+
addition = "Article " + str(i + 1) + ": " + articles[i] + seperator
|
45 |
+
addition += articles[i] + seperator
|
46 |
+
token_count += num_tokens_from_string(addition, encoder)
|
47 |
+
if token_count > 3500:
|
48 |
+
print(i)
|
49 |
+
break
|
50 |
+
|
51 |
+
content += addition
|
52 |
+
|
53 |
+
prompt += content
|
54 |
+
llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0.0)
|
55 |
+
message = [
|
56 |
+
SystemMessage(content=prompt),
|
57 |
+
HumanMessage(content=question)
|
58 |
+
]
|
59 |
+
|
60 |
+
response = llm.invoke(message)
|
61 |
+
print(response)
|
62 |
+
print("response length:", len(response))
|
63 |
+
source = re.findall('\((.*?)\)', response)[-1]
|
64 |
+
|
65 |
+
# get integers from source
|
66 |
+
source = re.findall(r'\d+', source)
|
67 |
+
used_article_num = [int(i) - 1 for i in source]
|
68 |
+
|
69 |
+
links = [f"https://tobaccowatcher.globaltobaccocontrol.org/articles/{uuid}/" for uuid in uuids]
|
70 |
+
titles = [titles for score, contents, uuids, titles, domains in information]
|
71 |
+
|
72 |
+
links = [links[i] for i in used_article_num]
|
73 |
+
titles = [titles[i] for i in used_article_num]
|
74 |
+
domains = [domains[i] for i in used_article_num]
|
75 |
+
|
76 |
+
response_without_source = re.sub("""\(Article.*\)""", "", response)
|
77 |
+
return response_without_source, links, titles, domains
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
question = "How is United States fighting against tobacco addiction?"
|
81 |
+
rerank_type = "crossencoder"
|
82 |
+
llm_type = "chat"
|
83 |
+
csv_path = save_solr_articles_full(question, keyword_type="rake")
|
84 |
+
reranked_out = crossencoder_rerank_answer(csv_path, question)
|
85 |
+
feed_articles_to_gpt_with_links(reranked_out, question)
|
full_chain.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
from get_keywords import get_keywords
|
4 |
+
from get_articles import save_solr_articles_full
|
5 |
+
from rerank import langchain_rerank_answer, langchain_with_sources, crossencoder_rerank_answer, \
|
6 |
+
crossencoder_rerank_sentencewise, crossencoder_rerank_sentencewise_articles, no_rerank
|
7 |
+
#from feed_to_llm import feed_articles_to_gpt_with_links
|
8 |
+
from feed_to_llm_v2 import feed_articles_to_gpt_with_links
|
9 |
+
|
10 |
+
|
11 |
+
def get_response(question, rerank_type="crossencoder", llm_type="chat"):
|
12 |
+
csv_path = save_solr_articles_full(question, keyword_type="rake")
|
13 |
+
reranked_out = crossencoder_rerank_answer(csv_path, question)
|
14 |
+
return feed_articles_to_gpt_with_links(reranked_out, question)
|
15 |
+
|
16 |
+
|
17 |
+
# save_path = save_solr_articles_full(question)
|
18 |
+
# information = crossencoder_rerank_answer(save_path, question)
|
19 |
+
# response, links, titles = feed_articles_to_gpt_with_links(information, question)
|
20 |
+
#
|
21 |
+
# return response, links, titles
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
question = "How is United States fighting against tobacco addiction?"
|
27 |
+
rerank_type = "crossencoder"
|
28 |
+
llm_type = "chat"
|
29 |
+
response, links, titles, domains = get_response(question, rerank_type, llm_type)
|
30 |
+
print(response)
|
31 |
+
print(links)
|
32 |
+
print(titles)
|
33 |
+
print(domains)
|
get_articles.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pysolr import Solr
|
2 |
+
import os
|
3 |
+
import csv
|
4 |
+
from sentence_transformers import SentenceTransformer, util
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from get_keywords import get_keywords
|
8 |
+
import os
|
9 |
+
|
10 |
+
"""
|
11 |
+
This function creates top 15 articles from Solr and saves them in a csv file
|
12 |
+
Input:
|
13 |
+
query: str
|
14 |
+
num_articles: int
|
15 |
+
keyword_type: str (openai, rake, or na)
|
16 |
+
Output: path to csv file
|
17 |
+
"""
|
18 |
+
def save_solr_articles_full(query: str, num_articles=15, keyword_type="openai") -> str:
|
19 |
+
keywords = get_keywords(query, keyword_type)
|
20 |
+
if keyword_type == "na":
|
21 |
+
keywords = query
|
22 |
+
return save_solr_articles(keywords, num_articles)
|
23 |
+
|
24 |
+
|
25 |
+
"""
|
26 |
+
Removes spaces and newlines from text
|
27 |
+
Input: text: str
|
28 |
+
Output: text: str
|
29 |
+
"""
|
30 |
+
def remove_spaces_newlines(text: str) -> str:
|
31 |
+
text = text.replace('\n', ' ')
|
32 |
+
text = text.replace(' ', ' ')
|
33 |
+
return text
|
34 |
+
|
35 |
+
|
36 |
+
# truncates long articles to 1500 words
|
37 |
+
def truncate_article(text: str) -> str:
|
38 |
+
split = text.split()
|
39 |
+
if len(split) > 1500:
|
40 |
+
split = split[:1500]
|
41 |
+
text = ' '.join(split)
|
42 |
+
return text
|
43 |
+
|
44 |
+
|
45 |
+
"""
|
46 |
+
Searches Solr for articles based on keywords and saves them in a csv file
|
47 |
+
Input:
|
48 |
+
keywords: str
|
49 |
+
num_articles: int
|
50 |
+
Output: path to csv file
|
51 |
+
Minor details:
|
52 |
+
Removes duplicate articles to start with.
|
53 |
+
Articles with dead urls are removed since those articles are often wierd.
|
54 |
+
Articles with titles that start with five starting words are removed. they are usually duplicates with minor changes.
|
55 |
+
If one of title, uuid, cleaned_content, url are missing the article is skipped.
|
56 |
+
"""
|
57 |
+
def save_solr_articles(keywords: str, num_articles=15) -> str:
|
58 |
+
solr_key = os.getenv("SOLR_KEY")
|
59 |
+
SOLR_ARTICLES_URL = f"https://website:{solr_key}@solr.machines.globalhealthwatcher.org:8080/solr/articles/"
|
60 |
+
solr = Solr(SOLR_ARTICLES_URL, verify=False)
|
61 |
+
|
62 |
+
# No duplicates
|
63 |
+
fq = ['-dups:0']
|
64 |
+
|
65 |
+
query = f'text:({keywords})' + " AND " + "dead_url:(false)"
|
66 |
+
|
67 |
+
# Get top 2*num_articles articles and then remove misformed or duplicate articles
|
68 |
+
outputs = solr.search(query, fq=fq, sort="score desc", rows=num_articles * 2)
|
69 |
+
|
70 |
+
article_count = 0
|
71 |
+
|
72 |
+
save_path = os.path.join("data", "articles.csv")
|
73 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
74 |
+
os.makedirs(os.path.dirname(save_path))
|
75 |
+
|
76 |
+
with open(save_path, 'w', newline='') as csvfile:
|
77 |
+
fieldnames = ['title', 'uuid', 'content', 'url', 'domain']
|
78 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC)
|
79 |
+
writer.writeheader()
|
80 |
+
|
81 |
+
title_five_words = set()
|
82 |
+
|
83 |
+
for d in outputs.docs:
|
84 |
+
if article_count == num_articles:
|
85 |
+
break
|
86 |
+
|
87 |
+
# skip if title returns a keyerror
|
88 |
+
if 'title' not in d or 'uuid' not in d or 'cleaned_content' not in d or 'url' not in d:
|
89 |
+
continue
|
90 |
+
|
91 |
+
title_cleaned = remove_spaces_newlines(d['title'])
|
92 |
+
|
93 |
+
split = title_cleaned.split()
|
94 |
+
# skip if title is a duplicate
|
95 |
+
if not len(split) < 5:
|
96 |
+
five_words = title_cleaned.split()[:5]
|
97 |
+
five_words = ' '.join(five_words)
|
98 |
+
if five_words in title_five_words:
|
99 |
+
continue
|
100 |
+
title_five_words.add(five_words)
|
101 |
+
|
102 |
+
article_count += 1
|
103 |
+
|
104 |
+
cleaned_content = remove_spaces_newlines(d['cleaned_content'])
|
105 |
+
cleaned_content = truncate_article(cleaned_content)
|
106 |
+
|
107 |
+
domain = ""
|
108 |
+
if 'domain' not in d:
|
109 |
+
domain = "Not Specified"
|
110 |
+
else:
|
111 |
+
domain = d['domain']
|
112 |
+
print(domain)
|
113 |
+
|
114 |
+
writer.writerow({'title': title_cleaned, 'uuid': d['uuid'], 'content': cleaned_content, 'url': d['url'],
|
115 |
+
'domain': domain})
|
116 |
+
return save_path
|
117 |
+
|
118 |
+
|
119 |
+
def save_embedding_base_articles(query, article_embeddings, titles, contents, uuids, urls, num_articles=15):
|
120 |
+
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
|
121 |
+
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
122 |
+
hits = util.semantic_search(query_embedding, article_embeddings, top_k=15)
|
123 |
+
hits = hits[0]
|
124 |
+
corpus_ids = [item['corpus_id'] for item in hits]
|
125 |
+
r_contents = [contents[idx] for idx in corpus_ids]
|
126 |
+
r_titles = [titles[idx] for idx in corpus_ids]
|
127 |
+
r_uuids = [uuids[idx] for idx in corpus_ids]
|
128 |
+
r_urls = [urls[idx] for idx in corpus_ids]
|
129 |
+
|
130 |
+
save_path = os.path.join("data", "articles.csv")
|
131 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
132 |
+
os.makedirs(os.path.dirname(save_path))
|
133 |
+
|
134 |
+
with open(save_path, 'w', newline='', encoding="utf-8") as csvfile:
|
135 |
+
fieldNames = ['title', 'uuid', 'content', 'url']
|
136 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldNames, quoting=csv.QUOTE_NONNUMERIC)
|
137 |
+
writer.writeheader()
|
138 |
+
for i in range(num_articles):
|
139 |
+
writer.writerow({'title': r_titles[i], 'uuid': r_uuids[i], 'content': r_contents[i], 'url': r_urls[i]})
|
140 |
+
return save_path
|
get_keywords.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chat_models import ChatOpenAI
|
2 |
+
from langchain.schema import (
|
3 |
+
HumanMessage,
|
4 |
+
SystemMessage
|
5 |
+
)
|
6 |
+
|
7 |
+
from rake_nltk import Rake
|
8 |
+
import nltk
|
9 |
+
nltk.download('stopwords')
|
10 |
+
nltk.download('punkt')
|
11 |
+
"""
|
12 |
+
This function takes in user query and returns keywords
|
13 |
+
Input:
|
14 |
+
user_query: str
|
15 |
+
keyword_type: str (openai, rake, or na)
|
16 |
+
If the keyword type is na, then user query is returned.
|
17 |
+
Output: keywords: str
|
18 |
+
"""
|
19 |
+
def get_keywords(user_query: str, keyword_type: str) -> str:
|
20 |
+
if keyword_type == "openai":
|
21 |
+
return get_keywords_openai(user_query)
|
22 |
+
if keyword_type == "rake":
|
23 |
+
return get_keywords_rake(user_query)
|
24 |
+
else:
|
25 |
+
return user_query
|
26 |
+
|
27 |
+
|
28 |
+
"""
|
29 |
+
This function takes user query and returns keywords using rake_nltk
|
30 |
+
rake_nltk actually returns keyphrases, not keywords. Since using keyphrases did not show improvement, we are using keywords
|
31 |
+
to match the output type of the other keyword functions.
|
32 |
+
Input:
|
33 |
+
user_query: str
|
34 |
+
Output: keywords: str
|
35 |
+
"""
|
36 |
+
def get_keywords_rake(user_query: str) -> str:
|
37 |
+
r = Rake()
|
38 |
+
r.extract_keywords_from_text(user_query)
|
39 |
+
keyphrases = r.get_ranked_phrases()
|
40 |
+
|
41 |
+
# If we want to get keyphrases, return keyphrases but should do keywords
|
42 |
+
out = ""
|
43 |
+
for phrase in keyphrases:
|
44 |
+
out += phrase + " "
|
45 |
+
return out
|
46 |
+
|
47 |
+
|
48 |
+
"""
|
49 |
+
This function takes user query and returns keywords using openai
|
50 |
+
Input:
|
51 |
+
user_query: str
|
52 |
+
Output: keywords: str
|
53 |
+
"""
|
54 |
+
def get_keywords_openai(user_query: str) -> str:
|
55 |
+
llm = ChatOpenAI(temperature=0.0)
|
56 |
+
command = "return the keywords of the following query. response should be words separated by commas. "
|
57 |
+
message = [
|
58 |
+
SystemMessage(content=command),
|
59 |
+
HumanMessage(content=user_query)
|
60 |
+
]
|
61 |
+
response = llm(message)
|
62 |
+
res = response.content.replace(",", "")
|
63 |
+
return res
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.25.0
|
2 |
+
langchain==0.1.14
|
3 |
+
langchain_core==0.1.40
|
4 |
+
langchain_openai==0.1.1
|
5 |
+
nltk==3.8.1
|
6 |
+
openai==1.16.2
|
7 |
+
pandas==2.2.1
|
8 |
+
pysolr==3.9.0
|
9 |
+
rake_nltk==1.0.6
|
10 |
+
sentence_transformers==2.2.2
|
11 |
+
tiktoken==0.5.2
|
12 |
+
torch==2.1.2
|
rerank.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# reranks the top articles from a given csv file
|
2 |
+
|
3 |
+
from langchain.chat_models import ChatOpenAI
|
4 |
+
from langchain.chains import RetrievalQA
|
5 |
+
from langchain.chat_models import ChatOpenAI
|
6 |
+
from langchain.document_loaders import CSVLoader
|
7 |
+
from langchain.indexes import VectorstoreIndexCreator
|
8 |
+
from langchain.vectorstores import DocArrayInMemorySearch
|
9 |
+
from sentence_transformers import CrossEncoder, util
|
10 |
+
from langchain.chains import RetrievalQAWithSourcesChain
|
11 |
+
from nltk import sent_tokenize
|
12 |
+
import pandas as pd
|
13 |
+
import time
|
14 |
+
|
15 |
+
"""
|
16 |
+
This function rerank top articles (15 -> 4) from a given csv, then sends to LLM
|
17 |
+
Input:
|
18 |
+
csv_path: str
|
19 |
+
question: str
|
20 |
+
top_n: int
|
21 |
+
Output:
|
22 |
+
response: str
|
23 |
+
links: list of str
|
24 |
+
titles: list of str
|
25 |
+
|
26 |
+
Other functions in this file does not send articles to LLM. This is an exception.
|
27 |
+
Created using langchain RAG functions. Deprecated.
|
28 |
+
Update: Use langchain_RAG instead.
|
29 |
+
"""
|
30 |
+
|
31 |
+
|
32 |
+
def langchain_rerank_answer(csv_path, question, source='url', top_n=4):
|
33 |
+
llm = ChatOpenAI(temperature=0.0)
|
34 |
+
loader = CSVLoader(csv_path, source_column="url")
|
35 |
+
|
36 |
+
index = VectorstoreIndexCreator(
|
37 |
+
vectorstore_cls=DocArrayInMemorySearch,
|
38 |
+
).from_loaders([loader])
|
39 |
+
|
40 |
+
# prompt_template = """You are an a chatbot that answers tobacco related questions with source. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
41 |
+
# {context}
|
42 |
+
# Question: {question}"""
|
43 |
+
# PROMPT = PromptTemplate(
|
44 |
+
# template=prompt_template, input_variables=["context", "question"]
|
45 |
+
# )
|
46 |
+
# chain_type_kwargs = {"prompt": PROMPT}
|
47 |
+
|
48 |
+
qa = RetrievalQA.from_chain_type(
|
49 |
+
llm=llm,
|
50 |
+
chain_type="stuff",
|
51 |
+
retriever=index.vectorstore.as_retriever(),
|
52 |
+
verbose=False,
|
53 |
+
return_source_documents=True,
|
54 |
+
# chain_type_kwargs=chain_type_kwargs,
|
55 |
+
# chain_type_kwargs = {
|
56 |
+
# "document_separator": "<<<<>>>>>"
|
57 |
+
# },
|
58 |
+
)
|
59 |
+
|
60 |
+
answer = qa({"query": question})
|
61 |
+
sources = answer['source_documents']
|
62 |
+
sources_out = [source.metadata['source'] for source in sources]
|
63 |
+
|
64 |
+
return answer['result'], sources_out
|
65 |
+
|
66 |
+
|
67 |
+
"""
|
68 |
+
Langchain with sources.
|
69 |
+
This function is deprecated. Use langchain_RAG instead.
|
70 |
+
"""
|
71 |
+
|
72 |
+
|
73 |
+
def langchain_with_sources(csv_path, question, top_n=4):
|
74 |
+
llm = ChatOpenAI(temperature=0.0)
|
75 |
+
loader = CSVLoader(csv_path, source_column="uuid")
|
76 |
+
index = VectorstoreIndexCreator(
|
77 |
+
vectorstore_cls=DocArrayInMemorySearch,
|
78 |
+
).from_loaders([loader])
|
79 |
+
|
80 |
+
qa = RetrievalQAWithSourcesChain.from_chain_type(
|
81 |
+
llm=llm,
|
82 |
+
chain_type="stuff",
|
83 |
+
retriever=index.vectorstore.as_retriever(),
|
84 |
+
)
|
85 |
+
output = qa({"question": question}, return_only_outputs=True)
|
86 |
+
return output['answer'], output['sources']
|
87 |
+
|
88 |
+
|
89 |
+
"""
|
90 |
+
Reranks the top articles using crossencoder.
|
91 |
+
Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for embedding / reranking.
|
92 |
+
Input:
|
93 |
+
csv_path: str
|
94 |
+
question: str
|
95 |
+
top_n: int
|
96 |
+
Output:
|
97 |
+
out_values: list of [content, uuid, title]
|
98 |
+
"""
|
99 |
+
|
100 |
+
|
101 |
+
# returns list of top n similar articles using crossencoder
|
102 |
+
def crossencoder_rerank_answer(csv_path: str, question: str, top_n=4) -> list:
|
103 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
104 |
+
articles = pd.read_csv(csv_path)
|
105 |
+
contents = articles['content'].tolist()
|
106 |
+
uuids = articles['uuid'].tolist()
|
107 |
+
titles = articles['title'].tolist()
|
108 |
+
|
109 |
+
# biencoder retrieval does not have domain
|
110 |
+
if 'domain' not in articles:
|
111 |
+
domain = [""] * len(contents)
|
112 |
+
else:
|
113 |
+
domain = articles['domain'].tolist()
|
114 |
+
|
115 |
+
cross_inp = [[question, content] for content in contents]
|
116 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
117 |
+
scores_sentences = list(zip(cross_scores, contents, uuids, titles, domain))
|
118 |
+
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
119 |
+
|
120 |
+
out_values = scores_sentences[:top_n]
|
121 |
+
|
122 |
+
# if score is less than 0, truncate
|
123 |
+
for idx in range(len(out_values)):
|
124 |
+
if out_values[idx][0] < 0:
|
125 |
+
out_values = out_values[:idx]
|
126 |
+
if len(out_values) == 0:
|
127 |
+
out_values = scores_sentences[:1]
|
128 |
+
|
129 |
+
break
|
130 |
+
# print(out_values)
|
131 |
+
return out_values
|
132 |
+
|
133 |
+
|
134 |
+
def crossencoder_rerank_sentencewise(csv_path: str, question: str, top_n=10) -> list:
|
135 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
136 |
+
articles = pd.read_csv(csv_path)
|
137 |
+
contents = articles['content'].tolist()
|
138 |
+
uuids = articles['uuid'].tolist()
|
139 |
+
titles = articles['title'].tolist()
|
140 |
+
|
141 |
+
if 'domain' not in articles:
|
142 |
+
domain = [""] * len(contents)
|
143 |
+
else:
|
144 |
+
domain = articles['domain'].tolist()
|
145 |
+
|
146 |
+
sentences = []
|
147 |
+
new_uuids = []
|
148 |
+
new_titles = []
|
149 |
+
new_domains = []
|
150 |
+
for idx in range(len(contents)):
|
151 |
+
sents = sent_tokenize(contents[idx])
|
152 |
+
sentences.extend(sents)
|
153 |
+
new_uuids.extend([uuids[idx]] * len(sents))
|
154 |
+
new_titles.extend([titles[idx]] * len(sents))
|
155 |
+
new_domains.extend([domain[idx]] * len(sents))
|
156 |
+
|
157 |
+
cross_inp = [[question, sent] for sent in sentences]
|
158 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
159 |
+
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
|
160 |
+
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
161 |
+
|
162 |
+
out_values = scores_sentences[:top_n]
|
163 |
+
|
164 |
+
# if score is less than 0, truncate
|
165 |
+
for idx in range(len(out_values)):
|
166 |
+
if out_values[idx][0] < 0:
|
167 |
+
out_values = out_values[:idx]
|
168 |
+
if len(out_values) == 0:
|
169 |
+
out_values = scores_sentences[:1]
|
170 |
+
|
171 |
+
break
|
172 |
+
|
173 |
+
return out_values
|
174 |
+
|
175 |
+
|
176 |
+
def crossencoder_rerank_sentencewise_sentence_chunks(csv_path, question, top_n=10, chunk_size=2):
|
177 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
178 |
+
articles = pd.read_csv(csv_path)
|
179 |
+
contents = articles['content'].tolist()
|
180 |
+
uuids = articles['uuid'].tolist()
|
181 |
+
titles = articles['title'].tolist()
|
182 |
+
|
183 |
+
# embeddings do not have domain as column
|
184 |
+
if 'domain' not in articles:
|
185 |
+
domain = [""] * len(contents)
|
186 |
+
else:
|
187 |
+
domain = articles['domain'].tolist()
|
188 |
+
|
189 |
+
sentences = []
|
190 |
+
new_uuids = []
|
191 |
+
new_titles = []
|
192 |
+
new_domains = []
|
193 |
+
|
194 |
+
for idx in range(len(contents)):
|
195 |
+
sents = sent_tokenize(contents[idx])
|
196 |
+
sents_merged = []
|
197 |
+
|
198 |
+
# if the number of sentences is less than chunk size, merge and join
|
199 |
+
if len(sents) < chunk_size:
|
200 |
+
sents_merged.append(' '.join(sents))
|
201 |
+
else:
|
202 |
+
for i in range(0, len(sents) - chunk_size + 1):
|
203 |
+
sents_merged.append(' '.join(sents[i:i + chunk_size]))
|
204 |
+
|
205 |
+
sentences.extend(sents_merged)
|
206 |
+
new_uuids.extend([uuids[idx]] * len(sents_merged))
|
207 |
+
new_titles.extend([titles[idx]] * len(sents_merged))
|
208 |
+
new_domains.extend([domain[idx]] * len(sents_merged))
|
209 |
+
|
210 |
+
cross_inp = [[question, sent] for sent in sentences]
|
211 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
212 |
+
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
|
213 |
+
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
214 |
+
|
215 |
+
out_values = scores_sentences[:top_n]
|
216 |
+
|
217 |
+
for idx in range(len(out_values)):
|
218 |
+
if out_values[idx][0] < 0:
|
219 |
+
out_values = out_values[:idx]
|
220 |
+
if len(out_values) == 0:
|
221 |
+
out_values = scores_sentences[:1]
|
222 |
+
|
223 |
+
break
|
224 |
+
|
225 |
+
return out_values
|
226 |
+
|
227 |
+
|
228 |
+
def crossencoder_rerank_sentencewise_articles(csv_path, question, top_n=4):
|
229 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
230 |
+
contents, uuids, titles, domain = load_articles(csv_path)
|
231 |
+
|
232 |
+
sentences = []
|
233 |
+
contents_elongated = []
|
234 |
+
new_uuids = []
|
235 |
+
new_titles = []
|
236 |
+
new_domains = []
|
237 |
+
|
238 |
+
for idx in range(len(contents)):
|
239 |
+
sents = sent_tokenize(contents[idx])
|
240 |
+
sentences.extend(sents)
|
241 |
+
new_uuids.extend([uuids[idx]] * len(sents))
|
242 |
+
contents_elongated.extend([contents[idx]] * len(sents))
|
243 |
+
new_titles.extend([titles[idx]] * len(sents))
|
244 |
+
new_domains.extend([domain[idx]] * len(sents))
|
245 |
+
|
246 |
+
cross_inp = [[question, sent] for sent in sentences]
|
247 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
248 |
+
scores_sentences = list(zip(cross_scores, contents_elongated, new_uuids, new_titles, new_domains))
|
249 |
+
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
250 |
+
|
251 |
+
score_sentences_compressed = []
|
252 |
+
for item in scores_sentences:
|
253 |
+
if not score_sentences_compressed:
|
254 |
+
score_sentences_compressed.append(item)
|
255 |
+
else:
|
256 |
+
if item[2] not in [x[2] for x in score_sentences_compressed]:
|
257 |
+
score_sentences_compressed.append(item)
|
258 |
+
|
259 |
+
scores_sentences = score_sentences_compressed
|
260 |
+
return scores_sentences[:top_n]
|
261 |
+
|
262 |
+
|
263 |
+
def no_rerank(csv_path, question, top_n=4):
|
264 |
+
contents, uuids, titles, domains = load_articles(csv_path)
|
265 |
+
return list(zip(contents, uuids, titles, domains))[:top_n]
|
266 |
+
|
267 |
+
|
268 |
+
def load_articles(csv_path:str):
|
269 |
+
articles = pd.read_csv(csv_path)
|
270 |
+
contents = articles['content'].tolist()
|
271 |
+
uuids = articles['uuid'].tolist()
|
272 |
+
titles = articles['title'].tolist()
|
273 |
+
if 'domain' not in articles:
|
274 |
+
domain = [""] * len(contents)
|
275 |
+
else:
|
276 |
+
domain = articles['domain'].tolist()
|
277 |
+
return contents, uuids, titles, domain
|
278 |
+
|