Spaces:
Running
Running
import os | |
import json | |
import requests | |
import gradio as gr | |
import threading | |
import time | |
import PyPDF2 | |
import chromadb | |
import shutil | |
from pydantic import BaseModel, Field | |
from typing import Dict | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
API_KEY = os.getenv("mistral") | |
BASE_URL = "https://api.together.xyz" | |
# Store user inputs | |
user_inputs = { | |
"organization": "", | |
"rules_l1": "", | |
"rules_l2": "", | |
"rules_l3": "", | |
} | |
# Function to classify query | |
def classify_query(query: str) -> Dict: | |
if not all(user_inputs.values()): | |
raise ValueError("Please fill all input fields first.") | |
messages = [ | |
{"role": "system", "content": f"""You are a Customer Query Classification Agent for {user_inputs["organization"]}. | |
What is considered Level 1 Query (Requires no account info just provided documents by the admin is enough to answer): | |
{user_inputs["rules_l1"]} | |
What is considered Level 2 Query (Requires account info and provided documents by the admin is enough to answer): | |
{user_inputs["rules_l2"]} | |
What is considered as Level 3 Query (Immediate Escalation to Human Customer Service Agents): | |
{user_inputs["rules_l3"]} | |
Classify the following customer query and provide the output in JSON format: | |
```json | |
{{ | |
"title": "title of the query in under 10 words", | |
"level": "1 or 2 or 3" | |
}} | |
```"""}, | |
{"role": "user", "content": query} | |
] | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {API_KEY}" | |
} | |
data = { | |
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"messages": messages, | |
"temperature": 0.7, | |
"response_format": { | |
"type": "json_object", | |
"schema": { | |
"type": "object", | |
"properties": { | |
"title": {"type": "string"}, | |
"level": {"type": "integer"} | |
}, | |
"required": ["title", "level"] | |
} | |
} | |
} | |
response = requests.post(f"{BASE_URL}/chat/completions", headers=headers, json=data) | |
response.raise_for_status() | |
classification_result = response.json().get('choices')[0].get('message').get('content') | |
return classification_result | |
# Function to convert PDF to text | |
def pdf_to_text(file_path): | |
pdf_file = open(file_path, 'rb') | |
pdf_reader = PyPDF2.PdfReader(pdf_file) | |
text = "" | |
for page_num in range(len(pdf_reader.pages)): | |
text += pdf_reader.pages[page_num].extract_text() | |
pdf_file.close() | |
return text | |
# Function to handle file upload and save embeddings to ChromaDB | |
def handle_file_upload(files, collection_name): | |
if not collection_name: | |
return "Please provide a collection name." | |
os.makedirs('chabot_pdfs', exist_ok=True) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small") | |
# Initialize Chroma DB client | |
client = chromadb.PersistentClient(path="./db") | |
try: | |
collection = client.create_collection(name=collection_name) | |
except ValueError as e: | |
return f"Error creating collection: {str(e)}. Please try a different collection name." | |
for file in files: | |
file_name = os.path.basename(file.name) | |
file_path = os.path.join('chabot_pdfs', file_name) | |
shutil.copy(file.name, file_path) # Copy the file instead of saving | |
text = pdf_to_text(file_path) | |
chunks = text_splitter.split_text(text) | |
documents_list = [] | |
embeddings_list = [] | |
ids_list = [] | |
for i, chunk in enumerate(chunks): | |
vector = embeddings.embed_query(chunk) | |
documents_list.append(chunk) | |
embeddings_list.append(vector) | |
ids_list.append(f"{file_name}_{i}") | |
collection.add( | |
embeddings=embeddings_list, | |
documents=documents_list, | |
ids=ids_list | |
) | |
return "Files uploaded and processed successfully." | |
# Function to search vector database | |
def search_vector_database(query, collection_name): | |
if not collection_name: | |
return "Please provide a collection name." | |
embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small") | |
client = chromadb.PersistentClient(path="./db") | |
try: | |
collection = client.get_collection(name=collection_name) | |
except ValueError as e: | |
return f"Error accessing collection: {str(e)}. Make sure the collection name is correct." | |
query_vector = embeddings.embed_query(query) | |
results = collection.query(query_embeddings=[query_vector], n_results=2, include=["documents"]) | |
return "\n\n".join("\n".join(result) for result in results["documents"]) | |
# New function to handle login | |
def handle_login(username, password): | |
# This is a simple example. In a real application, you'd want to use secure authentication methods. | |
if username == "admin" and password == "password": | |
return """ | |
"NeoBank": { | |
"user_id": "NB782940", | |
"user_name": "john_doe123", | |
"full_name": "John Doe", | |
"email": "[email protected]", | |
"balance": 2875.43, | |
"transactions": [ | |
{"date": "2024-06-20", "description": "Coffee Shop", "amount": -4.50}, | |
{"date": "2024-06-19", "description": "Grocery Store", "amount": -85.22}, | |
{"date": "2024-06-18", "description": "Salary Deposit", "amount": 2500.00} | |
] | |
}, | |
"CryptoInvest": { | |
"user_id": "CI549217", | |
"user_name": "crypto_enthusiast", | |
"full_name": "Alice Johnson", | |
"email": "[email protected]", | |
"portfolio": { | |
"BTC": {"amount": 0.025, "value": 7500.00}, | |
"ETH": {"amount": 1.2, "value": 2100.00}, | |
"SOL": {"amount": 5.8, "value": 450.50} | |
}, | |
"transactions": [ | |
{"date": "2024-06-22", "description": "Bought ETH", "amount": -500.00}, | |
{"date": "2024-06-20", "description": "Sold BTC", "amount": 1200.00} | |
] | |
}, | |
"RoboAdvisor": { | |
"user_id": "RA385712", | |
"user_name": "jane_smith", | |
"full_name": "Jane Smith", | |
"email": "[email protected]", | |
"risk_tolerance": "moderate", | |
"portfolio_value": 15800.75, | |
"allocations": { | |
"stocks": 0.60, | |
"bonds": 0.30, | |
"real_estate": 0.10 | |
}, | |
"recent_activity": [ | |
{"date": "2024-06-21", "description": "Dividends received", "amount": 32.50}, | |
{"date": "2024-06-15", "description": "Portfolio rebalanced" } | |
] | |
}, | |
"PeerLend": { | |
"user_id": "PL916350", | |
"user_name": "bob_williams", | |
"full_name": "Bob Williams", | |
"email": "[email protected]", | |
"account_type": "borrower", | |
"loan_amount": 5000.00, | |
"interest_rate": 7.8, | |
"monthly_payment": 150.30, | |
"payment_history": [ | |
{"date": "2024-06-22", "status": "paid"}, | |
{"date": "2024-05-22", "status": "paid"}, | |
{"date": "2024-04-22", "status": "paid"} | |
] | |
}, | |
"InsureTech": { | |
"user_id": "IT264805", | |
"user_name": "eva_brown4", | |
"full_name": "Eva Brown", | |
"email": "[email protected]", | |
"policy_type": "auto", | |
"coverage_details": { | |
"liability": "50/100/50", | |
"collision": "500 deductible", | |
"comprehensive": "100 deductible" | |
}, | |
"premium": 85.50, | |
"next_payment": "2024-07-10", | |
"claims": [] | |
} | |
""" | |
else: | |
return "Invalid username or password" | |
# Gradio interface | |
def gradio_interface(): | |
with gr.Blocks(theme='gl198976/The-Rounded') as interface: | |
gr.Markdown("# Admin Dashboard🧖🏻♀️") | |
with gr.Tab("Query Classifier Agent"): | |
with gr.Row(): | |
with gr.Column(): | |
organization_input = gr.Textbox(label="Organization Name") | |
rules_l1_input = gr.Textbox(label="Rules for Level 1 Query", lines=5) | |
rules_l2_input = gr.Textbox(label="Rules for Level 2 Query", lines=5) | |
rules_l3_input = gr.Textbox(label="Rules for Level 3 Query", lines=5) | |
submit_btn = gr.Button("Submit Rules") | |
with gr.Column(): | |
query_input = gr.Textbox(label="Customer Query") | |
classification_output = gr.Textbox(label="Classification Result") | |
classify_btn = gr.Button("Classify Query") | |
api_details = gr.Markdown(""" | |
### API Endpoint Details | |
- **URL:** `http://0.0.0.0:7860/classify` | |
- **Method:** POST | |
- **Request Body:** JSON with a single key `query` | |
- **Example Usage:** | |
```python | |
from gradio_client import Client | |
client = Client("http://0.0.0.0:7860/") | |
result = client.predict( | |
"Hello!!", # str in 'Customer Query' Textbox component | |
api_name="/classify_and_display" | |
) | |
print(result) | |
``` | |
""") | |
submit_btn.click(lambda org, r1, r2, r3: ( | |
setattr(user_inputs, "organization", org), | |
setattr(user_inputs, "rules_l1", r1), | |
setattr(user_inputs, "rules_l2", r2), | |
setattr(user_inputs, "rules_l3", r3) | |
), inputs=[organization_input, rules_l1_input, rules_l2_input, rules_l3_input]) | |
classify_btn.click(classify_query, inputs=[query_input], outputs=[classification_output]) | |
with gr.Tab("Organization Documentation Agent"): | |
gr.Markdown(""" | |
### Warning | |
If you encounter an error when uploading files, try changing the collection name and upload again. | |
Each collection name must be unique. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
collection_name_input = gr.Textbox(label="Collection Name", placeholder="Enter a unique name for this collection") | |
file_upload = gr.Files(file_types=[".pdf"], label="Upload PDFs") | |
upload_btn = gr.Button("Upload and Process Files") | |
upload_status = gr.Textbox(label="Upload Status", interactive=False) | |
with gr.Column(): | |
search_query_input = gr.Textbox(label="Search Query") | |
search_output = gr.Textbox(label="Search Results", lines=10) | |
search_btn = gr.Button("Search") | |
api_details = gr.Markdown(""" | |
### API Endpoint Details | |
- **URL:** `http://0.0.0.0:7860/search_vector_database` | |
- **Method:** POST | |
- **Example Usage:** | |
```python | |
from gradio_client import Client | |
client = Client("http://0.0.0.0:7860/") | |
result = client.predict( | |
"search query", # str in 'Search Query' Textbox component | |
"name of collection given in ui", # str in 'Collection Name' Textbox component | |
api_name="/search_vector_database" | |
) | |
print(result) | |
``` | |
""") | |
upload_btn.click(handle_file_upload, inputs=[file_upload, collection_name_input], outputs=[upload_status]) | |
search_btn.click(search_vector_database, inputs=[search_query_input, collection_name_input], outputs=[search_output]) | |
with gr.Tab("Account Information"): | |
with gr.Row(): | |
with gr.Column(): | |
username_input = gr.Textbox(label="Username") | |
password_input = gr.Textbox(label="Password", type="password") | |
login_btn = gr.Button("Login") | |
with gr.Column(): | |
account_info_output = gr.Textbox(label="Account Info", lines=20) | |
api_details = gr.Markdown(""" | |
### API Endpoint Details | |
- **URL:** `http://0.0.0.0:7860/handle_login` | |
- **Method:** POST | |
- **Example Usage:** | |
```python | |
from gradio_client import Client | |
client = Client("http://0.0.0.0:7860/") | |
result = client.predict( | |
"admin", # str in 'Username' Textbox component | |
"password", # str in 'Password' Textbox component | |
api_name="/handle_login" | |
) | |
print(result) | |
``` | |
""") | |
login_btn.click(handle_login, inputs=[username_input, password_input], outputs=[account_info_output]) | |
interface.launch(server_name="0.0.0.0", server_port=7860) | |
if __name__ == "__main__": | |
gradio_interface() | |