File size: 4,386 Bytes
568499b
 
a284f57
568499b
 
 
a284f57
 
568499b
a284f57
 
568499b
a284f57
 
568499b
 
 
 
655f181
34c16bd
 
655f181
34c16bd
 
655f181
34c16bd
a284f57
34c16bd
c5a2694
a284f57
 
 
 
 
 
 
568499b
a284f57
 
568499b
 
a284f57
568499b
 
 
 
 
a284f57
 
568499b
 
a284f57
 
 
 
 
568499b
 
 
a284f57
568499b
 
 
 
 
a284f57
568499b
 
 
 
 
 
 
 
 
 
 
34c16bd
 
 
568499b
 
 
 
a284f57
568499b
 
 
 
 
 
 
a284f57
568499b
 
 
 
a284f57
568499b
 
 
 
 
 
 
 
a284f57
568499b
 
a284f57
 
 
568499b
 
a284f57
 
 
 
 
 
 
 
 
568499b
a284f57
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Dict, List

import torch
import pandas as pd
import streamlit as st
from findkit import retrieval_pipeline

import config
from search_utils import (
    RetrievalPipelineWrapper,
    get_doc_cols,
    get_repos_with_descriptions,
    get_retrieval_df,
    merge_cols,
)


class RetrievalApp:

    def is_cuda_available(self):
        try:
            torch._C._cuda_init()
        except:
            return False
        return True

    def get_device_options(self):
        if self.is_cuda_available():
            return ["cuda", "cpu"]
        else:
            return ["cpu"]

    @st.cache(allow_output_mutation=True)
    def get_retrieval_df(self):
        return get_retrieval_df(self.data_path, config.text_list_cols)

    def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
        self.data_path = data_path
        self.device = st.sidebar.selectbox("device", self.get_device_options())
        print("loading data")

        self.retrieval_df = self.get_retrieval_df().copy()

        model_name = st.sidebar.selectbox("model", config.model_names)
        self.query_encoder_name = "lambdaofgod/query-" + model_name
        self.document_encoder_name = "lambdaofgod/document-" + model_name

        doc_cols = get_doc_cols(model_name)

        st.sidebar.text("using models")
        st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
        st.sidebar.text("HTTP://huggingface.co/" + self.document_encoder_name)

        self.additional_shown_cols = st.sidebar.multiselect(
            label="used text features", options=config.text_cols, default=doc_cols
        )

    @staticmethod
    def show_retrieval_results(
        retrieval_pipe: RetrievalPipelineWrapper,
        query: str,
        k: int,
        all_queries: List[str],
        description_length: int,
        repos_by_query: Dict[str, pd.DataFrame],
        additional_shown_cols: List[str],
    ):
        print("started retrieval")
        if query in all_queries:
            with st.expander(
                "query is in gold standard set queries. Toggle viewing gold standard results?"
            ):
                st.write("gold standard results")
                task_repos = repos_by_query.get_group(query)
                st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
        with st.spinner(text="fetching results"):
            st.write(
                retrieval_pipe.search(
                    query, k, description_length, additional_shown_cols
                ).to_html(escape=False, index=False),
                unsafe_allow_html=True,
            )
        print("finished retrieval")

    def run_app(self, retrieval_pipeline):

        retrieved_results = st.sidebar.number_input("number of results", value=10)
        description_length = st.sidebar.number_input(
            "number of used description words", value=10
        )

        tasks_deduped = (
            self.retrieval_df["tasks"].explode().value_counts().reset_index()
        )  # drop_duplicates().sort_values().reset_index(drop=True)
        tasks_deduped.columns = ["task", "documents per task"]
        with st.sidebar.expander("View test set queries"):
            st.table(tasks_deduped.explode("task"))
        repos_by_query = self.retrieval_df.explode("tasks").groupby("tasks")
        query = st.text_input("input query", value="metric learning")
        RetrievalApp.show_retrieval_results(
            retrieval_pipeline,
            query,
            retrieved_results,
            tasks_deduped["task"].to_list(),
            description_length,
            repos_by_query,
            self.additional_shown_cols,
        )

    @st.cache(allow_output_mutation=True)
    def get_retrieval_pipeline(self, displayed_retrieval_df):
        return RetrievalPipelineWrapper.setup_from_encoder_names(
            self.query_encoder_name,
            self.document_encoder_name,
            displayed_retrieval_df["document"],
            displayed_retrieval_df,
            device=self.device,
        )

    def main(self):
        print("setting up retrieval_pipe")
        displayed_retrieval_df = merge_cols(
            self.retrieval_df.copy(), self.additional_shown_cols
        )
        retrieval_pipeline = self.get_retrieval_pipeline(displayed_retrieval_df)
        self.run_app(retrieval_pipeline)