File size: 1,630 Bytes
8b15eea
 
9a4e478
8b15eea
 
 
9a4e478
8ea4cbd
8b15eea
 
 
 
 
 
 
 
 
 
 
 
9a4e478
8b15eea
 
 
 
 
 
 
9a4e478
ed7d1fc
9a4e478
 
8b15eea
 
9a4e478
 
 
 
 
 
 
 
8ea4cbd
8b15eea
9a4e478
 
8b15eea
 
 
 
9a4e478
8b15eea
9a4e478
8b15eea
 
 
 
 
 
9a4e478
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import logging
import time
from pathlib import Path

import lancedb
from sentence_transformers import SentenceTransformer

import spaces


# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Start the timer for loading the QdrantDocumentStore
start_time = time.perf_counter()

proj_dir = Path(__file__).parents[1]

# Log the time taken to load the QdrantDocumentStore
db = lancedb.connect(proj_dir / "lancedb")
tbl = db.open_table('arabic-wiki')
lancedb_loading_time = time.perf_counter() - start_time
logger.info(f"Time taken to load LanceDB: {lancedb_loading_time:.6f} seconds")

# Start the timer for loading the EmbeddingRetriever
start_time = time.perf_counter()

name = "sentence-transformers/paraphrase-multilingual-minilm-l12-v2"
st_model_gpu = SentenceTransformer(name, device='cuda')
st_model_cpu = SentenceTransformer(name, device='cpu')


# used for both training and querying
def call_embed_func(query):
    try:
        return embed_func(query)
    except:
        logger.warning(f'Using CPU')
        return st_model_cpu.encode(query)


@spaces.GPU
def embed_func(query):
    return st_model_gpu.encode(query)


def vector_search(query_vector, top_k):
    return tbl.search(query_vector).limit(top_k).to_list()


def retriever(query, top_k=3):
    query_vector = call_embed_func(query)
    documents = vector_search(query_vector, top_k)
    return documents


# Log the time taken to load the EmbeddingRetriever
retriever_loading_time = time.perf_counter() - start_time
logger.info(f"Time taken to load EmbeddingRetriever: {retriever_loading_time:.6f} seconds")