paperswithcode_nbow / pages /1_Retrieval_App.py
lambdaofgod's picture
updated model config
1ed024e
raw
history blame
4.75 kB
import os
from typing import Dict, List
import datasets
import pandas as pd
import sentence_transformers
import streamlit as st
from findkit import feature_extractors, indexes, retrieval_pipeline
from toolz import partial
import config
def truncate_description(description, length=50):
return " ".join(description.split()[:length])
def get_repos_with_descriptions(repos_df, repos):
return repos_df.loc[repos]
def search_f(
retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
query: str,
k: int,
description_length: int,
doc_col: List[str],
):
results = retrieval_pipe.find_similar(query, k)
# results['repo'] = results.index
results["link"] = "https://github.com/" + results["repo"]
for col in doc_col:
results[col] = results[col].apply(
lambda desc: truncate_description(desc, description_length)
)
shown_cols = ["repo", "tasks", "link", "distance"]
shown_cols = shown_cols + doc_col
return results.reset_index(drop=True)[shown_cols]
def show_retrieval_results(
retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
query: str,
k: int,
all_queries: List[str],
description_length: int,
repos_by_query: Dict[str, pd.DataFrame],
doc_col: str,
):
print("started retrieval")
if query in all_queries:
with st.expander(
"query is in gold standard set queries. Toggle viewing gold standard results?"
):
st.write("gold standard results")
task_repos = repos_by_query.get_group(query)
st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
with st.spinner(text="fetching results"):
st.write(
search_f(retrieval_pipe, query, k, description_length, doc_col).to_html(
escape=False, index=False
),
unsafe_allow_html=True,
)
print("finished retrieval")
def setup_pipeline(
extractor: feature_extractors.SentenceEncoderFeatureExtractor,
documents_df: pd.DataFrame,
text_col: str,
):
retrieval_pipeline.RetrievalPipelineFactory.build(
documents_df[text_col], metadata=documents_df
)
@st.cache(allow_output_mutation=True)
def setup_retrieval_pipeline(
query_encoder_path, document_encoder_path, documents, metadata
):
document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
sentence_transformers.SentenceTransformer(document_encoder_path, device="cpu")
)
query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
sentence_transformers.SentenceTransformer(query_encoder_path, device="cpu")
)
retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
feature_extractor=document_encoder,
query_feature_extractor=query_encoder,
index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
)
return retrieval_pipe.build(documents, metadata=metadata)
def app(retrieval_pipeline, retrieval_df, doc_col):
retrieved_results = st.sidebar.number_input("number of results", value=10)
description_length = st.sidebar.number_input(
"number of used description words", value=10
)
tasks_deduped = (
retrieval_df["tasks"].explode().value_counts().reset_index()
) # drop_duplicates().sort_values().reset_index(drop=True)
tasks_deduped.columns = ["task", "documents per task"]
with st.sidebar.expander("View test set queries"):
st.table(tasks_deduped.explode("task"))
additional_shown_cols = st.sidebar.multiselect(
label="additional cols", options=[doc_col], default=doc_col
)
repos_by_query = retrieval_df.explode("tasks").groupby("tasks")
query = st.text_input("input query", value="metric learning")
show_retrieval_results(
retrieval_pipeline,
query,
retrieved_results,
tasks_deduped["task"].to_list(),
description_length,
repos_by_query,
additional_shown_cols,
)
def app_main(
query_encoder_path,
document_encoder_path,
data_path,
):
print("loading data")
retrieval_df = (
datasets.load_dataset(data_path)["train"]
.to_pandas()
.drop_duplicates(subset=["repo"])
.reset_index(drop=True)
)
print("setting up retrieval_pipe")
doc_col = "dependencies"
retrieval_pipeline = setup_retrieval_pipeline(
query_encoder_path, document_encoder_path, retrieval_df[doc_col], retrieval_df
)
app(retrieval_pipeline, retrieval_df, doc_col)
app_main(
query_encoder_path=config.query_encoder_model_name,
document_encoder_path=config.document_encoder_model_name,
data_path="lambdaofgod/pwc_repositories_with_dependencies",
)