|
import pandas as pd |
|
import os |
|
import chromadb |
|
from chromadb.utils import embedding_functions |
|
import math |
|
|
|
|
|
|
|
|
|
|
|
def create_domain_identification_database(vdb_path: str,collection_name:str , df: pd.DataFrame) -> None: |
|
"""This function processes the dataframe into the required format, and then creates the following collections in a ChromaDB instance |
|
1. domain_identification_collection - Contains input text embeddings, and the metadata the other columns |
|
|
|
Args: |
|
collection_name (str) : name of database collection |
|
vdb_path (str): Relative path of the location of the ChromaDB instance. |
|
df (pd.DataFrame): task scheduling dataset. |
|
|
|
""" |
|
|
|
|
|
chroma_client = chromadb.PersistentClient(path=vdb_path) |
|
|
|
|
|
embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE") |
|
|
|
|
|
domain_identification_collection = chroma_client.create_collection( |
|
name=collection_name, |
|
embedding_function=embedding_function, |
|
) |
|
|
|
|
|
|
|
domain_identification_documents = [row.query for row in df.itertuples()] |
|
|
|
|
|
domain_identification_metadata = [ |
|
{"domain": row.domain , "label": row.label} |
|
for row in df.itertuples() |
|
] |
|
|
|
|
|
domain_ids = ["domain_id " + str(row.Index) for row in df.itertuples()] |
|
|
|
|
|
length = len(df) |
|
num_iteration = length / 166 |
|
num_iteration = math.ceil(num_iteration) |
|
|
|
start = 0 |
|
|
|
for i in range(num_iteration): |
|
if i == num_iteration - 1 : |
|
domain_identification_collection.add(documents=domain_identification_documents[start:], metadatas=domain_identification_metadata[start:], ids=domain_ids[start:]) |
|
else: |
|
end = start + 166 |
|
domain_identification_collection.add(documents=domain_identification_documents[start:end], metadatas=domain_identification_metadata[start:end], ids=domain_ids[start:end]) |
|
start = end |
|
return None |
|
|
|
|
|
|
|
def delete_collection_from_vector_db(vdb_path: str, collection_name: str) -> None: |
|
"""Deletes a particular collection from the persistent ChromaDB instance. |
|
|
|
Args: |
|
vdb_path (str): Path of the persistent ChromaDB instance. |
|
collection_name (str): Name of the collection to be deleted. |
|
""" |
|
chroma_client = chromadb.PersistentClient(path=vdb_path) |
|
chroma_client.delete_collection(collection_name) |
|
return None |
|
|
|
|
|
def list_collections_from_vector_db(vdb_path: str) -> None: |
|
"""Lists all the available collections from the persistent ChromaDB instance. |
|
|
|
Args: |
|
vdb_path (str): Path of the persistent ChromaDB instance. |
|
""" |
|
chroma_client = chromadb.PersistentClient(path=vdb_path) |
|
print(chroma_client.list_collections()) |
|
|
|
|
|
def get_collection_from_vector_db( |
|
vdb_path: str, collection_name: str |
|
) -> chromadb.Collection: |
|
"""Fetches a particular ChromaDB collection object from the persistent ChromaDB instance. |
|
|
|
Args: |
|
vdb_path (str): Path of the persistent ChromaDB instance. |
|
collection_name (str): Name of the collection which needs to be retrieved. |
|
""" |
|
chroma_client = chromadb.PersistentClient(path=vdb_path) |
|
|
|
huggingface_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE") |
|
|
|
|
|
|
|
|
|
collection = chroma_client.get_collection( |
|
name=collection_name, embedding_function=huggingface_ef |
|
) |
|
|
|
return collection |
|
|
|
|
|
def retrieval( input_text : str, |
|
num_results : int, |
|
collection: chromadb.Collection ): |
|
|
|
"""fetches the domain name from the collection based on the semantic similarity |
|
|
|
args: |
|
input_text : the received text which can be news , posts , or tweets |
|
num_results : number of fetched examples from the collection |
|
collection : the extracted collection from the database that we will fetch examples from |
|
|
|
""" |
|
|
|
|
|
fetched_domain = collection.query( |
|
query_texts = [input_text], |
|
n_results = num_results, |
|
) |
|
|
|
|
|
|
|
domain = fetched_domain["metadatas"][0][0]["domain"] |
|
label = fetched_domain["metadatas"][0][0]["label"] |
|
distance = fetched_domain["distances"][0][0] |
|
|
|
return domain , label , distance |