Spaces:
Build error
Build error
import logging | |
import time | |
from pathlib import Path | |
import lancedb | |
from sentence_transformers import SentenceTransformer | |
import spaces | |
# Setting up the logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Start the timer for loading the VectorDB | |
start_time = time.perf_counter() | |
proj_dir = Path(__file__).parents[1] | |
# Log the time taken to load the VectorDB | |
db = lancedb.connect(proj_dir / "lancedb") | |
tbl = db.open_table('arabic-wiki') | |
lancedb_loading_time = time.perf_counter() - start_time | |
logger.info(f"Time taken to load LanceDB: {lancedb_loading_time:.6f} seconds") | |
# Start the timer for loading the EmbeddingRetriever | |
start_time = time.perf_counter() | |
name = "sentence-transformers/paraphrase-multilingual-minilm-l12-v2" | |
st_model_gpu = SentenceTransformer(name, device='cuda') | |
st_model_cpu = SentenceTransformer(name, device='cpu') | |
# used for both training and querying | |
def call_embed_func(query): | |
try: | |
return embed_func(query) | |
except: | |
logger.warning(f'Using CPU') | |
return st_model_cpu.encode(query) | |
def embed_func(query): | |
return st_model_gpu.encode(query) | |
def vector_search(query_vector, top_k): | |
return tbl.search(query_vector).limit(top_k).to_list() | |
def retriever(query, top_k=3): | |
query_vector = call_embed_func(query) | |
documents = vector_search(query_vector, top_k) | |
return documents | |
# Log the time taken to load the EmbeddingRetriever | |
retriever_loading_time = time.perf_counter() - start_time | |
logger.info(f"Time taken to load EmbeddingRetriever: {retriever_loading_time:.6f} seconds") | |