File size: 4,508 Bytes
d2ed505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import pandas as pd
import os
import chromadb
from chromadb.utils import embedding_functions
import math





def create_domain_identification_database(vdb_path: str,collection_name:str , df: pd.DataFrame) -> None:
    """This function processes the dataframe into the required format, and then creates the following collections in a ChromaDB instance
    1. domain_identification_collection - Contains input text embeddings, and the metadata the other columns

    Args:
        collection_name (str) : name of database collection
        vdb_path (str): Relative path of the location of the ChromaDB instance.
        df (pd.DataFrame): task scheduling dataset.

    """

    #identify the saving location of the ChromaDB
    chroma_client = chromadb.PersistentClient(path=vdb_path)

    #extract the embedding from hugging face
    embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE")

    #creating the collection
    domain_identification_collection = chroma_client.create_collection(
        name=collection_name,
        embedding_function=embedding_function,
    )


    # the main text "query" that will be embedded
    domain_identification_documents = [row.query for row in df.itertuples()]

    # the metadata
    domain_identification_metadata = [
        {"domain": row.domain , "label": row.label}
        for row in df.itertuples()
    ]

    #index
    domain_ids = ["domain_id " + str(row.Index) for row in df.itertuples()]


    length = len(df)
    num_iteration = length / 166
    num_iteration = math.ceil(num_iteration)

    start = 0
    # start adding the the vectors
    for i in range(num_iteration):
        if i == num_iteration - 1 :
            domain_identification_collection.add(documents=domain_identification_documents[start:], metadatas=domain_identification_metadata[start:], ids=domain_ids[start:])
        else:
            end = start + 166
            domain_identification_collection.add(documents=domain_identification_documents[start:end], metadatas=domain_identification_metadata[start:end], ids=domain_ids[start:end])
            start = end
    return None



def delete_collection_from_vector_db(vdb_path: str, collection_name: str) -> None:
    """Deletes a particular collection from the persistent ChromaDB instance.

    Args:
        vdb_path (str): Path of the persistent ChromaDB instance.
        collection_name (str): Name of the collection to be deleted.
    """
    chroma_client = chromadb.PersistentClient(path=vdb_path)
    chroma_client.delete_collection(collection_name)
    return None


def list_collections_from_vector_db(vdb_path: str) -> None:
    """Lists all the available collections from the persistent ChromaDB instance.

    Args:
        vdb_path (str): Path of the persistent ChromaDB instance.
    """
    chroma_client = chromadb.PersistentClient(path=vdb_path)
    print(chroma_client.list_collections())


def get_collection_from_vector_db(
    vdb_path: str, collection_name: str
) -> chromadb.Collection:
    """Fetches a particular ChromaDB collection object from the persistent ChromaDB instance.

    Args:
        vdb_path (str): Path of the persistent ChromaDB instance.
        collection_name (str): Name of the collection which needs to be retrieved.
    """
    chroma_client = chromadb.PersistentClient(path=vdb_path)

    huggingface_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE")




    collection = chroma_client.get_collection(
        name=collection_name, embedding_function=huggingface_ef
    )

    return collection


def retrieval( input_text : str,
              num_results : int,
              collection: chromadb.Collection ):

    """fetches the domain name from the collection based on the semantic similarity

    args:
        input_text : the received text  which can be news , posts , or tweets
        num_results : number of fetched examples from the collection
        collection : the extracted collection from the database that we will fetch examples from

    """


    fetched_domain = collection.query(
            query_texts = [input_text],
            n_results = num_results,
            )

    #extracting domain name  and label from the featched domains

    domain = fetched_domain["metadatas"][0][0]["domain"]
    label = fetched_domain["metadatas"][0][0]["label"]
    distance = fetched_domain["distances"][0][0]

    return domain , label , distance