Spaces:
Build error
Build error
def get_bm25_search_hits(corpus, sparse_scores, top_n=50): | |
bm25_search = [] | |
indices = [] | |
for idx in sparse_scores: | |
if len(bm25_search) <= top_n: | |
bm25_search.append(corpus[idx]) | |
indices.append(idx) | |
indices = [int(x) for x in indices] | |
return indices | |
def query_pinecone( | |
dense_vec, | |
top_k, | |
index, | |
year, | |
quarter, | |
ticker, | |
participant_type, | |
keywords=None, | |
indices=None, | |
threshold=0.25, | |
): | |
if participant_type == "Company Speaker": | |
participant = "Answer" | |
else: | |
participant = "Question" | |
if year == "All": | |
if quarter == "All": | |
if indices != None: | |
if keywords != None: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": { | |
"$in": [ | |
int("2020"), | |
int("2019"), | |
int("2018"), | |
int("2017"), | |
int("2016"), | |
] | |
}, | |
"Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"Keywords": {"$in": keywords}, | |
"index": {"$in": indices}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": { | |
"$in": [ | |
int("2020"), | |
int("2019"), | |
int("2018"), | |
int("2017"), | |
int("2016"), | |
] | |
}, | |
"Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"index": {"$in": indices}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
if keywords != None: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": { | |
"$in": [ | |
int("2020"), | |
int("2019"), | |
int("2018"), | |
int("2017"), | |
int("2016"), | |
] | |
}, | |
"Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"Keywords": {"$in": keywords}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": { | |
"$in": [ | |
int("2020"), | |
int("2019"), | |
int("2018"), | |
int("2017"), | |
int("2016"), | |
] | |
}, | |
"Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
if indices != None: | |
if keywords != None: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": { | |
"$in": [ | |
int("2020"), | |
int("2019"), | |
int("2018"), | |
int("2017"), | |
int("2016"), | |
] | |
}, | |
"Quarter": {"$eq": quarter}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"Keywords": {"$in": keywords}, | |
"index": {"$in": indices}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": { | |
"$in": [ | |
int("2020"), | |
int("2019"), | |
int("2018"), | |
int("2017"), | |
int("2016"), | |
] | |
}, | |
"Quarter": {"$eq": quarter}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"index": {"$in": indices}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
if keywords != None: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": { | |
"$in": [ | |
int("2020"), | |
int("2019"), | |
int("2018"), | |
int("2017"), | |
int("2016"), | |
] | |
}, | |
"Quarter": {"$eq": quarter}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"Keywords": {"$in": keywords}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": { | |
"$in": [ | |
int("2020"), | |
int("2019"), | |
int("2018"), | |
int("2017"), | |
int("2016"), | |
] | |
}, | |
"Quarter": {"$eq": quarter}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
# search pinecone index for context passage with the answer | |
if indices != None: | |
if keywords != None: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": int(year), | |
"Quarter": {"$eq": quarter}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"Keywords": {"$in": keywords}, | |
"index": {"$in": indices}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": int(year), | |
"Quarter": {"$eq": quarter}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"index": {"$in": indices}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
if keywords != None: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": int(year), | |
"Quarter": {"$eq": quarter}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"Keywords": {"$in": keywords}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
xc = index.query( | |
vector=dense_vec, | |
top_k=top_k, | |
filter={ | |
"Year": int(year), | |
"Quarter": {"$eq": quarter}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
}, | |
include_metadata=True, | |
) | |
# filter the context passages based on the score threshold | |
filtered_matches = [] | |
for match in xc["matches"]: | |
if match["score"] >= threshold: | |
filtered_matches.append(match) | |
xc["matches"] = filtered_matches | |
return xc | |
def query_pinecone_sparse( | |
dense_vec, | |
sparse_vec, | |
top_k, | |
index, | |
year, | |
quarter, | |
ticker, | |
participant_type, | |
keywords=None, | |
indices=None, | |
threshold=0.25, | |
): | |
if participant_type == "Company Speaker": | |
participant = "Answer" | |
else: | |
participant = "Question" | |
if year == "All": | |
if quarter == "All": | |
xc = index.query( | |
vector=dense_vec, | |
sparse_vector=sparse_vec, | |
top_k=top_k, | |
filter={ | |
"Year": { | |
"$in": [ | |
int("2020"), | |
int("2019"), | |
int("2018"), | |
int("2017"), | |
int("2016"), | |
] | |
}, | |
"Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"Keywords": {"$in": keywords}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
xc = index.query( | |
vector=dense_vec, | |
sparse_vector=sparse_vec, | |
top_k=top_k, | |
filter={ | |
"Year": { | |
"$in": [ | |
int("2020"), | |
int("2019"), | |
int("2018"), | |
int("2017"), | |
int("2016"), | |
] | |
}, | |
"Quarter": {"$eq": quarter}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"Keywords": {"$in": keywords}, | |
}, | |
include_metadata=True, | |
) | |
else: | |
# search pinecone index for context passage with the answer | |
xc = index.query( | |
vector=dense_vec, | |
sparse_vector=sparse_vec, | |
top_k=top_k, | |
filter={ | |
"Year": int(year), | |
"Quarter": {"$eq": quarter}, | |
"Ticker": {"$eq": ticker}, | |
"QA_Flag": {"$eq": participant}, | |
"Keywords": {"$in": keywords}, | |
}, | |
include_metadata=True, | |
) | |
# filter the context passages based on the score threshold | |
filtered_matches = [] | |
for match in xc["matches"]: | |
if match["score"] >= threshold: | |
filtered_matches.append(match) | |
xc["matches"] = filtered_matches | |
return xc | |
def format_query(query_results): | |
# extract passage_text from Pinecone search result | |
context = [ | |
result["metadata"]["Text"] for result in query_results["matches"] | |
] | |
return context | |
def sentence_id_combine(data, query_results, lag=1): | |
# Extract sentence IDs from query results | |
ids = [ | |
result["metadata"]["Sentence_id"] | |
for result in query_results["matches"] | |
] | |
# Generate new IDs by adding a lag value to the original IDs | |
new_ids = [id + i for id in ids for i in range(-lag, lag + 1)] | |
# Remove duplicates and sort the new IDs | |
new_ids = sorted(set(new_ids)) | |
# Create a list of lookup IDs by grouping the new IDs in groups of lag*2+1 | |
lookup_ids = [ | |
new_ids[i : i + (lag * 2 + 1)] | |
for i in range(0, len(new_ids), lag * 2 + 1) | |
] | |
# Create a list of context sentences by joining the sentences | |
# corresponding to the lookup IDs | |
context_list = [ | |
" ".join( | |
data.loc[data["Sentence_id"].isin(lookup_id), "Text"].to_list() | |
) | |
for lookup_id in lookup_ids | |
] | |
return context_list | |
def text_lookup(data, sentence_ids): | |
context = ". ".join(data.iloc[sentence_ids].to_list()) | |
return context | |
def year_quarter_range(start_quarter, start_year, end_quarter, end_year): | |
"""Creates a list of all (year, quarter) pairs that lie in the range including the start and end quarters.""" | |
start_year = int(start_year) | |
end_year = int(end_year) | |
quarters = ( | |
[("Q1", "Q2", "Q3", "Q4")] * (end_year - start_year) | |
+ [("Q1", "Q2", "Q3" if end_quarter == "Q4" else "Q4")] | |
* (end_quarter == "Q4") | |
+ [ | |
( | |
"Q1" | |
if start_quarter == "Q1" | |
else "Q2" | |
if start_quarter == "Q2" | |
else "Q3" | |
if start_quarter == "Q3" | |
else "Q4", | |
) | |
* (end_year - start_year) | |
] | |
) | |
years = list(range(start_year, end_year + 1)) | |
list_year_quarter = [ | |
(y, q) for y in years for q in quarters[years.index(y)] | |
] | |
# Remove duplicate pairs | |
seen = set() | |
list_year_quarter_cleaned = [] | |
for tup in list_year_quarter: | |
if tup not in seen: | |
seen.add(tup) | |
list_year_quarter_cleaned.append(tup) | |
return list_year_quarter_cleaned | |
def multi_document_query( | |
dense_query_embedding, | |
sparse_query_embedding, | |
num_results, | |
pinecone_index, | |
start_quarter, | |
start_year, | |
end_quarter, | |
end_year, | |
ticker, | |
participant_type, | |
threshold, | |
): | |
pass | |