Spaces:
Runtime error
Runtime error
File size: 4,386 Bytes
568499b a284f57 568499b a284f57 568499b a284f57 568499b a284f57 568499b 655f181 34c16bd 655f181 34c16bd 655f181 34c16bd a284f57 34c16bd c5a2694 a284f57 568499b a284f57 568499b a284f57 568499b a284f57 568499b a284f57 568499b a284f57 568499b a284f57 568499b 34c16bd 568499b a284f57 568499b a284f57 568499b a284f57 568499b a284f57 568499b a284f57 568499b a284f57 568499b a284f57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from typing import Dict, List
import torch
import pandas as pd
import streamlit as st
from findkit import retrieval_pipeline
import config
from search_utils import (
RetrievalPipelineWrapper,
get_doc_cols,
get_repos_with_descriptions,
get_retrieval_df,
merge_cols,
)
class RetrievalApp:
def is_cuda_available(self):
try:
torch._C._cuda_init()
except:
return False
return True
def get_device_options(self):
if self.is_cuda_available():
return ["cuda", "cpu"]
else:
return ["cpu"]
@st.cache(allow_output_mutation=True)
def get_retrieval_df(self):
return get_retrieval_df(self.data_path, config.text_list_cols)
def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
self.data_path = data_path
self.device = st.sidebar.selectbox("device", self.get_device_options())
print("loading data")
self.retrieval_df = self.get_retrieval_df().copy()
model_name = st.sidebar.selectbox("model", config.model_names)
self.query_encoder_name = "lambdaofgod/query-" + model_name
self.document_encoder_name = "lambdaofgod/document-" + model_name
doc_cols = get_doc_cols(model_name)
st.sidebar.text("using models")
st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
st.sidebar.text("HTTP://huggingface.co/" + self.document_encoder_name)
self.additional_shown_cols = st.sidebar.multiselect(
label="used text features", options=config.text_cols, default=doc_cols
)
@staticmethod
def show_retrieval_results(
retrieval_pipe: RetrievalPipelineWrapper,
query: str,
k: int,
all_queries: List[str],
description_length: int,
repos_by_query: Dict[str, pd.DataFrame],
additional_shown_cols: List[str],
):
print("started retrieval")
if query in all_queries:
with st.expander(
"query is in gold standard set queries. Toggle viewing gold standard results?"
):
st.write("gold standard results")
task_repos = repos_by_query.get_group(query)
st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
with st.spinner(text="fetching results"):
st.write(
retrieval_pipe.search(
query, k, description_length, additional_shown_cols
).to_html(escape=False, index=False),
unsafe_allow_html=True,
)
print("finished retrieval")
def run_app(self, retrieval_pipeline):
retrieved_results = st.sidebar.number_input("number of results", value=10)
description_length = st.sidebar.number_input(
"number of used description words", value=10
)
tasks_deduped = (
self.retrieval_df["tasks"].explode().value_counts().reset_index()
) # drop_duplicates().sort_values().reset_index(drop=True)
tasks_deduped.columns = ["task", "documents per task"]
with st.sidebar.expander("View test set queries"):
st.table(tasks_deduped.explode("task"))
repos_by_query = self.retrieval_df.explode("tasks").groupby("tasks")
query = st.text_input("input query", value="metric learning")
RetrievalApp.show_retrieval_results(
retrieval_pipeline,
query,
retrieved_results,
tasks_deduped["task"].to_list(),
description_length,
repos_by_query,
self.additional_shown_cols,
)
@st.cache(allow_output_mutation=True)
def get_retrieval_pipeline(self, displayed_retrieval_df):
return RetrievalPipelineWrapper.setup_from_encoder_names(
self.query_encoder_name,
self.document_encoder_name,
displayed_retrieval_df["document"],
displayed_retrieval_df,
device=self.device,
)
def main(self):
print("setting up retrieval_pipe")
displayed_retrieval_df = merge_cols(
self.retrieval_df.copy(), self.additional_shown_cols
)
retrieval_pipeline = self.get_retrieval_pipeline(displayed_retrieval_df)
self.run_app(retrieval_pipeline)
|