File size: 4,508 Bytes
d2ed505 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import pandas as pd
import os
import chromadb
from chromadb.utils import embedding_functions
import math
def create_domain_identification_database(vdb_path: str,collection_name:str , df: pd.DataFrame) -> None:
"""This function processes the dataframe into the required format, and then creates the following collections in a ChromaDB instance
1. domain_identification_collection - Contains input text embeddings, and the metadata the other columns
Args:
collection_name (str) : name of database collection
vdb_path (str): Relative path of the location of the ChromaDB instance.
df (pd.DataFrame): task scheduling dataset.
"""
#identify the saving location of the ChromaDB
chroma_client = chromadb.PersistentClient(path=vdb_path)
#extract the embedding from hugging face
embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE")
#creating the collection
domain_identification_collection = chroma_client.create_collection(
name=collection_name,
embedding_function=embedding_function,
)
# the main text "query" that will be embedded
domain_identification_documents = [row.query for row in df.itertuples()]
# the metadata
domain_identification_metadata = [
{"domain": row.domain , "label": row.label}
for row in df.itertuples()
]
#index
domain_ids = ["domain_id " + str(row.Index) for row in df.itertuples()]
length = len(df)
num_iteration = length / 166
num_iteration = math.ceil(num_iteration)
start = 0
# start adding the the vectors
for i in range(num_iteration):
if i == num_iteration - 1 :
domain_identification_collection.add(documents=domain_identification_documents[start:], metadatas=domain_identification_metadata[start:], ids=domain_ids[start:])
else:
end = start + 166
domain_identification_collection.add(documents=domain_identification_documents[start:end], metadatas=domain_identification_metadata[start:end], ids=domain_ids[start:end])
start = end
return None
def delete_collection_from_vector_db(vdb_path: str, collection_name: str) -> None:
"""Deletes a particular collection from the persistent ChromaDB instance.
Args:
vdb_path (str): Path of the persistent ChromaDB instance.
collection_name (str): Name of the collection to be deleted.
"""
chroma_client = chromadb.PersistentClient(path=vdb_path)
chroma_client.delete_collection(collection_name)
return None
def list_collections_from_vector_db(vdb_path: str) -> None:
"""Lists all the available collections from the persistent ChromaDB instance.
Args:
vdb_path (str): Path of the persistent ChromaDB instance.
"""
chroma_client = chromadb.PersistentClient(path=vdb_path)
print(chroma_client.list_collections())
def get_collection_from_vector_db(
vdb_path: str, collection_name: str
) -> chromadb.Collection:
"""Fetches a particular ChromaDB collection object from the persistent ChromaDB instance.
Args:
vdb_path (str): Path of the persistent ChromaDB instance.
collection_name (str): Name of the collection which needs to be retrieved.
"""
chroma_client = chromadb.PersistentClient(path=vdb_path)
huggingface_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE")
collection = chroma_client.get_collection(
name=collection_name, embedding_function=huggingface_ef
)
return collection
def retrieval( input_text : str,
num_results : int,
collection: chromadb.Collection ):
"""fetches the domain name from the collection based on the semantic similarity
args:
input_text : the received text which can be news , posts , or tweets
num_results : number of fetched examples from the collection
collection : the extracted collection from the database that we will fetch examples from
"""
fetched_domain = collection.query(
query_texts = [input_text],
n_results = num_results,
)
#extracting domain name and label from the featched domains
domain = fetched_domain["metadatas"][0][0]["domain"]
label = fetched_domain["metadatas"][0][0]["label"]
distance = fetched_domain["distances"][0][0]
return domain , label , distance |