File size: 2,814 Bytes
d40f2bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import faiss
import pickle
import datasets
import numpy as np
import requests
import streamlit as st
from vector_engine.utils import vector_search
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
@st.cache
def read_data(dataset_repo='dhmeltzer/asks_validation_embedded'):
"""Read the data from huggingface."""
return load_dataset(dataset_repo)
#@st.cache(allow_output_mutation=True)
#def load_bert_model(name="nli-distilbert-base"):
# """Instantiate a sentence-level DistilBERT model."""
# return AutoModel.from_pretrained(f'sentence-transformers/{name}')
#
#@st.cache(allow_output_mutation=True)
#def load_tokenizer(name="nli-distilbert-base"):
# return AutoTokenizer.from_pretrained(f'sentence-transformers/{name}')
@st.cache(allow_output_mutation=True)
def load_faiss_index(path_to_faiss="./faiss_index_small.pickle"):
"""Load and deserialize the Faiss index."""
with open(path_to_faiss, "rb") as h:
data = pickle.load(h)
return faiss.deserialize_index(data)
def main():
# Load data and models
data = read_data()
#model = load_bert_model()
#tok = load_tokenizer()
faiss_index = load_faiss_index()
import requests
model_id="sentence-transformers/nli-distilbert-base"
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
headers = {"Authorization": "Bearer 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"}
def query(texts):
response = requests.post(api_url, headers=headers, json={"inputs": texts, "options":{"wait_for_model":True}})
return response.json()
st.title("Vector-based searches with Sentence Transformers and Faiss")
# User search
user_input = st.text_area("Search box", "ELI5 Dataset")
# Filters
st.sidebar.markdown("**Filters**")
filter_scores = st.sidebar.slider("Citations", 0, 250, 0)
num_results = st.sidebar.slider("Number of search results", 1, 50, 1)
vector = query([user_input])
# Fetch results
if user_input:
# Get paper IDs
_, I = faiss_index.search(np.array(vector).astype("float32"), k=num_results)
#D, I = vector_search([user_input],tok, model, faiss_index, num_results)
# Slice data on year
#frame = data[
# (data.scores >= filter_scores)
#]
frame = data
st.write(user_input)
# Get individual results
for id_ in I.flatten().tolist():
f = frame[id_]
#if id_ in set(frame.id):
# f = frame[(frame.id == id_)]
#else:
# continue
st.write(
f"""**{f['title']}**
**text**: {f['selftext']}
"""
)
if __name__ == "__main__":
main()
|