File size: 3,049 Bytes
8b54370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from transformers import AutoConfig
from sentence_transformers import SentenceTransformer
import lancedb
import torch
import pyarrow as pa
import pandas as pd
import numpy as np
import tqdm

class VectorDB:

    vector_column = "vector"
    description_column = "description"
    name_column = "name"
    table_name = "pimcore_actions"
    emb_model = ''
    db_location = ''

    def __init__(self, emb_model, db_location, actions_list_file_path, num_sub_vectors, batch_size):
      self.emb_model = emb_model
      self.db_location = db_location

      emb_config = AutoConfig.from_pretrained(emb_model)
      emb_dimension = emb_config.hidden_size

      assert emb_dimension % num_sub_vectors == 0, \
        "Embedding size must be divisible by the num of sub vectors"

      print('Model loaded...')
      print(emb_model)

      model = SentenceTransformer(emb_model)
      model.eval()

      if torch.backends.mps.is_available():
          device = "mps"
      elif torch.cuda.is_available():
          device = "cuda"
      else:
          device = "cpu"

      print(f"Device: {device}")

      db = lancedb.connect(db_location)

      schema = pa.schema(
        [
            pa.field(self.vector_column, pa.list_(pa.float32(), emb_dimension)),
            pa.field(self.description_column, pa.string()),
            pa.field(self.name_column, pa.string())
        ]
      )
      tbl = db.create_table(self.table_name, schema=schema, mode="overwrite")


      df = pd.read_csv(actions_list_file_path)
      sentences = df.values

      print("Starting vector generation")
      for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))):
          try:
              batch = [sent for sent in sentences[i * batch_size:(i + 1) * batch_size] if len(sent) > 0]

              to_encode = [entry[1] for entry in batch]
              names = [entry[0] for entry in batch]

              encoded = model.encode(to_encode, normalize_embeddings=True, device=device)
              encoded = [list(vec) for vec in encoded]

              df = pd.DataFrame({
                  self.vector_column: encoded,
                  self.description_column: to_encode,
                  self.name_column: names
              })

              tbl.add(df)
          except:
              print(f"batch {i} was skipped")
      print("Vector generation done.")


    def get_embedding_db_as_pandas(self):
        db = lancedb.connect(self.db_location)
        tbl = db.open_table(self.table_name)
        return tbl.to_pandas()



    def retrieve_prefiltered_hits(self, query, k):
        db = lancedb.connect(".lancedb")
        table = db.open_table(self.table_name)
        retriever = SentenceTransformer(self.emb_model)

        query_vec = retriever.encode(query)
        documents = table.search(query_vec, vector_column_name=self.vector_column).limit(k).to_list()
        names = [doc[self.name_column] for doc in documents]
        descriptions = [doc[self.description_column] for doc in documents]

        return names, descriptions