Spaces:
Sleeping
Sleeping
import os | |
import sqlite3 | |
import time | |
from typing import List | |
from dotenv import load_dotenv | |
from openai import OpenAI | |
from qdrant_client import QdrantClient | |
from utils.markdown import generate_image_description_with_empty_description | |
from utils.cli import \ | |
print_execution_time, start_loading_animation, stop_loading_animation | |
from splitter import splitter | |
from qdrant_client.models import Distance, VectorParams | |
from qdrant_client.models import PointStruct | |
HEADERS_TO_SPLIT_ON = [ | |
("#", "Header 1"), | |
("##", "Header 2"), | |
("###", "Header 3"), | |
("####", "Header 4"), | |
("#####", "Header 5"), | |
("######", "Header 6"), | |
] | |
def safe_join(thread, done: List[bool]): | |
if thread is not None and thread.is_alive(): | |
done[0] = True | |
thread.join() | |
def run(): | |
loading_split = None | |
loading_image_description = None | |
loading_embedding = None | |
done_split = [False] | |
done_image_description = [False] | |
done_embeding = [False] | |
try: | |
load_dotenv() | |
start_time = time.time() | |
print("➗ Começando o RAG...") | |
conn = sqlite3.connect('../meu_banco.db') | |
cursor = conn.cursor() | |
print("✅ Conectado ao sqlite.") | |
qdrant_client = QdrantClient(":memory:") | |
print("✅ Conectado ao qdrant.") | |
oa_client = OpenAI( | |
api_key=os.environ.get("OPENAI_API_KEY") | |
) | |
print("✅ Conectado a OpenAI.") | |
if qdrant_client.collection_exists(os.environ.get("COLLECTION_NAME")): | |
qdrant_client.delete_collection( | |
collection_name=os.environ.get("COLLECTION_NAME") | |
) | |
qdrant_client.create_collection( | |
collection_name=os.environ.get("COLLECTION_NAME"), | |
vectors_config=VectorParams(size=1536, distance=Distance.COSINE) | |
) | |
print("✅ Collection criada.") | |
cursor.execute(''' | |
SELECT w.id, w.url, w.text_content, w.relative_path, w.hyperrefs, | |
i.name AS image_name, i.path AS image_path, i.url AS image_url, i.hyperlink AS image_hyperlink, i.alt AS image_alt, | |
v.name AS video_name, v.url AS video_url, v.hyperlink AS video_hyperlink, v.alt AS video_alt | |
FROM Website w | |
LEFT JOIN Image i ON i.website_id = w.id | |
LEFT JOIN Video v ON v.website_id = w.id | |
''') | |
web_sites_bruto = cursor.fetchall() | |
web_sites = [] | |
for web_site_bruto in web_sites_bruto: | |
web_site = { | |
"text_content": web_site_bruto[2], | |
"text_id": web_site_bruto[0], | |
"url": web_site_bruto[1] | |
} | |
web_sites.append(web_site) | |
loading_image_description = start_loading_animation( | |
done_image_description, | |
"Gerando Descrições de imagens caso não tenha..." | |
) | |
texts_with_images_descriptions, error_images_description = \ | |
generate_image_description_with_empty_description( | |
oa_client, | |
web_sites | |
) | |
stop_loading_animation( | |
done_image_description, | |
loading_image_description | |
) | |
print("✅ Descrição de imagens geradas.") | |
texts_parent_splitted, texts_child_splitted = splitter.split( | |
texts_with_images_descriptions, | |
HEADERS_TO_SPLIT_ON | |
) | |
print("✅ Textos dividos.") | |
loading_split = start_loading_animation( | |
done_split, | |
"Dividindo textos..." | |
) | |
for text_parent_splitted in texts_parent_splitted: | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS ParentText | |
( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
content TEXT, | |
parent_id INTEGER, | |
url TEXT, | |
FOREIGN KEY (parent_id) REFERENCES Website(id) | |
); | |
''') | |
cursor.execute(''' | |
INSERT INTO ParentText (content, parent_id, url) | |
VALUES (?, ?, ?); | |
''', ( | |
text_parent_splitted["content"].page_content, | |
text_parent_splitted["parent_id"], | |
text_parent_splitted["url"] | |
)) | |
stop_loading_animation(done_split, loading_split) | |
print("✅ Texto pai salvo no sqlite.") | |
count_embeddings = 0 | |
total_child_splitted = len(texts_child_splitted) | |
for text_child_splitted in texts_child_splitted: | |
done_embeding = [False] | |
loading_embedding = start_loading_animation( | |
done_embeding, | |
f"""Gerando embeddings e salvando no qdrant: {count_embeddings} de {total_child_splitted}""" # noqa | |
) | |
embedding = oa_client.embeddings.create( | |
input=[text_child_splitted["content"].page_content], | |
model=os.environ.get("OPENAI_MODEL_EMBEDDING") | |
).data[0].embedding | |
qdrant_client.upsert( | |
collection_name=os.environ.get("COLLECTION_NAME"), | |
points=[PointStruct( | |
id=str(text_child_splitted["id"]), | |
vector=embedding, | |
payload={ | |
"content": text_child_splitted["content"].page_content, | |
"parent_id": text_child_splitted["parent_id"], | |
"type": "text" | |
} | |
)] | |
) | |
count_embeddings += 1 | |
stop_loading_animation(done_embeding, loading_embedding) | |
print("✅ Texto filho salvo no qdrant.") | |
print("✅ RAG finalizado.") | |
print( | |
f""" | |
📊 Relatório:\n | |
\t Tempo de execução: {print_execution_time(start_time)}\n | |
\t Textos Filhos e Embedding gerados: {len(texts_child_splitted)}\n | |
\t Textos pai gerados: {len(texts_parent_splitted)}\n | |
\t WebSites recuperados da base: {len(web_sites)}\n | |
\t Erros ao gerar descrição de imagens: {len(error_images_description)}\n | |
""") # noqa | |
except Exception as error: | |
print(f"❌ Erro: {error}") | |
finally: | |
safe_join(loading_split, done_split) | |
safe_join(loading_image_description, done_image_description) | |
safe_join(loading_embedding, done_embeding) | |
if __name__ == "__main__": | |
run() | |