Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import chromadb | |
from sentence_transformers import SentenceTransformer | |
import spaces | |
def get_embeddings(queries, task): | |
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN")) | |
prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries] | |
query_embeddings = model.encode(prompts) | |
return query_embeddings | |
# Initialize a persistent Chroma client and retrieve collection | |
client = chromadb.PersistentClient(path="./chroma") | |
collection_de = client.get_collection(name="phil_de") | |
collection_en = client.get_collection(name="phil_en") | |
authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"] | |
authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"] | |
def query_chroma(collection, embedding, authors): | |
try: | |
where_filter = {"author": {"$in": authors}} if authors else {} | |
# Directly use the embedding provided, already in list format suitable for the query | |
results = collection.query( | |
query_embeddings=[embedding.tolist()], # Ensure embedding is properly formatted | |
n_results=10, | |
where=where_filter, | |
include=["documents", "metadatas", "distances"] | |
) | |
ids = results.get('ids', [[]])[0] | |
metadatas = results.get('metadatas', [[]])[0] | |
documents = results.get('documents', [[]])[0] | |
distances = results.get('distances', [[]])[0] | |
formatted_results = [] | |
for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances): | |
result_dict = { | |
"id": id_, | |
"author": metadata.get('author', 'Unknown author'), | |
"book": metadata.get('book', 'Unknown book'), | |
"section": metadata.get('section', 'Unknown section'), | |
"title": metadata.get('title', 'Untitled'), | |
"text": document_text, | |
"distance": distance | |
} | |
formatted_results.append(result_dict) | |
return formatted_results | |
except Exception as e: | |
return [{"error": str(e)}] | |
def update_authors(database): | |
return gr.update(choices=authors_list_de if database == "German" else authors_list_en) | |
with gr.Blocks() as demo: | |
gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search.") | |
database_inp = gr.Dropdown(label="Database", choices=["English", "German"], value="German") | |
author_inp = gr.Dropdown(label="Authors", choices=authors_list_de, multiselect=True) | |
inp = gr.Textbox(label="Query", placeholder="Enter questions separated by semicolons...") | |
btn = gr.Button("Search") | |
results = gr.State() # Store results in a State component | |
def perform_query(queries, authors, database): | |
task = "Given a question, retrieve passages that answer the question" | |
queries = queries.split(';') | |
embeddings = get_embeddings(queries, task) | |
collection = collection_de if database == "German" else collection_en | |
results_data = [] | |
for query, embedding in zip(queries, embeddings): | |
res = query_chroma(collection, embedding, authors) | |
results_data.append((query, res)) | |
return results_data | |
btn.click( | |
perform_query, | |
inputs=[inp, author_inp, database_inp], | |
outputs=[results] | |
) | |
def display_accordion(data): | |
output_blocks = [] | |
for query, res in data: | |
with gr.Accordion(query) as acc: | |
if not res: | |
markdown_contents = "No results found." | |
elif "error" in res[0]: | |
markdown_contents = f"Error retrieving data: {res[0]['error']}" | |
else: | |
markdown_contents = "\n".join(f"**{r['author']}, {r['book']}**\n\n{r['text']}" for r in res) | |
gr.Markdown(markdown_contents) | |
database_inp.change( | |
fn=lambda database: update_authors(database), | |
inputs=[database_inp], | |
outputs=[author_inp] | |
) | |
demo.launch() |