opensearchspace / automm_semantic_embedding.py
suzhoum's picture
wip
5e17fcf
import ir_datasets
import pandas as pd
from autogluon.multimodal import MultiModalPredictor
dataset = ir_datasets.load("beir/fiqa/dev")
dataset = ir_datasets.load("beir/fiqa/dev")
docs_df = pd.DataFrame(dataset.docs_iter()).set_index("doc_id").sample(frac=0.0001)
query_df = pd.DataFrame(dataset.queries_iter()).set_index("query_id")
model_name = "sentence-transformers/all-MiniLM-L6-v2"
predictor = MultiModalPredictor(
pipeline="feature_extraction",
hyperparameters={
"model.hf_text.checkpoint_name": model_name
}
)
document_embedding = predictor.extract_embedding(docs_df)
query = "What happened when the dot com bubble burst?"
query_embedding = predictor.extract_embedding([query])
import numpy as np
q_norm = query_embedding['0'] / np.linalg.norm(query_embedding['0'], axis=-1, keepdims=True)
d_norm = document_embedding['text'] / np.linalg.norm(document_embedding['text'], axis=-1, keepdims=True)
scores = d_norm.dot(q_norm[0])
print(f'Question: {query}')
print()
for idx in np.argsort(-scores)[:2]:
print(f'Top {idx} result:')
print('-----------------')
print(docs_df['text'].iloc[idx])
print()