File size: 3,578 Bytes
0e05863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import json
from huggingface_hub import HfApi, ModelFilter, DatasetFilter, ModelSearchArguments
from pprint import pprint
from hf_search import hf_search
import streamlit as st
import itertools

@st.cache
def hf_api(query, limit=5, filters={}):
    print("query", query)
    print("filters", filters)
    print("limit", limit)

    api = HfApi()
    filt = ModelFilter(
        task=filters["task"],
        library=filters["library"],
    )
    models = api.list_models(search=query, filter=filt, limit=limit, full=True)
    hits = []
    for model in models:
        model = model.__dict__
        hits.append(
            {
                "modelId": model.get("modelId"),
                "tags": model.get("tags"),
                "downloads": model.get("downloads"),
                "likes": model.get("likes"),
            }
        )
    count = len(hits)
    if len(hits) > limit:
        hits = hits[:limit]
    pprint(hits)
    return {"hits": hits, "count": count}


@st.cache
def semantic_search(query, limit=5, filters={}):
    print("query", query)
    print("filters", filters)
    print("limit", limit)

    hits = hf_search(query=query, method="retrieve & rerank", limit=limit, filters=filters)
    hits = [
        {
            "modelId": hit["modelId"],
            "tags": hit["tags"],
            "downloads": hit["downloads"],
            "likes": hit["likes"],
            "readme": hit.get("readme", None),
        }
        for hit in hits
    ]
    return {"hits": hits, "count": len(hits)}


@st.cache
def bm25_search(query, limit=5, filters={}):
    print("query", query)
    print("filters", filters)
    print("limit", limit)

    # TODO: filters
    hits = hf_search(query=query, method="bm25", limit=limit)
    hits = [
        {
            "modelId": hit["modelId"],
            "tags": hit["tags"],
            "downloads": hit["downloads"],
            "likes": hit["likes"],
            "readme": hit.get("readme", None),
        }
        for hit in hits
    ]
    hits = [
        hits[i] for i in range(len(hits)) if hits[i]["modelId"] not in [h["modelId"] for h in hits[:i]]
    ]  # unique hits
    return {"hits": hits, "count": len(hits)}


def paginator(label, articles, articles_per_page=10, on_sidebar=True):
    # https://gist.github.com/treuille/2ce0acb6697f205e44e3e0f576e810b7
    """Lets the user paginate a set of article.
    Parameters
    ----------
    label : str
        The label to display over the pagination widget.
    article : Iterator[Any]
        The articles to display in the paginator.
    articles_per_page: int
        The number of articles to display per page.
    on_sidebar: bool
        Whether to display the paginator widget on the sidebar.

    Returns
    -------
    Iterator[Tuple[int, Any]]
        An iterator over *only the article on that page*, including
        the item's index.
    """

    # Figure out where to display the paginator
    if on_sidebar:
        location = st.sidebar.empty()
    else:
        location = st.empty()

    # Display a pagination selectbox in the specified location.
    articles = list(articles)
    n_pages = (len(articles) - 1) // articles_per_page + 1
    page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}"
    page_number = location.selectbox(label, range(n_pages), format_func=page_format_func)

    # Iterate over the articles in the page to let the user display them.
    min_index = page_number * articles_per_page
    max_index = min_index + articles_per_page

    return itertools.islice(enumerate(articles), min_index, max_index)