Ahmadzei commited on
Commit
8ea3b28
1 Parent(s): 114d969

added reranking to semantic search

Browse files
Files changed (2) hide show
  1. .env +1 -0
  2. backend/semantic_search.py +28 -2
.env CHANGED
@@ -3,6 +3,7 @@ export EMB_MODEL=sentence-transformers/all-MiniLM-L6-v2
3
  export TOP_K=5
4
  export HF_MODEL=mistralai/Mistral-7B-Instruct-v0.2
5
  export OPENAI_MODEL=gpt-4-turbo-preview
 
6
 
7
  #### SECRETS ####
8
  export OPENAI_API_KEY=
 
3
  export TOP_K=5
4
  export HF_MODEL=mistralai/Mistral-7B-Instruct-v0.2
5
  export OPENAI_MODEL=gpt-4-turbo-preview
6
+ export CROSS_ENC_MODEL=cross-encoder/ms-marco-MiniLM-L-6-v2
7
 
8
  #### SECRETS ####
9
  export OPENAI_API_KEY=
backend/semantic_search.py CHANGED
@@ -2,10 +2,25 @@ import lancedb
2
  import os
3
  import gradio as gr
4
  from sentence_transformers import SentenceTransformer
 
 
5
 
 
 
 
 
 
 
 
6
 
7
- db = lancedb.connect(".lancedb")
 
 
 
 
 
8
 
 
9
  TABLE = db.open_table(os.getenv("TABLE_NAME"))
10
  VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
11
  TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
@@ -13,13 +28,24 @@ BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
13
 
14
  retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
15
 
 
 
 
 
 
 
 
16
 
17
- def retrieve(query, k):
18
  query_vec = retriever.encode(query)
19
  try:
20
  documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
21
  documents = [doc[TEXT_COLUMN] for doc in documents]
22
 
 
 
 
 
23
  return documents
24
 
25
  except Exception as e:
 
2
  import os
3
  import gradio as gr
4
  from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ import torch
7
 
8
+ # For Text Similarity and Relevance Ranking:
9
+ # valhalla/distilbart-mnli-12-3
10
+ # sentence-transformers/cross-encoder/stsb-roberta-large
11
+ #
12
+ # For Question Answering:
13
+ # deepset/roberta-base-squad2
14
+ # cross-encoder/quora-distilroberta-base
15
 
16
+ CROSS_ENC_MODEL = os.getenv("CROSS_ENC_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
17
+
18
+ # Initialize the tokenizer and model for reranking
19
+ tokenizer = AutoTokenizer.from_pretrained(CROSS_ENC_MODEL)
20
+ cross_encoder = AutoModelForSequenceClassification.from_pretrained(CROSS_ENC_MODEL)
21
+ cross_encoder.eval() # Put model in evaluation mode
22
 
23
+ db = lancedb.connect(".lancedb")
24
  TABLE = db.open_table(os.getenv("TABLE_NAME"))
25
  VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
26
  TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
 
28
 
29
  retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
30
 
31
+ def rerank(query, documents):
32
+ pairs = [[query, doc] for doc in documents] # Create pairs of query and each document
33
+ inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors="pt")
34
+ with torch.no_grad():
35
+ scores = cross_encoder(**inputs).logits.squeeze() # Get scores for each pair
36
+ sorted_docs = [doc for _, doc in sorted(zip(scores, documents), key=lambda x: x[0], reverse=True)]
37
+ return sorted_docs
38
 
39
+ def retrieve(query, k, rr=True):
40
  query_vec = retriever.encode(query)
41
  try:
42
  documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
43
  documents = [doc[TEXT_COLUMN] for doc in documents]
44
 
45
+ # Rerank the retrieved documents if rr (rerank) is True
46
+ if rr:
47
+ documents = rerank(query, documents)
48
+
49
  return documents
50
 
51
  except Exception as e: