Spaces:
Runtime error
Runtime error
import gradio as gr | |
import spaces | |
import subprocess | |
import os | |
import shutil | |
import string | |
import random | |
import glob | |
from pypdf import PdfReader | |
from sentence_transformers import SentenceTransformer | |
model_name = os.environ.get("MODEL", "Snowflake/snowflake-arctic-embed-m") | |
chunk_size = int(os.environ.get("CHUNK_SIZE", 128)) | |
default_max_characters = int(os.environ.get("DEFAULT_MAX_CHARACTERS", 258)) | |
model = SentenceTransformer(model_name) | |
# model.to(device="cuda") | |
def embed(queries, chunks) -> dict[str, list[tuple[str, float]]]: | |
query_embeddings = model.encode(queries, prompt_name="query") | |
document_embeddings = model.encode(chunks) | |
scores = query_embeddings @ document_embeddings.T | |
results = {} | |
for query, query_scores in zip(queries, scores): | |
chunk_idxs = [i for i in range(len(chunks))] | |
# Get a structure like {query: [(chunk_idx, score), (chunk_idx, score), ...]} | |
results[query] = list(zip(chunk_idxs, query_scores)) | |
return results | |
def extract_text_from_pdf(reader): | |
full_text = "" | |
for idx, page in enumerate(reader.pages): | |
text = page.extract_text() | |
if len(text) > 0: | |
full_text += f"---- Page {idx} ----\n" + page.extract_text() + "\n\n" | |
return full_text.strip() | |
def convert(filename) -> str: | |
plain_text_filetypes = [ | |
".txt", | |
".csv", | |
".tsv", | |
".md", | |
".yaml", | |
".toml", | |
".json", | |
".json5", | |
".jsonc", | |
] | |
# Already a plain text file that wouldn't benefit from pandoc so return the content | |
if any(filename.endswith(ft) for ft in plain_text_filetypes): | |
with open(filename, "r") as f: | |
return f.read() | |
if filename.endswith(".pdf"): | |
return extract_text_from_pdf(PdfReader(filename)) | |
raise ValueError(f"Unsupported file type: {filename}") | |
def chunk_to_length(text, max_length=512): | |
chunks = [] | |
while len(text) > max_length: | |
chunks.append(text[:max_length]) | |
text = text[max_length:] | |
chunks.append(text) | |
return chunks | |
def predict(query, max_characters) -> str: | |
# Embed the query | |
query_embedding = model.encode(query, prompt_name="query") | |
# Initialize a list to store all chunks and their similarities across all documents | |
all_chunks = [] | |
# Iterate through all documents | |
for filename, doc in docs.items(): | |
# Calculate dot product between query and document embeddings | |
similarities = doc["embeddings"] @ query_embedding.T | |
# Add chunks and similarities to the all_chunks list | |
all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)]) | |
# Sort all chunks by similarity | |
all_chunks.sort(key=lambda x: x[2], reverse=True) | |
# Initialize a dictionary to store relevant chunks for each document | |
relevant_chunks = {} | |
# Add most relevant chunks until max_characters is reached | |
total_chars = 0 | |
for filename, chunk, _ in all_chunks: | |
if total_chars + len(chunk) <= max_characters: | |
if filename not in relevant_chunks: | |
relevant_chunks[filename] = [] | |
relevant_chunks[filename].append(chunk) | |
total_chars += len(chunk) | |
else: | |
break | |
return relevant_chunks | |
docs = {} | |
for filename in glob.glob("sources/*"): | |
if filename.endswith("add_your_files_here"): | |
continue | |
converted_doc = convert(filename) | |
chunks = chunk_to_length(converted_doc, chunk_size) | |
embeddings = model.encode(chunks) | |
docs[filename] = { | |
"chunks": chunks, | |
"embeddings": embeddings, | |
} | |
gr.Interface( | |
predict, | |
inputs=[ | |
gr.Textbox(label="Query asked about the documents"), | |
gr.Number(label="Max output characters", value=default_max_characters), | |
], | |
outputs=[gr.JSON(label="Relevant chunks")], | |
title="RAG Community Tool Template demo", | |
description="This is a demo of the RAG Community Tool Template. To use RAG in HuggingChat with your own documents, start by cloning this space, add your documents to the `sources` folder, and then create a community tool with this space!", | |
).launch() |