Spaces:
Sleeping
Sleeping
import gradio as gr | |
import anthropic | |
import json | |
import requests | |
import warnings | |
import logging | |
import os | |
import pandas as pd | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Initialize Anthropoc client with API key | |
client = anthropic.Client(api_key=os.getenv('ANTHROPIC_API_KEY')) | |
MODEL_NAME = "claude-3-5-sonnet-20240620" | |
# Define the base URL for the FastAPI service | |
BASE_URL = "https://dwb2023-blackbird-svc.hf.space" | |
# Define tools | |
tools = [ | |
{ | |
"name": "get_user", | |
"description": "Looks up a user by email, phone, or username.", | |
"input_schema": { | |
"type": "object", | |
"properties": { | |
"key": { | |
"type": "string", | |
"enum": ["email", "phone", "username"], | |
"description": "The attribute to search for a user by (email, phone, or username)." | |
}, | |
"value": { | |
"type": "string", | |
"description": "The value to match for the specified attribute." | |
} | |
}, | |
"required": ["key", "value"] | |
} | |
}, | |
{ | |
"name": "get_order_by_id", | |
"description": "Retrieves the details of a specific order based on the order ID.", | |
"input_schema": { | |
"type": "object", | |
"properties": { | |
"order_id": { | |
"type": "string", | |
"description": "The unique identifier for the order." | |
} | |
}, | |
"required": ["order_id"] | |
} | |
}, | |
{ | |
"name": "get_customer_orders", | |
"description": "Retrieves the list of orders belonging to a user based on a user's customer id.", | |
"input_schema": { | |
"type": "object", | |
"properties": { | |
"customer_id": { | |
"type": "string", | |
"description": "The customer_id belonging to the user" | |
} | |
}, | |
"required": ["customer_id"] | |
} | |
}, | |
{ | |
"name": "cancel_order", | |
"description": "Cancels an order based on a provided order_id. Only orders that are 'processing' can be cancelled.", | |
"input_schema": { | |
"type": "object", | |
"properties": { | |
"order_id": { | |
"type": "string", | |
"description": "The order_id pertaining to a particular order" | |
} | |
}, | |
"required": ["order_id"] | |
} | |
}, | |
{ | |
"name": "update_user_contact", | |
"description": "Updates a user's email and/or phone number.", | |
"input_schema": { | |
"type": "object", | |
"properties": { | |
"user_id": { | |
"type": "string", | |
"description": "The ID of the user" | |
}, | |
"email": { | |
"type": "string", | |
"description": "The new email address of the user" | |
}, | |
"phone": { | |
"type": "string", | |
"description": "The new phone number of the user" | |
} | |
}, | |
"required": ["user_id"] | |
} | |
}, | |
{ | |
"name": "get_user_info", | |
"description": "Retrieves a user's information along with their order history based on email, phone, or username.", | |
"input_schema": { | |
"type": "object", | |
"properties": { | |
"key": { | |
"type": "string", | |
"enum": ["email", "phone", "username"], | |
"description": "The attribute to search for a user by (email, phone, or username)." | |
}, | |
"value": { | |
"type": "string", | |
"description": "The value to match for the specified attribute." | |
} | |
}, | |
"required": ["key", "value"] | |
} | |
} | |
] | |
# Suppress the InsecureRequestWarning | |
warnings.filterwarnings("ignore", category=requests.urllib3.exceptions.InsecureRequestWarning) | |
def process_tool_call(tool_name, tool_input): | |
tool_endpoints = { | |
"get_user": "get_user", | |
"get_order_by_id": "get_order_by_id", | |
"get_customer_orders": "get_customer_orders", | |
"cancel_order": "cancel_order", | |
"update_user_contact": "update_user", | |
"get_user_info": "get_user_info" | |
} | |
if tool_name in tool_endpoints: | |
response = requests.post(f"{BASE_URL}/{tool_endpoints[tool_name]}", json=tool_input, verify=False) | |
else: | |
logger.error(f"Invalid tool name: {tool_name}") | |
return {"error": "Invalid tool name"} | |
if response.status_code == 200: | |
return response.json() | |
else: | |
logger.error(f"Tool call failed: {response.text}") | |
return {"error": response.text} | |
system_prompt = """ | |
You are a customer support chat bot for an online retailer called BlackBird. | |
Your job is to help users look up their account, orders, and cancel orders. | |
Be helpful and brief in your responses. | |
You have access to a set of tools, but only use them when needed. | |
If you do not have enough information to use a tool correctly, ask a user follow up questions to get the required inputs. | |
Do not call any of the tools unless you have the required data from a user. | |
In each conversational turn, you will begin by thinking about your response. | |
Once you're done, you will write a user-facing response. | |
""" | |
def simple_chat(user_message, history): | |
# Reconstruct the message history | |
messages = [] | |
for i, (user_msg, assistant_msg) in enumerate(history): | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": user_message}) | |
full_response = "" | |
MAX_ITERATIONS = 5 | |
iteration_count = 0 | |
while iteration_count < MAX_ITERATIONS: | |
try: | |
logger.info(f"Sending messages to API: {json.dumps(messages, indent=2)}") | |
response = client.messages.create( | |
model=MODEL_NAME, | |
system=system_prompt, | |
max_tokens=4096, | |
tools=tools, | |
messages=messages, | |
) | |
assistant_message = response.content[0].text if isinstance(response.content, list) else response.content | |
if response.stop_reason == "tool_use": | |
tool_use = response.content[-1] | |
tool_name = tool_use.name | |
tool_input = tool_use.input | |
tool_result = process_tool_call(tool_name, tool_input) | |
# Add assistant message indicating tool use | |
messages.append({"role": "assistant", "content": assistant_message}) | |
# Add user message with tool result to maintain role alternation | |
messages.append({ | |
"role": "user", | |
"content": json.dumps({ | |
"type": "tool_result", | |
"tool_use_id": tool_use.id, | |
"content": tool_result, | |
}) | |
}) | |
full_response += f"\nUsing tool: {tool_name}\n" | |
iteration_count += 1 | |
continue | |
else: | |
# Add the assistant's reply to the full response | |
full_response += assistant_message | |
messages.append({"role": "assistant", "content": assistant_message}) | |
break | |
except anthropic.BadRequestError as e: | |
logger.error(f"BadRequestError: {str(e)}") | |
full_response = f"Error: {str(e)}" | |
break | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
full_response = f"An unexpected error occurred: {str(e)}" | |
break | |
logger.info(f"Final messages: {json.dumps(messages, indent=2)}") | |
if iteration_count == MAX_ITERATIONS: | |
logger.warning("Maximum iterations reached in simple_chat") | |
history.append((user_message, full_response)) | |
return history, "", messages # Return messages as well | |
def messages_to_dataframe(messages): | |
data = [] | |
for msg in messages: | |
row = { | |
'role': msg['role'], | |
'content': msg['content'] if isinstance(msg['content'], str) else json.dumps(msg['content']), | |
'tool_use': None, | |
'tool_result': None | |
} | |
if msg['role'] == 'assistant' and isinstance(msg['content'], list): | |
for item in msg['content']: | |
if isinstance(item, dict) and 'type' in item: | |
if item['type'] == 'tool_use': | |
row['tool_use'] = json.dumps(item) | |
elif item['type'] == 'tool_result': | |
row['tool_result'] = json.dumps(item) | |
data.append(row) | |
return pd.DataFrame(data) | |
def submit_message(message, history): | |
history, _, messages = simple_chat(message, history) | |
df = messages_to_dataframe(messages) | |
print(df) # For console output | |
return history, "", df | |
with gr.Blocks() as demo: | |
gr.Markdown("# BlackBird Customer Support Chat") | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="Your message") | |
clear = gr.Button("Clear") | |
df_output = gr.Dataframe(label="Conversation Analysis") | |
submit_event = msg.submit(submit_message, [msg, chatbot], [chatbot, msg, df_output]).then( | |
lambda: "", None, msg | |
) | |
example_inputs = [ | |
"What's the status of my orders? My Customer id is 2837622", | |
"Can you confirm my customer info and order status? My email is [email protected]", | |
"I'd like to cancel an order", | |
"Can you update my email address to [email protected]?", | |
] | |
examples = gr.Examples( | |
examples=example_inputs, | |
inputs=msg | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.launch() | |