File size: 4,960 Bytes
b0477e7
5481095
 
 
 
5768645
5054b14
5768645
34a8460
5481095
1e0081c
34a8460
5481095
 
 
 
 
5d0dd01
5481095
 
 
 
 
 
 
 
 
 
 
12b81f7
362cbdc
 
5481095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d510e2f
5481095
 
741b955
 
 
 
 
 
 
 
 
5481095
 
741b955
 
 
5054b14
 
5dad4d1
6c5586b
741b955
 
 
08c0bda
6c5586b
 
 
 
741b955
 
 
6c5586b
741b955
37db268
5481095
 
 
 
 
 
 
 
 
1e0081c
5481095
 
 
b0477e7
ebce04e
 
 
 
 
1e0081c
ebce04e
5481095
3e64b70
 
5481095
 
 
34a8460
2e6ebbe
 
5481095
 
 
4900245
 
5481095
 
 
 
 
 
 
 
5768645
 
 
5481095
5768645
 
5481095
 
 
 
 
 
 
 
 
 
10a8d15
5481095
 
 
 
 
 
054de85
5481095
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from io import BytesIO
from fastapi import FastAPI, Form, Depends, Request, File, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from pymilvus import connections


import os
import pypdf
from uuid import uuid4

from langchain.text_splitter import RecursiveCharacterTextSplitter
from pymilvus import MilvusClient, db, utility, Collection, CollectionSchema, FieldSchema, DataType
from sentence_transformers import SentenceTransformer
import torch
from milvus_singleton import MilvusClientSingleton


os.environ['HF_HOME'] = '/app/cache'
os.environ['HF_MODULES_CACHE'] = '/app/cache/hf_modules'
embedding_model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', 
                                      trust_remote_code=True, 
                                      device='cuda' if torch.cuda.is_available() else 'cpu',
                                      cache_folder='/app/cache'
)
collection_name="rag"


# milvus_client = MilvusClientSingleton.get_instance(uri="/app/milvus_data/milvus_demo.db")
milvus_client = MilvusClient(uri="/app/milvus_data/milvus_demo.db")

def document_to_embeddings(content:str) -> list:
    return embedding_model.encode(content, show_progress_bar=True)

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Replace with the list of allowed origins for production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def split_documents(document_data):
    splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=10)
    return splitter.split_text(document_data)

def create_a_collection(milvus_client, collection_name):
    # Define the fields for the collection
    id_field = FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=40, is_primary=True)
    content_field = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=4096)
    vector_field = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024)

    # Define the schema for the collection
    schema = CollectionSchema(fields=[id_field, content_field, vector_field])

    # Create the collection
    milvus_client.create_collection(
        collection_name=collection_name,
        schema=schema
    )

    connections.connect(uri="/app/milvus_data/milvus_demo.db")

    collection = Collection(name=collection_name)

    # Create an index for the collection
    # IVF_FLAT index is used here, with metric_type COSINE
    index_params = {
        "index_type": "FLAT",
        "metric_type": "COSINE",
        "params": {
            "nlist": 128
        }
    }

    # Create the index on the vector field
    collection.create_index(
        field_name="vector",
        index_params=index_params  # Pass the dictionary, not a string
    )

@app.get("/")
async def root():
    return {"message": "Hello World"}

@app.post("/insert")
async def insert(file: UploadFile = File(...)):
    contents = await file.read()

    if not milvus_client.has_collection(collection_name):
        create_a_collection(milvus_client, collection_name)

    contents = pypdf.PdfReader(BytesIO(contents))

    extracted_text = ""
    for page_num in range(len(contents.pages)):
        page = contents.pages[page_num]
        extracted_text += page.extract_text()
    
    splitted_document_data = split_documents(extracted_text)

    print(splitted_document_data)

    data_objects = []
    for doc in splitted_document_data:
        data = {
            "id": str(uuid4()),
            "vector": document_to_embeddings(doc),
            "content": doc,
        }
        data_objects.append(data)

    print(data_objects)

    try:
        milvus_client.insert(collection_name=collection_name, data=data_objects)

    except Exception as e:
        raise JSONResponse(status_code=500, content={"error": str(e)})
    else:
        return JSONResponse(status_code=200, content={"result": 'good'})
    
class RAGRequest(BaseModel):
    question: str
    
@app.post("/rag")
async def rag(request: RAGRequest):
    question = request.question
    if not question:
        return JSONResponse(status_code=400, content={"message": "Please a question!"})
    
    try:
        search_res = milvus_client.search(
            collection_name=collection_name,
            data=[
                document_to_embeddings(question)
            ], 
            limit=5,  # Return top 3 results
            # search_params={"metric_type": "COSINE"},  # Inner product distance
            output_fields=["content"],  # Return the text field
        )

        retrieved_lines_with_distances = [
            (res["entity"]["content"]) for res in search_res[0]
        ]
        return JSONResponse(status_code=200, content={"result": retrieved_lines_with_distances})
    except Exception as e:
        return JSONResponse(status_code=400, content={"error": str(e)})