import faiss import pickle import datasets import numpy as np import requests import streamlit as st from vector_engine.utils import vector_search from transformers import AutoModel, AutoTokenizer from datasets import load_dataset @st.cache def read_data(dataset_repo='dhmeltzer/asks_validation_embedded'): """Read the data from huggingface.""" return load_dataset(dataset_repo) #@st.cache(allow_output_mutation=True) #def load_bert_model(name="nli-distilbert-base"): # """Instantiate a sentence-level DistilBERT model.""" # return AutoModel.from_pretrained(f'sentence-transformers/{name}') # #@st.cache(allow_output_mutation=True) #def load_tokenizer(name="nli-distilbert-base"): # return AutoTokenizer.from_pretrained(f'sentence-transformers/{name}') @st.cache(allow_output_mutation=True) def load_faiss_index(path_to_faiss="./faiss_index_small.pickle"): """Load and deserialize the Faiss index.""" with open(path_to_faiss, "rb") as h: data = pickle.load(h) return faiss.deserialize_index(data) def main(): # Load data and models data = read_data() #model = load_bert_model() #tok = load_tokenizer() faiss_index = load_faiss_index() import requests model_id="sentence-transformers/nli-distilbert-base" api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" headers = {"Authorization": "Bearer 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"} def query(texts): response = requests.post(api_url, headers=headers, json={"inputs": texts, "options":{"wait_for_model":True}}) return response.json() st.title("Vector-based searches with Sentence Transformers and Faiss") # User search user_input = st.text_area("Search box", "ELI5 Dataset") # Filters st.sidebar.markdown("**Filters**") filter_scores = st.sidebar.slider("Citations", 0, 250, 0) num_results = st.sidebar.slider("Number of search results", 1, 50, 1) vector = query([user_input]) # Fetch results if user_input: # Get paper IDs _, I = faiss_index.search(np.array(vector).astype("float32"), k=num_results) #D, I = vector_search([user_input],tok, model, faiss_index, num_results) # Slice data on year #frame = data[ # (data.scores >= filter_scores) #] frame = data st.write(user_input) # Get individual results for id_ in I.flatten().tolist(): f = frame[id_] #if id_ in set(frame.id): # f = frame[(frame.id == id_)] #else: # continue st.write( f"""**{f['title']}** **text**: {f['selftext']} """ ) if __name__ == "__main__": main()