ramGPT / searchEmbeddings.py
Xia
1st version
c891946
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 19 14:45:37 2023
@author: Hua
"""
import pandas as pd
import json
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
import numpy as np
df = pd.read_csv('RAMEmbeddings.csv')
# load gpt2 embeddings
gpt2bds = df.GPT2Embeddings # get a pd.Series
gpt2list = [np.float32(np.array(json.loads(i))) for i in gpt2bds] # list of embeddings
# define the search function
def search(inputs):
# GPT2 embedding
gpt2_model = SentenceTransformer('sembeddings/model_gpt_trained')
embeddings = gpt2_model.encode(inputs)
# calculate the similarity list to a given embedding
sims = []
for i in range(len(gpt2list)):
sim = util.pytorch_cos_sim(embeddings, gpt2list[i])
sims.append(sim.item())
# find the top-5 similarity items
sims_arr = np.array(sims, dtype=object)
inds = np.argpartition(sims_arr, -5)[-5:]
# return top 5 items
return df.loc[inds].reset_index(drop=True)