Spaces:
Sleeping
Sleeping
File size: 4,960 Bytes
b0477e7 5481095 5768645 5054b14 5768645 34a8460 5481095 1e0081c 34a8460 5481095 5d0dd01 5481095 12b81f7 362cbdc 5481095 d510e2f 5481095 741b955 5481095 741b955 5054b14 5dad4d1 6c5586b 741b955 08c0bda 6c5586b 741b955 6c5586b 741b955 37db268 5481095 1e0081c 5481095 b0477e7 ebce04e 1e0081c ebce04e 5481095 3e64b70 5481095 34a8460 2e6ebbe 5481095 4900245 5481095 5768645 5481095 5768645 5481095 10a8d15 5481095 054de85 5481095 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from io import BytesIO
from fastapi import FastAPI, Form, Depends, Request, File, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from pymilvus import connections
import os
import pypdf
from uuid import uuid4
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pymilvus import MilvusClient, db, utility, Collection, CollectionSchema, FieldSchema, DataType
from sentence_transformers import SentenceTransformer
import torch
from milvus_singleton import MilvusClientSingleton
os.environ['HF_HOME'] = '/app/cache'
os.environ['HF_MODULES_CACHE'] = '/app/cache/hf_modules'
embedding_model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5',
trust_remote_code=True,
device='cuda' if torch.cuda.is_available() else 'cpu',
cache_folder='/app/cache'
)
collection_name="rag"
# milvus_client = MilvusClientSingleton.get_instance(uri="/app/milvus_data/milvus_demo.db")
milvus_client = MilvusClient(uri="/app/milvus_data/milvus_demo.db")
def document_to_embeddings(content:str) -> list:
return embedding_model.encode(content, show_progress_bar=True)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Replace with the list of allowed origins for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def split_documents(document_data):
splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=10)
return splitter.split_text(document_data)
def create_a_collection(milvus_client, collection_name):
# Define the fields for the collection
id_field = FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=40, is_primary=True)
content_field = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=4096)
vector_field = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024)
# Define the schema for the collection
schema = CollectionSchema(fields=[id_field, content_field, vector_field])
# Create the collection
milvus_client.create_collection(
collection_name=collection_name,
schema=schema
)
connections.connect(uri="/app/milvus_data/milvus_demo.db")
collection = Collection(name=collection_name)
# Create an index for the collection
# IVF_FLAT index is used here, with metric_type COSINE
index_params = {
"index_type": "FLAT",
"metric_type": "COSINE",
"params": {
"nlist": 128
}
}
# Create the index on the vector field
collection.create_index(
field_name="vector",
index_params=index_params # Pass the dictionary, not a string
)
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.post("/insert")
async def insert(file: UploadFile = File(...)):
contents = await file.read()
if not milvus_client.has_collection(collection_name):
create_a_collection(milvus_client, collection_name)
contents = pypdf.PdfReader(BytesIO(contents))
extracted_text = ""
for page_num in range(len(contents.pages)):
page = contents.pages[page_num]
extracted_text += page.extract_text()
splitted_document_data = split_documents(extracted_text)
print(splitted_document_data)
data_objects = []
for doc in splitted_document_data:
data = {
"id": str(uuid4()),
"vector": document_to_embeddings(doc),
"content": doc,
}
data_objects.append(data)
print(data_objects)
try:
milvus_client.insert(collection_name=collection_name, data=data_objects)
except Exception as e:
raise JSONResponse(status_code=500, content={"error": str(e)})
else:
return JSONResponse(status_code=200, content={"result": 'good'})
class RAGRequest(BaseModel):
question: str
@app.post("/rag")
async def rag(request: RAGRequest):
question = request.question
if not question:
return JSONResponse(status_code=400, content={"message": "Please a question!"})
try:
search_res = milvus_client.search(
collection_name=collection_name,
data=[
document_to_embeddings(question)
],
limit=5, # Return top 3 results
# search_params={"metric_type": "COSINE"}, # Inner product distance
output_fields=["content"], # Return the text field
)
retrieved_lines_with_distances = [
(res["entity"]["content"]) for res in search_res[0]
]
return JSONResponse(status_code=200, content={"result": retrieved_lines_with_distances})
except Exception as e:
return JSONResponse(status_code=400, content={"error": str(e)}) |