Spaces:
Paused
Paused
import os | |
import glob | |
import shutil | |
import subprocess | |
import asyncio | |
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
# import torch | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.memory import ConversationBufferMemory | |
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler | |
from langchain.schema import LLMResult | |
# from langchain.embeddings import HuggingFaceEmbeddings | |
from load_models import load_model | |
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.vectorstores import Chroma | |
from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME, PATH_NAME_SOURCE_DIRECTORY | |
class Predict(BaseModel): | |
prompt: str | |
class Delete(BaseModel): | |
filename: str | |
class MyCustomAsyncHandler(AsyncCallbackHandler): | |
def on_llm_new_token(self, token: str, **kwargs) -> None: | |
print(f" token: {token}") | |
async def on_llm_start( | |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
) -> None: | |
class_name = serialized["name"] | |
print("start") | |
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
print("finish") | |
# if torch.backends.mps.is_available(): | |
# DEVICE_TYPE = "mps" | |
# elif torch.cuda.is_available(): | |
# DEVICE_TYPE = "cuda" | |
# else: | |
# DEVICE_TYPE = "cpu" | |
DEVICE_TYPE = "cuda" | |
SHOW_SOURCES = True | |
EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE}) | |
# load the vectorstore | |
DB = Chroma( | |
persist_directory=PERSIST_DIRECTORY, | |
embedding_function=EMBEDDINGS, | |
client_settings=CHROMA_SETTINGS, | |
) | |
RETRIEVER = DB.as_retriever() | |
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks = [MyCustomAsyncHandler]) | |
template = """you are a helpful, respectful and honest assistant. When answering questions, you should only use the documents provided. | |
You should only answer the topics that appear in these documents. | |
Always answer in the most helpful and reliable way possible, if you don't know the answer to a question, just say you don't know, don't try to make up an answer, | |
don't share false information. you should use no more than 15 sentences and all your answers should be as concise as possible. | |
Always say "Thank you for asking!" at the end of your answer. | |
Context: {history} \n {context} | |
Question: {question} | |
""" | |
memory = ConversationBufferMemory(input_key="question", memory_key="history") | |
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["history", "context", "question"], template=template) | |
QA = RetrievalQA.from_chain_type( | |
llm=LLM, | |
chain_type="stuff", | |
retriever=RETRIEVER, | |
return_source_documents=SHOW_SOURCES, | |
chain_type_kwargs={ | |
"prompt": QA_CHAIN_PROMPT, | |
"memory": memory | |
}, | |
) | |
app = FastAPI(title="homepage-app") | |
api_app = FastAPI(title="api app") | |
app.mount("/api", api_app, name="api") | |
app.mount("/", StaticFiles(directory="static",html = True), name="static") | |
def run_ingest_route(): | |
global DB | |
global RETRIEVER | |
global QA | |
try: | |
if os.path.exists(PERSIST_DIRECTORY): | |
try: | |
shutil.rmtree(PERSIST_DIRECTORY) | |
except OSError as e: | |
raise HTTPException(status_code=500, detail=f"Error: {e.filename} - {e.strerror}.") | |
else: | |
raise HTTPException(status_code=500, detail="The directory does not exist") | |
run_langest_commands = ["python", "ingest.py"] | |
if DEVICE_TYPE == "cpu": | |
run_langest_commands.append("--device_type") | |
run_langest_commands.append(DEVICE_TYPE) | |
result = subprocess.run(run_langest_commands, capture_output=True) | |
if result.returncode != 0: | |
raise HTTPException(status_code=400, detail="Script execution failed: {}") | |
# load the vectorstore | |
DB = Chroma( | |
persist_directory=PERSIST_DIRECTORY, | |
embedding_function=EMBEDDINGS, | |
client_settings=CHROMA_SETTINGS, | |
) | |
RETRIEVER = DB.as_retriever() | |
QA = RetrievalQA.from_chain_type( | |
llm=LLM, | |
chain_type="stuff", | |
retriever=RETRIEVER, | |
return_source_documents=SHOW_SOURCES, | |
chain_type_kwargs={ | |
"prompt": QA_CHAIN_PROMPT, | |
"memory": memory | |
}, | |
) | |
return {"response": "The training was successfully completed"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}") | |
def get_files(): | |
upload_dir = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY) | |
files = glob.glob(os.path.join(upload_dir, '*')) | |
return {"directory": upload_dir, "files": files} | |
def delete_source_route(data: Delete): | |
filename = data.filename | |
path_source_documents = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY) | |
file_to_delete = f"{path_source_documents}/{filename}" | |
if os.path.exists(file_to_delete): | |
try: | |
os.remove(file_to_delete) | |
print(f"{file_to_delete} has been deleted.") | |
return {"message": f"{file_to_delete} has been deleted."} | |
except OSError as e: | |
raise HTTPException(status_code=400, detail=print(f"error: {e}.")) | |
else: | |
raise HTTPException(status_code=400, detail=print(f"The file {file_to_delete} does not exist.")) | |
async def predict(data: Predict): | |
global QA | |
user_prompt = data.prompt | |
if user_prompt: | |
res = QA(user_prompt) | |
print(res) | |
answer, docs = res["result"], res["source_documents"] | |
prompt_response_dict = { | |
"Prompt": user_prompt, | |
"Answer": answer, | |
} | |
prompt_response_dict["Sources"] = [] | |
for document in docs: | |
prompt_response_dict["Sources"].append( | |
(os.path.basename(str(document.metadata["source"])), str(document.page_content)) | |
) | |
qa_chain_response = res.stream( | |
{"query": user_prompt}, | |
) | |
print(f"{qa_chain_response} stream") | |
# generated_text = "" | |
# for new_text in STREAMER: | |
# generated_text += new_text | |
# print(generated_text) | |
return {"response": prompt_response_dict} | |
else: | |
raise HTTPException(status_code=400, detail="Prompt Incorrect") | |
async def create_upload_file(file: UploadFile): | |
# Get the file size (in bytes) | |
file.file.seek(0, 2) | |
file_size = file.file.tell() | |
# move the cursor back to the beginning | |
await file.seek(0) | |
if file_size > 10 * 1024 * 1024: | |
# more than 10 MB | |
raise HTTPException(status_code=400, detail="File too large") | |
content_type = file.content_type | |
if content_type not in [ | |
"text/plain", | |
"text/markdown", | |
"text/x-markdown", | |
"text/csv", | |
"application/msword", | |
"application/pdf", | |
"application/vnd.ms-excel", | |
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", | |
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", | |
"text/x-python", | |
"application/x-python-code"]: | |
raise HTTPException(status_code=400, detail="Invalid file type") | |
upload_dir = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY) | |
if not os.path.exists(upload_dir): | |
os.makedirs(upload_dir) | |
dest = os.path.join(upload_dir, file.filename) | |
with open(dest, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
return {"filename": file.filename} | |
async def websocket_endpoint(websocket: WebSocket): | |
await websocket.accept() | |
while True: | |
data = await websocket.receive_text() | |
await websocket.send_text(f"Message text was: {data}") | |