Spaces:
Runtime error
Runtime error
from typing import Dict, List | |
import torch | |
import pandas as pd | |
import streamlit as st | |
from findkit import retrieval_pipeline | |
import config | |
from search_utils import ( | |
RetrievalPipelineWrapper, | |
get_doc_cols, | |
get_repos_with_descriptions, | |
get_retrieval_df, | |
merge_cols, | |
) | |
class RetrievalApp: | |
def is_cuda_available(self): | |
try: | |
torch._C._cuda_init() | |
except: | |
return False | |
return True | |
def get_device_options(self): | |
if self.is_cuda_available(): | |
return ["cuda", "cpu"] | |
else: | |
return ["cpu"] | |
def get_retrieval_df(self): | |
return get_retrieval_df(self.data_path, config.text_list_cols) | |
def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"): | |
self.data_path = data_path | |
self.device = st.sidebar.selectbox("device", self.get_device_options()) | |
print("loading data") | |
self.retrieval_df = self.get_retrieval_df().copy() | |
model_name = st.sidebar.selectbox("model", config.model_names) | |
self.query_encoder_name = "lambdaofgod/query-" + model_name | |
self.document_encoder_name = "lambdaofgod/document-" + model_name | |
doc_cols = get_doc_cols(model_name) | |
st.sidebar.text("using models") | |
st.sidebar.text("https://huggingface.co/" + self.query_encoder_name) | |
st.sidebar.text("HTTP://huggingface.co/" + self.document_encoder_name) | |
self.additional_shown_cols = st.sidebar.multiselect( | |
label="used text features", options=config.text_cols, default=doc_cols | |
) | |
def show_retrieval_results( | |
retrieval_pipe: RetrievalPipelineWrapper, | |
query: str, | |
k: int, | |
all_queries: List[str], | |
description_length: int, | |
repos_by_query: Dict[str, pd.DataFrame], | |
additional_shown_cols: List[str], | |
): | |
print("started retrieval") | |
if query in all_queries: | |
with st.expander( | |
"query is in gold standard set queries. Toggle viewing gold standard results?" | |
): | |
st.write("gold standard results") | |
task_repos = repos_by_query.get_group(query) | |
st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos)) | |
with st.spinner(text="fetching results"): | |
st.write( | |
retrieval_pipe.search( | |
query, k, description_length, additional_shown_cols | |
).to_html(escape=False, index=False), | |
unsafe_allow_html=True, | |
) | |
print("finished retrieval") | |
def run_app(self, retrieval_pipeline): | |
retrieved_results = st.sidebar.number_input("number of results", value=10) | |
description_length = st.sidebar.number_input( | |
"number of used description words", value=10 | |
) | |
tasks_deduped = ( | |
self.retrieval_df["tasks"].explode().value_counts().reset_index() | |
) # drop_duplicates().sort_values().reset_index(drop=True) | |
tasks_deduped.columns = ["task", "documents per task"] | |
with st.sidebar.expander("View test set queries"): | |
st.table(tasks_deduped.explode("task")) | |
repos_by_query = self.retrieval_df.explode("tasks").groupby("tasks") | |
query = st.text_input("input query", value="metric learning") | |
RetrievalApp.show_retrieval_results( | |
retrieval_pipeline, | |
query, | |
retrieved_results, | |
tasks_deduped["task"].to_list(), | |
description_length, | |
repos_by_query, | |
self.additional_shown_cols, | |
) | |
def get_retrieval_pipeline(self, displayed_retrieval_df): | |
return RetrievalPipelineWrapper.setup_from_encoder_names( | |
self.query_encoder_name, | |
self.document_encoder_name, | |
displayed_retrieval_df["document"], | |
displayed_retrieval_df, | |
device=self.device, | |
) | |
def main(self): | |
print("setting up retrieval_pipe") | |
displayed_retrieval_df = merge_cols( | |
self.retrieval_df.copy(), self.additional_shown_cols | |
) | |
retrieval_pipeline = self.get_retrieval_pipeline(displayed_retrieval_df) | |
self.run_app(retrieval_pipeline) | |