Spaces:
Sleeping
Sleeping
Synced repo using 'sync_with_huggingface' Github Action
Browse files- .dockerignore +4 -0
- .env.example +10 -0
- CreateEnv.ps1 +15 -0
- DATABASE.py +13 -0
- Middleware.py +65 -0
- MongoChainGenerator.py +58 -0
- MongoEmbeddingGenerator.py +34 -0
- OtherFun.py +38 -0
- ReadME.md +53 -0
- appConfig.py +62 -0
- docker-compose.yml +5 -0
- main.py +83 -4
- requirements.txt +27 -2
- verifyToken.py +26 -0
.dockerignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.vercel
|
2 |
+
.venv
|
3 |
+
.idea
|
4 |
+
__pycache__
|
.env.example
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authentication to HuggingFaceHub
|
2 |
+
HUGGINGFACEHUB_API_TOKEN=
|
3 |
+
|
4 |
+
# Authentication Token secret
|
5 |
+
JWT_SECRET=
|
6 |
+
|
7 |
+
# database
|
8 |
+
MONGO_DB_URL=
|
9 |
+
MONGO_DB_NAME=
|
10 |
+
MONGO_DB_NAME_CACHE=
|
CreateEnv.ps1
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Set the project directory and virtual environment name
|
2 |
+
$venvName = "venv"
|
3 |
+
|
4 |
+
# Check if the virtual environment folder exists
|
5 |
+
$venvExists = Test-Path ($venvName)
|
6 |
+
|
7 |
+
if (-not $venvExists) {
|
8 |
+
# Create virtual environment if it doesn't exist
|
9 |
+
Write-Host "Creating virtual environment..."
|
10 |
+
python -m venv $venvName
|
11 |
+
}
|
12 |
+
Write-Host "Activate venv"
|
13 |
+
& venv\Scripts\Activate.ps1
|
14 |
+
Write-Host "Installing dependencies..."
|
15 |
+
python -m pip install -r "requirements.txt"
|
DATABASE.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from appConfig import *
|
2 |
+
from pymongo import MongoClient
|
3 |
+
|
4 |
+
|
5 |
+
class DATABASE():
|
6 |
+
client = None
|
7 |
+
|
8 |
+
def __init__(self):
|
9 |
+
self._initialize_mongodb_client()
|
10 |
+
|
11 |
+
def _initialize_mongodb_client(self):
|
12 |
+
if DATABASE.client is None:
|
13 |
+
DATABASE.client = MongoClient(ENV_VAR.MONGO_DB_URL)
|
Middleware.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from MongoChainGenerator import *
|
2 |
+
from MongoEmbeddingGenerator import *
|
3 |
+
from DATABASE import *
|
4 |
+
from appConfig import LOG
|
5 |
+
|
6 |
+
|
7 |
+
class Main:
|
8 |
+
qa_chains = {}
|
9 |
+
embedding_generator = None
|
10 |
+
|
11 |
+
def __init__(self) -> None:
|
12 |
+
DATABASE()
|
13 |
+
self._initialize_embedding_generator()
|
14 |
+
self._load_existing_qa_chains()
|
15 |
+
|
16 |
+
def _initialize_embedding_generator(self):
|
17 |
+
if Main.embedding_generator is None:
|
18 |
+
Main.embedding_generator = MongoEmbeddingGenerator(repo_id=CONST_VAR.EMBEDDING_MODEL_REPO_ID)
|
19 |
+
LOG.debug("Embedding generator initialized")
|
20 |
+
|
21 |
+
def _load_existing_qa_chains(self):
|
22 |
+
chats = DATABASE.client["chatData"]["chats"].find()
|
23 |
+
for chat in chats:
|
24 |
+
if chat["collectionName"] not in Main.qa_chains:
|
25 |
+
self.create_exist_chains(chat)
|
26 |
+
|
27 |
+
def create_exist_chains(self, chat):
|
28 |
+
if chat["collectionName"] not in Main.qa_chains:
|
29 |
+
qa_generator = MongoChainGenerator(
|
30 |
+
embedding_model=Main.embedding_generator.embedding_model,
|
31 |
+
db_collection_name=chat["collectionName"],
|
32 |
+
template_context=chat["templateContext"]
|
33 |
+
)
|
34 |
+
Main.qa_chains[chat["collectionName"]] = qa_generator.generate_retrieval_qa_chain()
|
35 |
+
LOG.debug("Chain created for collection " + chat["collectionName"])
|
36 |
+
else:
|
37 |
+
LOG.debug("Chain already exists for collection " + chat["collectionName"])
|
38 |
+
|
39 |
+
def generate_embedding(self, content: str, file_name: str, collection_name: str):
|
40 |
+
return Main.embedding_generator.generate_embeddings(content, file_name, collection_name)
|
41 |
+
|
42 |
+
def generate_tmp_embedding_and_chain(self, contents: str, tmp_collection_name):
|
43 |
+
qa_generator = MongoChainGenerator(
|
44 |
+
embedding_model=Main.embedding_generator.embedding_model,
|
45 |
+
template_context=CONST_VAR.TEMPLATE_CONTEXT,
|
46 |
+
tmp_vector_embedding=Main.embedding_generator.generate_tmp_embeddings(pdf_bytes=contents)
|
47 |
+
)
|
48 |
+
Main.qa_chains[tmp_collection_name] = qa_generator.generate_retrieval_qa_chain()
|
49 |
+
LOG.debug(tmp_collection_name + ' chain created')
|
50 |
+
|
51 |
+
def ask_question(self, question: str, collection_name):
|
52 |
+
if collection_name in Main.qa_chains:
|
53 |
+
try:
|
54 |
+
LOG.debug(collection_name + " answering")
|
55 |
+
response = Main.qa_chains[collection_name]({"query": question, "early_stopping": True, "min_length": 2000, "max_tokens": 5000})
|
56 |
+
return response["result"]
|
57 |
+
except Exception as e:
|
58 |
+
LOG.error("An error occurred while answering question: {}".format(str(e)))
|
59 |
+
return "Retry to ask question! An error occurred: {}".format(str(e))
|
60 |
+
else:
|
61 |
+
LOG.warning("Chain for collection '{}' not found.".format(collection_name))
|
62 |
+
return "Chain for collection '{}' not found.".format(collection_name)
|
63 |
+
|
64 |
+
def check_collection_name(self, collection_name):
|
65 |
+
return collection_name in self.qa_chains
|
MongoChainGenerator.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from appConfig import *
|
2 |
+
from langchain.chains import RetrievalQA
|
3 |
+
from langchain.prompts import PromptTemplate
|
4 |
+
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
5 |
+
from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch
|
6 |
+
from langchain.vectorstores.faiss import FAISS
|
7 |
+
from huggingface_hub import login
|
8 |
+
|
9 |
+
login(token=ENV_VAR.HUGGINGFACEHUB_API_TOKEN,write_permission=True,add_to_git_credential=True)
|
10 |
+
|
11 |
+
class MongoChainGenerator:
|
12 |
+
LLM = None
|
13 |
+
|
14 |
+
def __init__(self, embedding_model, template_context, db_collection_name=None,tmp_vector_embedding=None):
|
15 |
+
if db_collection_name:
|
16 |
+
self._load_vectors(embedding_model, db_collection_name)
|
17 |
+
else:
|
18 |
+
self._create_tmp_retriever(tmp_vector_embedding)
|
19 |
+
|
20 |
+
self._initialize_prompt(template_context)
|
21 |
+
|
22 |
+
if MongoChainGenerator.LLM is None:
|
23 |
+
self._initialize_llm()
|
24 |
+
|
25 |
+
def _create_tmp_retriever(self, tmp_vector_embedding: FAISS):
|
26 |
+
self.qa_retriever = tmp_vector_embedding.as_retriever(search_type="similarity", search_kwargs={"k": 7})
|
27 |
+
LOG.debug("Temporary retriever created")
|
28 |
+
|
29 |
+
def _load_vectors(self, embedding_model, db_collection_name):
|
30 |
+
self.qa_retriever = MongoDBAtlasVectorSearch.from_connection_string(
|
31 |
+
connection_string=ENV_VAR.MONGO_DB_URL,
|
32 |
+
namespace=ENV_VAR.MONGO_DB_NAME + "." + db_collection_name,
|
33 |
+
embedding=embedding_model,
|
34 |
+
).as_retriever(search_type="similarity", search_kwargs={"k": 7})
|
35 |
+
LOG.debug("Retriever loaded from MongoDB Atlas")
|
36 |
+
|
37 |
+
def _initialize_prompt(self, template_context):
|
38 |
+
template = template_context + """
|
39 |
+
{context}
|
40 |
+
|
41 |
+
Question: {question} all related details.
|
42 |
+
Answer:"""
|
43 |
+
self.prompt = PromptTemplate(template=template, input_variables=["context", "question"])
|
44 |
+
LOG.debug("Prompt template initialized")
|
45 |
+
|
46 |
+
def _initialize_llm(self):
|
47 |
+
MongoChainGenerator.LLM = HuggingFaceEndpoint(repo_id=CONST_VAR.TEXT_GENERATOR_MODEL_REPO_ID, temperature=0.8, max_new_tokens=4096)
|
48 |
+
# MongoChainGenerator.LLM = HuggingFaceHub(repo_id=CONST_VAR.TEXT_GENERATOR_MODEL_REPO_ID, model_kwargs={"temperature": 0.85, "return_full_text": False, "max_length": 4096, "max_new_tokens": 4096})
|
49 |
+
LOG.info("LLM initialized")
|
50 |
+
|
51 |
+
def generate_retrieval_qa_chain(self):
|
52 |
+
chain = RetrievalQA.from_chain_type(
|
53 |
+
llm=MongoChainGenerator.LLM,
|
54 |
+
retriever=self.qa_retriever,
|
55 |
+
chain_type_kwargs={"prompt": self.prompt},
|
56 |
+
)
|
57 |
+
LOG.debug("Retrieval QA chain generated")
|
58 |
+
return chain
|
MongoEmbeddingGenerator.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
import PyPDF2
|
3 |
+
from appConfig import *
|
4 |
+
from DATABASE import *
|
5 |
+
from langchain.vectorstores.faiss import FAISS
|
6 |
+
from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch
|
7 |
+
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
8 |
+
|
9 |
+
class MongoEmbeddingGenerator:
|
10 |
+
|
11 |
+
def __init__(self, repo_id):
|
12 |
+
self.embedding_model = HuggingFaceHubEmbeddings(repo_id=repo_id, huggingfacehub_api_token=ENV_VAR.HUGGINGFACEHUB_API_TOKEN)
|
13 |
+
LOG.info("Embedding model initialised")
|
14 |
+
|
15 |
+
def _extract_text_from_pdf(self, pdf_bytes):
|
16 |
+
pdf_file = BytesIO(pdf_bytes)
|
17 |
+
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
18 |
+
return [pdf_reader.pages[page_num].extract_text() for page_num in range(len(pdf_reader.pages))]
|
19 |
+
|
20 |
+
def generate_tmp_embeddings(self, pdf_bytes):
|
21 |
+
texts = self._extract_text_from_pdf(pdf_bytes)
|
22 |
+
return FAISS.from_texts(texts=texts, embedding=self.embedding_model)
|
23 |
+
|
24 |
+
def generate_embeddings(self, pdf_bytes, file_name: str, collection_name: str):
|
25 |
+
client = DATABASE.client
|
26 |
+
if client[ENV_VAR.MONGO_DB_NAME_CACHE][collection_name].find_one({"src_file_name": file_name}):
|
27 |
+
LOG.debug(f"Vectors already exist in MongoDB for file {file_name}")
|
28 |
+
return f"Vectors already exist in MongoDB for file {file_name}"
|
29 |
+
else:
|
30 |
+
texts = self._extract_text_from_pdf(pdf_bytes)
|
31 |
+
client[ENV_VAR.MONGO_DB_NAME_CACHE][collection_name].insert_one({"src_file_name": file_name})
|
32 |
+
MongoDBAtlasVectorSearch.from_texts(texts=texts, embedding=self.embedding_model, collection=client[ENV_VAR.MONGO_DB_NAME][collection_name])
|
33 |
+
LOG.debug(f"Vectors stored in MongoDB for file {file_name}")
|
34 |
+
return f"Vectors stored in MongoDB for file {file_name}"
|
OtherFun.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import UploadFile
|
2 |
+
import asyncio
|
3 |
+
from Middleware import Main
|
4 |
+
from appConfig import LOG
|
5 |
+
|
6 |
+
|
7 |
+
def delete_chain_after_delay(model: Main, chain_name: str):
|
8 |
+
async def delete_chain():
|
9 |
+
try:
|
10 |
+
await asyncio.sleep(7200) # Sleep for 2 hours
|
11 |
+
if chain_name in model.qa_chains:
|
12 |
+
del model.qa_chains[chain_name]
|
13 |
+
# Log deletion
|
14 |
+
LOG.info(f"Chain '{chain_name}' deleted after 2 hours")
|
15 |
+
except Exception as e:
|
16 |
+
LOG.error(f"An error occurred while deleting chain '{chain_name}': {e}")
|
17 |
+
|
18 |
+
return delete_chain
|
19 |
+
|
20 |
+
|
21 |
+
async def process_file(model: Main, collection_name: str, file: UploadFile):
|
22 |
+
try:
|
23 |
+
contents = await file.read()
|
24 |
+
|
25 |
+
file_extension = file.filename.split(".")[-1]
|
26 |
+
|
27 |
+
if file_extension == "pdf":
|
28 |
+
response = model.generate_embedding(
|
29 |
+
contents, file.filename, collection_name)
|
30 |
+
elif file_extension == "txt":
|
31 |
+
response = contents.decode("utf-8")
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Unsupported file format for {file.filename}")
|
34 |
+
|
35 |
+
return response
|
36 |
+
except Exception as e:
|
37 |
+
LOG.error(f"An error occurred while processing file '{file.filename}': {e}")
|
38 |
+
return f"Error processing file '{file.filename}': {e}"
|
ReadME.md
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Automated Legal Document Analysis and Question Answering System
|
2 |
+
|
3 |
+
## Requirements
|
4 |
+
### Python Version
|
5 |
+
- Python 3.9.10
|
6 |
+
|
7 |
+
### Huggingface API
|
8 |
+
- Generate your API key and place it in the .env file:
|
9 |
+
```
|
10 |
+
HUGGINGFACEHUB_API_TOKEN=""
|
11 |
+
```
|
12 |
+
|
13 |
+
## Setup Environment
|
14 |
+
|
15 |
+
### Option 1: Automated Setup (PowerShell)
|
16 |
+
- Run `CreateEnv.ps1` file in PowerShell. It will:
|
17 |
+
- Create a virtual environment
|
18 |
+
- Activate it
|
19 |
+
- Create temporary folders
|
20 |
+
- Install necessary Python modules
|
21 |
+
|
22 |
+
### Option 2: Manual Setup
|
23 |
+
- Create a virtual environment:
|
24 |
+
```bash
|
25 |
+
python -m venv venv
|
26 |
+
```
|
27 |
+
- Activate virtual environment:
|
28 |
+
```bash
|
29 |
+
venv\Scripts\Activate.ps1
|
30 |
+
```
|
31 |
+
- Install Python modules:
|
32 |
+
```bash
|
33 |
+
python -m pip install -r "requirements.txt"
|
34 |
+
```
|
35 |
+
- Create folders:
|
36 |
+
1. Outputs
|
37 |
+
2. Models (if you want to download manually) (Not necessary)
|
38 |
+
|
39 |
+
## References
|
40 |
+
- [YouTube Video Reference](https://www.youtube.com/watch?v=dXxQ0LR-3Hg&t=123s)
|
41 |
+
- [GitHub](https://github.com/curiousily/Get-Things-Done-with-Prompt-Engineering-and-LangChain)
|
42 |
+
|
43 |
+
### Models
|
44 |
+
- [gpt2](https://huggingface.co/gpt2)
|
45 |
+
- [gte-small](https://huggingface.co/thenlper/gte-small)
|
46 |
+
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
47 |
+
|
48 |
+
### PDF Documents
|
49 |
+
- [The Indian Penal Code](https://www.iitk.ac.in/wc/data/IPC_186045.pdf)
|
50 |
+
|
51 |
+
|
52 |
+
### HuggingFaceHub Repository link:
|
53 |
+
- https://huggingface.co/spaces/dhruv4023/APIchatbot
|
appConfig.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
|
8 |
+
class ENV_VAR():
|
9 |
+
MONGO_DB_URL = os.environ.get("MONGO_DB_URL")
|
10 |
+
MONGO_DB_NAME = os.environ.get("MONGO_DB_NAME")
|
11 |
+
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
12 |
+
MONGO_DB_NAME_CACHE = os.environ.get("MONGO_DB_NAME_CACHE")
|
13 |
+
JWT_SECRET = os.environ.get("JWT_SECRET")
|
14 |
+
|
15 |
+
|
16 |
+
class CONST_VAR():
|
17 |
+
TEXT_GENERATOR_MODEL_REPO_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
18 |
+
EMBEDDING_MODEL_REPO_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
19 |
+
TEMPLATE_CONTEXT = """
|
20 |
+
Use the following pieces of context to answer the question at the end.
|
21 |
+
You should prefer information which are more related to asked question.
|
22 |
+
Make sure to rely on information from text only and not on questions to provide accurate responses.
|
23 |
+
When you find particular answer in given text, display its context useful, make sure to cite it in the your answer.
|
24 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
25 |
+
You can only use the given to you to answer the question.
|
26 |
+
Generate concise answers and relevant data related to the asked question.
|
27 |
+
You must represent the answer in proper format such as make points highlight some major information.
|
28 |
+
don't attach your created quetions. if you don't get answer from the given text just say i don't know and terminate answering.
|
29 |
+
if you get answer from the text than write all about the asked quetion and relevant data related to it.
|
30 |
+
don't use your own knowledge just use the provided text to answer the question.
|
31 |
+
"""
|
32 |
+
|
33 |
+
|
34 |
+
class LOG:
|
35 |
+
def __init__(self) -> None:
|
36 |
+
pass
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def configure_logging(level=logging.INFO):
|
40 |
+
logging.basicConfig(level=level) # Set the logging level
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def debug(msg):
|
44 |
+
logging.debug(msg)
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def info(msg):
|
48 |
+
logging.info(msg)
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def warning(msg):
|
52 |
+
logging.warning(msg)
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def error(msg):
|
56 |
+
logging.error(msg)
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def critical(msg):
|
60 |
+
logging.critical(msg)
|
61 |
+
|
62 |
+
LOG.configure_logging() # Set logging level to INFO
|
docker-compose.yml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
web:
|
3 |
+
build: .
|
4 |
+
ports:
|
5 |
+
- "8000:8000"
|
main.py
CHANGED
@@ -1,7 +1,86 @@
|
|
1 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
@app.get("/")
|
6 |
-
async def
|
7 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
from typing import List, Optional
|
3 |
+
from fastapi.responses import JSONResponse
|
4 |
+
from starlette.middleware import Middleware
|
5 |
+
from starlette.middleware.gzip import GZipMiddleware
|
6 |
+
from starlette.middleware.cors import CORSMiddleware
|
7 |
+
from fastapi import FastAPI, File, UploadFile, Depends, Form, BackgroundTasks
|
8 |
|
9 |
+
from OtherFun import *
|
10 |
+
from Middleware import Main
|
11 |
+
from verifyToken import verify_token_and_role
|
12 |
+
# import os
|
13 |
+
origins = ["https://chatbotservernode.onrender.com","https://cbns.vercel.app", "https://hfhchatbot.vercel.app", "http://localhost:5000", "http://localhost:3000", "https://localhost:5000"]
|
14 |
|
15 |
+
# origins = os.getenv("ALLOWED_ORIGINS", "").split(",")
|
16 |
+
|
17 |
+
app = FastAPI(debug=True)
|
18 |
+
|
19 |
+
app.add_middleware(GZipMiddleware)
|
20 |
+
app.add_middleware(
|
21 |
+
CORSMiddleware,
|
22 |
+
allow_origins=origins, # You can replace '*' with specific origins
|
23 |
+
allow_credentials=True,
|
24 |
+
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # or specific methods
|
25 |
+
allow_headers=["Authorization", "Content-Type", "Accept"] # or specific headers
|
26 |
+
)
|
27 |
+
|
28 |
+
model = Main()
|
29 |
+
|
30 |
+
class BodyModel(BaseModel):
|
31 |
+
query: str
|
32 |
+
chain_name: Optional[str] = None # Made chain_name optional
|
33 |
+
|
34 |
@app.get("/")
|
35 |
+
async def home():
|
36 |
+
return "chatbot api server is running..."
|
37 |
+
|
38 |
+
|
39 |
+
@app.post("/ask")
|
40 |
+
async def askQ(body: BodyModel, token: str = Depends(verify_token_and_role)):
|
41 |
+
try:
|
42 |
+
response = model.ask_question(body.query, token["username"] if body.chain_name is None else body.chain_name)
|
43 |
+
return JSONResponse(content={"success": True, "data": response})
|
44 |
+
except Exception as e: # Catch specific exceptions
|
45 |
+
return JSONResponse(content={"success": False, "error": str(e)})
|
46 |
+
|
47 |
+
|
48 |
+
@app.post("/create/embedding")
|
49 |
+
async def createEmbedding(collection_name: str = Form(...), files: List[UploadFile] = File(None), token: str = Depends(verify_token_and_role)):
|
50 |
+
try:
|
51 |
+
if not files:
|
52 |
+
return JSONResponse(content={"success": False, "error":"No files provided"})
|
53 |
+
|
54 |
+
responses = []
|
55 |
+
for file in files:
|
56 |
+
response = await process_file(model, collection_name, file)
|
57 |
+
responses.append(response)
|
58 |
+
|
59 |
+
return JSONResponse(content={"success": True,"responses": responses})
|
60 |
+
except Exception as e:
|
61 |
+
return JSONResponse(content={"success": False, "error": str(e)})
|
62 |
+
|
63 |
+
|
64 |
+
@app.post("/create/tmp/chain")
|
65 |
+
async def createTmpChain(background_tasks: BackgroundTasks, files: List[UploadFile] = File(...), token: str = Depends(verify_token_and_role)):
|
66 |
+
try:
|
67 |
+
if not files:
|
68 |
+
return JSONResponse(content={"success": False, "error":"No files provided"})
|
69 |
+
|
70 |
+
all_contents = b""
|
71 |
+
for file in files:
|
72 |
+
contents = await file.read()
|
73 |
+
all_contents += contents
|
74 |
+
|
75 |
+
file_extension = files[0].filename.split(".")[-1]
|
76 |
+
if file_extension == "pdf":
|
77 |
+
chain_name = token["username"]
|
78 |
+
model.generate_tmp_embedding_and_chain(all_contents, chain_name)
|
79 |
+
background_tasks.add_task(delete_chain_after_delay(model, chain_name))
|
80 |
+
return JSONResponse(content={"success": True, "message": "Chain created. Will be deleted after 2 hours."})
|
81 |
+
elif file_extension == "txt":
|
82 |
+
all_contents.decode("utf-8")
|
83 |
+
return JSONResponse(content={"success": False, "error": "Unsupported file format"})
|
84 |
+
except Exception as e:
|
85 |
+
return JSONResponse(content={"success": False, "error": str(e)})
|
86 |
+
|
requirements.txt
CHANGED
@@ -1,2 +1,27 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain==0.1.8
|
2 |
+
python-dotenv==1.0.0
|
3 |
+
|
4 |
+
# pdf
|
5 |
+
PyPDF2==3.0.1
|
6 |
+
pypdf==3.17.3
|
7 |
+
|
8 |
+
# embedding
|
9 |
+
InstructorEmbedding==1.0.1
|
10 |
+
torch==2.2.1
|
11 |
+
tqdm==4.66.2
|
12 |
+
sentence-transformers==2.2.2
|
13 |
+
faiss-cpu
|
14 |
+
|
15 |
+
# mongodb
|
16 |
+
pymongo==4.6.1
|
17 |
+
|
18 |
+
# API-END point
|
19 |
+
fastapi==0.109.2
|
20 |
+
fastapi-cors
|
21 |
+
uvicorn[standard]==0.17.*
|
22 |
+
python-multipart==0.0.9
|
23 |
+
PyJWT==2.8.0
|
24 |
+
|
25 |
+
|
26 |
+
huggingface_hub
|
27 |
+
typing
|
verifyToken.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import HTTPException, Header, status
|
2 |
+
from appConfig import ENV_VAR, LOG
|
3 |
+
import jwt
|
4 |
+
|
5 |
+
async def verify_token_and_role(authorization: str = Header(None)):
|
6 |
+
try:
|
7 |
+
if not authorization or not authorization.startswith("Bearer "):
|
8 |
+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Token not provided or invalid")
|
9 |
+
|
10 |
+
token = authorization.split("Bearer ")[1]
|
11 |
+
|
12 |
+
try:
|
13 |
+
verified = jwt.decode(token, ENV_VAR.JWT_SECRET, algorithms=["HS256"])
|
14 |
+
LOG.debug("Token verified successfully")
|
15 |
+
except jwt.ExpiredSignatureError:
|
16 |
+
LOG.debug("Token expired")
|
17 |
+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Token expired")
|
18 |
+
|
19 |
+
if "role" not in verified or verified["role"] not in ["user","admin"]:
|
20 |
+
LOG.error("Insufficient permissions")
|
21 |
+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions")
|
22 |
+
|
23 |
+
return verified
|
24 |
+
except Exception as e:
|
25 |
+
LOG.error(f"An error occurred: {e}")
|
26 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|