Spaces:
Runtime error
Runtime error
import os | |
from typing import Dict, List | |
import datasets | |
import pandas as pd | |
import sentence_transformers | |
import streamlit as st | |
from findkit import feature_extractors, indexes, retrieval_pipeline | |
from toolz import partial | |
import config | |
def truncate_description(description, length=50): | |
return " ".join(description.split()[:length]) | |
def get_repos_with_descriptions(repos_df, repos): | |
return repos_df.loc[repos] | |
def search_f( | |
retrieval_pipe: retrieval_pipeline.RetrievalPipeline, | |
query: str, | |
k: int, | |
description_length: int, | |
doc_col: List[str], | |
): | |
results = retrieval_pipe.find_similar(query, k) | |
# results['repo'] = results.index | |
results["link"] = "https://github.com/" + results["repo"] | |
for col in doc_col: | |
results[col] = results[col].apply( | |
lambda desc: truncate_description(desc, description_length) | |
) | |
shown_cols = ["repo", "tasks", "link", "distance"] | |
shown_cols = shown_cols + doc_col | |
return results.reset_index(drop=True)[shown_cols] | |
def show_retrieval_results( | |
retrieval_pipe: retrieval_pipeline.RetrievalPipeline, | |
query: str, | |
k: int, | |
all_queries: List[str], | |
description_length: int, | |
repos_by_query: Dict[str, pd.DataFrame], | |
doc_col: str, | |
): | |
print("started retrieval") | |
if query in all_queries: | |
with st.expander( | |
"query is in gold standard set queries. Toggle viewing gold standard results?" | |
): | |
st.write("gold standard results") | |
task_repos = repos_by_query.get_group(query) | |
st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos)) | |
with st.spinner(text="fetching results"): | |
st.write( | |
search_f(retrieval_pipe, query, k, description_length, doc_col).to_html( | |
escape=False, index=False | |
), | |
unsafe_allow_html=True, | |
) | |
print("finished retrieval") | |
def setup_pipeline( | |
extractor: feature_extractors.SentenceEncoderFeatureExtractor, | |
documents_df: pd.DataFrame, | |
text_col: str, | |
): | |
retrieval_pipeline.RetrievalPipelineFactory.build( | |
documents_df[text_col], metadata=documents_df | |
) | |
def setup_retrieval_pipeline( | |
query_encoder_path, document_encoder_path, documents, metadata | |
): | |
document_encoder = feature_extractors.SentenceEncoderFeatureExtractor( | |
sentence_transformers.SentenceTransformer(document_encoder_path, device="cpu") | |
) | |
query_encoder = feature_extractors.SentenceEncoderFeatureExtractor( | |
sentence_transformers.SentenceTransformer(query_encoder_path, device="cpu") | |
) | |
retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory( | |
feature_extractor=document_encoder, | |
query_feature_extractor=query_encoder, | |
index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"), | |
) | |
return retrieval_pipe.build(documents, metadata=metadata) | |
def app(retrieval_pipeline, retrieval_df, doc_col): | |
retrieved_results = st.sidebar.number_input("number of results", value=10) | |
description_length = st.sidebar.number_input( | |
"number of used description words", value=10 | |
) | |
tasks_deduped = ( | |
retrieval_df["tasks"].explode().value_counts().reset_index() | |
) # drop_duplicates().sort_values().reset_index(drop=True) | |
tasks_deduped.columns = ["task", "documents per task"] | |
with st.sidebar.expander("View test set queries"): | |
st.table(tasks_deduped.explode("task")) | |
additional_shown_cols = st.sidebar.multiselect( | |
label="additional cols", options=[doc_col], default=doc_col | |
) | |
repos_by_query = retrieval_df.explode("tasks").groupby("tasks") | |
query = st.text_input("input query", value="metric learning") | |
show_retrieval_results( | |
retrieval_pipeline, | |
query, | |
retrieved_results, | |
tasks_deduped["task"].to_list(), | |
description_length, | |
repos_by_query, | |
additional_shown_cols, | |
) | |
def app_main( | |
query_encoder_path, | |
document_encoder_path, | |
data_path, | |
): | |
print("loading data") | |
retrieval_df = ( | |
datasets.load_dataset(data_path)["train"] | |
.to_pandas() | |
.drop_duplicates(subset=["repo"]) | |
.reset_index(drop=True) | |
) | |
print("setting up retrieval_pipe") | |
doc_col = "dependencies" | |
retrieval_pipeline = setup_retrieval_pipeline( | |
query_encoder_path, document_encoder_path, retrieval_df[doc_col], retrieval_df | |
) | |
app(retrieval_pipeline, retrieval_df, doc_col) | |
app_main( | |
query_encoder_path=config.query_encoder_model_name, | |
document_encoder_path=config.document_encoder_model_name, | |
data_path="lambdaofgod/pwc_repositories_with_dependencies", | |
) | |