import os import streamlit as st from elasticsearch import Elasticsearch from embedders.labse import LaBSE def search(): status_indicator.write(f"Loading model {model_name}...") model = globals()[model_name]() status_indicator.write(f"Computing query embeddings...") query_vector = model(query)[0, :].tolist() status_indicator.write(f"Performing query...") target_field = f"{model_name}_features" results = es.search( index="sentences", query={ "script_score": { "query": {"match_all": {}}, "script": { "source": f"cosineSimilarity(params.query_vector, '{target_field}') + 1.0", "params": {"query_vector": query_vector} } } }, size=limit ) for result in results["hits"]["hits"]: sentence = result['_source']['sentence'] score = result['_score'] document = result['_source']['document'] number = result['_source']['number'] previous = es.search( index="sentences", query={ "bool": { "must": [{ "term": { "document": document } },{ "range": { "number": { "gte": number-3, "lt": number, } } } ] } } ) previous_hits = sorted(previous["hits"]["hits"], key=lambda e: e["_source"]["number"]) previous_context = "".join([r["_source"]["sentence"] for r in previous_hits]) subsequent = es.search( index="sentences", query={ "bool": { "must": [{ "term": { "document": document } },{ "range": { "number": { "lte": number+3, "gt": number, } } } ] } } ) subsequent_hits = sorted(subsequent["hits"]["hits"], key=lambda e: e["_source"]["number"]) subsequent_context = "".join([r["_source"]["sentence"] for r in subsequent_hits]) document_name_results = es.search( index="documents", query={ "bool": { "must": [{ "term": { "id": document } } ] } } ) document_name_data = document_name_results["hits"]["hits"][0]["_source"] document_name = f"{document_name_data['title']} - {document_name_data['author']}" results_placeholder.markdown(f"#### {document_name} (score: {score:.2f})\n{previous_context} **{sentence}** {subsequent_context}") status_indicator.write(f"Results ready...") es = Elasticsearch(os.environ["ELASTIC_HOST"], basic_auth=os.environ["ELASTIC_AUTH"].split(":")) st.header("Serica Semantic Search") st.write("Perform a semantic search using a Sentence Embedding Transformer model on the SERICA database") model_name = st.selectbox("Model", ["LaBSE"]) limit = st.number_input("Number of results", 10) query = st.text_input("Query", value="") status_indicator = st.empty() do_search = st.button("Search") results_placeholder = st.container() if do_search: search()