Spaces:
Running
Running
# reranks the top articles from a given csv file | |
from langchain_openai import ChatOpenAI | |
from langchain.chains import RetrievalQA | |
from langchain_community.document_loaders.csv_loader import CSVLoader | |
from langchain_community.vectorstores import DocArrayInMemorySearch | |
from sentence_transformers import CrossEncoder | |
import pandas as pd | |
import time | |
""" | |
This function rerank top articles (15 -> 4) from a given csv, then sends to LLM | |
Input: | |
csv_path: str | |
question: str | |
top_n: int | |
Output: | |
response: str | |
links: list of str | |
titles: list of str | |
Other functions in this file does not send articles to LLM. This is an exception. | |
Created using langchain RAG functions. Deprecated. | |
Update: Use langchain_RAG instead. | |
""" | |
def langchain_rerank_answer(csv_path, question, source='url', top_n=4): | |
llm = ChatOpenAI(temperature=0.0) | |
loader = CSVLoader(csv_path, source_column="url") | |
index = VectorstoreIndexCreator( | |
vectorstore_cls=DocArrayInMemorySearch, | |
).from_loaders([loader]) | |
# prompt_template = """You are an a chatbot that answers tobacco related questions with source. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
# {context} | |
# Question: {question}""" | |
# PROMPT = PromptTemplate( | |
# template=prompt_template, input_variables=["context", "question"] | |
# ) | |
# chain_type_kwargs = {"prompt": PROMPT} | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=index.vectorstore.as_retriever(), | |
verbose=False, | |
return_source_documents=True, | |
# chain_type_kwargs=chain_type_kwargs, | |
# chain_type_kwargs = { | |
# "document_separator": "<<<<>>>>>" | |
# }, | |
) | |
answer = qa({"query": question}) | |
sources = answer['source_documents'] | |
sources_out = [source.metadata['source'] for source in sources] | |
return answer['result'], sources_out | |
""" | |
Langchain with sources. | |
This function is deprecated. Use langchain_RAG instead. | |
""" | |
def langchain_with_sources(csv_path, question, top_n=4): | |
llm = ChatOpenAI(temperature=0.0) | |
loader = CSVLoader(csv_path, source_column="uuid") | |
index = VectorstoreIndexCreator( | |
vectorstore_cls=DocArrayInMemorySearch, | |
).from_loaders([loader]) | |
qa = RetrievalQAWithSourcesChain.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=index.vectorstore.as_retriever(), | |
) | |
output = qa({"question": question}, return_only_outputs=True) | |
return output['answer'], output['sources'] | |
""" | |
Reranks the top articles using crossencoder. | |
Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for embedding / reranking. | |
Input: | |
csv_path: str | |
question: str | |
top_n: int | |
Output: | |
out_values: list of [content, uuid, title] | |
""" | |
# returns list of top n similar articles using crossencoder | |
def crossencoder_rerank_answer(csv_path: str, question: str, top_n=4) -> list: | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
articles = pd.read_csv(csv_path) | |
contents = articles['content'].tolist() | |
uuids = articles['uuid'].tolist() | |
titles = articles['title'].tolist() | |
# biencoder retrieval does not have domain | |
if 'domain' not in articles: | |
domain = [""] * len(contents) | |
else: | |
domain = articles['domain'].tolist() | |
cross_inp = [[question, content] for content in contents] | |
cross_scores = cross_encoder.predict(cross_inp) | |
scores_sentences = list(zip(cross_scores, contents, uuids, titles, domain)) | |
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True) | |
out_values = scores_sentences[:top_n] | |
# if score is less than 0, truncate | |
for idx in range(len(out_values)): | |
if out_values[idx][0] < 0: | |
out_values = out_values[:idx] | |
if len(out_values) == 0: | |
out_values = scores_sentences[:1] | |
break | |
# print(out_values) | |
return out_values | |
def crossencoder_rerank_sentencewise(csv_path: str, question: str, top_n=10) -> list: | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
articles = pd.read_csv(csv_path) | |
contents = articles['content'].tolist() | |
uuids = articles['uuid'].tolist() | |
titles = articles['title'].tolist() | |
if 'domain' not in articles: | |
domain = [""] * len(contents) | |
else: | |
domain = articles['domain'].tolist() | |
sentences = [] | |
new_uuids = [] | |
new_titles = [] | |
new_domains = [] | |
for idx in range(len(contents)): | |
sents = sent_tokenize(contents[idx]) | |
sentences.extend(sents) | |
new_uuids.extend([uuids[idx]] * len(sents)) | |
new_titles.extend([titles[idx]] * len(sents)) | |
new_domains.extend([domain[idx]] * len(sents)) | |
cross_inp = [[question, sent] for sent in sentences] | |
cross_scores = cross_encoder.predict(cross_inp) | |
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains)) | |
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True) | |
out_values = scores_sentences[:top_n] | |
# if score is less than 0, truncate | |
for idx in range(len(out_values)): | |
if out_values[idx][0] < 0: | |
out_values = out_values[:idx] | |
if len(out_values) == 0: | |
out_values = scores_sentences[:1] | |
break | |
return out_values | |
def crossencoder_rerank_sentencewise_sentence_chunks(csv_path, question, top_n=10, chunk_size=2): | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
articles = pd.read_csv(csv_path) | |
contents = articles['content'].tolist() | |
uuids = articles['uuid'].tolist() | |
titles = articles['title'].tolist() | |
# embeddings do not have domain as column | |
if 'domain' not in articles: | |
domain = [""] * len(contents) | |
else: | |
domain = articles['domain'].tolist() | |
sentences = [] | |
new_uuids = [] | |
new_titles = [] | |
new_domains = [] | |
for idx in range(len(contents)): | |
sents = sent_tokenize(contents[idx]) | |
sents_merged = [] | |
# if the number of sentences is less than chunk size, merge and join | |
if len(sents) < chunk_size: | |
sents_merged.append(' '.join(sents)) | |
else: | |
for i in range(0, len(sents) - chunk_size + 1): | |
sents_merged.append(' '.join(sents[i:i + chunk_size])) | |
sentences.extend(sents_merged) | |
new_uuids.extend([uuids[idx]] * len(sents_merged)) | |
new_titles.extend([titles[idx]] * len(sents_merged)) | |
new_domains.extend([domain[idx]] * len(sents_merged)) | |
cross_inp = [[question, sent] for sent in sentences] | |
cross_scores = cross_encoder.predict(cross_inp) | |
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains)) | |
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True) | |
out_values = scores_sentences[:top_n] | |
for idx in range(len(out_values)): | |
if out_values[idx][0] < 0: | |
out_values = out_values[:idx] | |
if len(out_values) == 0: | |
out_values = scores_sentences[:1] | |
break | |
return out_values | |
def crossencoder_rerank_sentencewise_articles(csv_path, question, top_n=4): | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
contents, uuids, titles, domain = load_articles(csv_path) | |
sentences = [] | |
contents_elongated = [] | |
new_uuids = [] | |
new_titles = [] | |
new_domains = [] | |
for idx in range(len(contents)): | |
sents = sent_tokenize(contents[idx]) | |
sentences.extend(sents) | |
new_uuids.extend([uuids[idx]] * len(sents)) | |
contents_elongated.extend([contents[idx]] * len(sents)) | |
new_titles.extend([titles[idx]] * len(sents)) | |
new_domains.extend([domain[idx]] * len(sents)) | |
cross_inp = [[question, sent] for sent in sentences] | |
cross_scores = cross_encoder.predict(cross_inp) | |
scores_sentences = list(zip(cross_scores, contents_elongated, new_uuids, new_titles, new_domains)) | |
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True) | |
score_sentences_compressed = [] | |
for item in scores_sentences: | |
if not score_sentences_compressed: | |
score_sentences_compressed.append(item) | |
else: | |
if item[2] not in [x[2] for x in score_sentences_compressed]: | |
score_sentences_compressed.append(item) | |
scores_sentences = score_sentences_compressed | |
return scores_sentences[:top_n] | |
def no_rerank(csv_path, question, top_n=4): | |
contents, uuids, titles, domains = load_articles(csv_path) | |
return list(zip(contents, uuids, titles, domains))[:top_n] | |
def load_articles(csv_path:str): | |
articles = pd.read_csv(csv_path) | |
contents = articles['content'].tolist() | |
uuids = articles['uuid'].tolist() | |
titles = articles['title'].tolist() | |
if 'domain' not in articles: | |
domain = [""] * len(contents) | |
else: | |
domain = articles['domain'].tolist() | |
return contents, uuids, titles, domain | |