File size: 4,150 Bytes
2e98c79
 
 
 
3bbb5f4
2e98c79
10f043b
7d6132f
4295f9e
7d6132f
 
2ec7158
2e98c79
 
ff486ce
2e8a9c7
 
 
 
2e98c79
2a5653d
2e98c79
 
2a5653d
2e98c79
2a5653d
7d6132f
2e98c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a5653d
2e98c79
2e8a9c7
 
 
2e98c79
 
92ed022
7d6132f
 
 
 
 
92ed022
 
 
 
 
 
 
7d6132f
 
 
2a5653d
7d6132f
 
2a5653d
 
 
 
 
2e98c79
 
7d6132f
 
2a5653d
2e98c79
7d6132f
2a5653d
 
 
92ed022
 
 
 
2e98c79
7d6132f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import gradio as gr
import chromadb
from sentence_transformers import SentenceTransformer
import spaces

@spaces.GPU
def get_embeddings(queries, task):
    model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
    prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries]
    query_embeddings = model.encode(prompts)  
    return query_embeddings

# Initialize a persistent Chroma client and retrieve collection
client = chromadb.PersistentClient(path="./chroma")
collection_de = client.get_collection(name="phil_de")
collection_en = client.get_collection(name="phil_en")
authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]

def query_chroma(collection, embedding, authors):
    try:
        where_filter = {"author": {"$in": authors}} if authors else {}
        # Directly use the embedding provided, already in list format suitable for the query
        results = collection.query(
            query_embeddings=[embedding.tolist()],  # Ensure embedding is properly formatted
            n_results=10,
            where=where_filter,
            include=["documents", "metadatas", "distances"]
        )

        ids = results.get('ids', [[]])[0]
        metadatas = results.get('metadatas', [[]])[0]
        documents = results.get('documents', [[]])[0]
        distances = results.get('distances', [[]])[0]

        formatted_results = []
        for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
            result_dict = {
                "id": id_,
                "author": metadata.get('author', 'Unknown author'),
                "book": metadata.get('book', 'Unknown book'),
                "section": metadata.get('section', 'Unknown section'),
                "title": metadata.get('title', 'Untitled'),
                "text": document_text,
                "distance": distance
            }
            formatted_results.append(result_dict)

        return formatted_results
    except Exception as e:
        return [{"error": str(e)}]

def update_authors(database):
    return gr.update(choices=authors_list_de if database == "German" else authors_list_en)



with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }")  as demo:
    gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search.")
    database_inp = gr.Dropdown(label="Database", choices=["English", "German"], value="German")
    author_inp = gr.Dropdown(label="Authors", choices=authors_list_de, multiselect=True)
    inp = gr.Textbox(label="Query", placeholder="Enter questions separated by semicolons...")
    btn = gr.Button("Search")
    results = gr.State()

    database_inp.change(
        fn=lambda database: update_authors(database),
        inputs=[database_inp],
        outputs=[author_inp]
    )

    def perform_query(queries, authors, database):
        task = "Given a question, retrieve passages that answer the question"
        queries = queries.split(';')
        embeddings = get_embeddings(queries, task)
        collection = collection_de if database == "German" else collection_en
        results_data = []
        for query, embedding in zip(queries, embeddings):
            res = query_chroma(collection, embedding, authors)
            results_data.append((query, res))
        return results_data

    btn.click(
        perform_query,
        inputs=[inp, author_inp, database_inp],
        outputs=[results]
    )

    @gr.render(inputs=[results])
    def display_accordion(data):
        for query, res in data:
            with gr.Accordion(query, open=False) as acc:
                markdown_contents = "\n".join(f"**{r['author']}, {r['book']}**\n\n{r['text']}" for r in res)
                with gr.Column():
                    gr.Markdown(value=markdown_contents, elem_classes="custom-markdown")

demo.launch()