Spaces:
Sleeping
Sleeping
import json | |
import os | |
import random | |
import gradio as gr | |
from langchain.schema import AIMessage, HumanMessage | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_openai import ChatOpenAI | |
from pydantic import BaseModel, SecretStr | |
class OAAPIKey(BaseModel): | |
openai_api_key: SecretStr | |
class HFAPIKey(BaseModel): | |
huggingface_api_key: SecretStr | |
def set_openai_api_key(api_key: SecretStr): | |
os.environ["OPENAI_API_KEY"] = api_key.get_secret_value() | |
llm = ChatOpenAI(temperature=1.0, model="gpt-3.5-turbo-0125") | |
return llm | |
def set_huggingface_api_key(api_key: SecretStr): | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key.get_secret_value() | |
your_endpoint_url = ( | |
"https://a0km823u69omaqm7.us-east-1.aws.endpoints.huggingface.cloud" | |
) | |
llm = HuggingFaceEndpoint( | |
endpoint_url=f"{your_endpoint_url}", | |
max_new_tokens=512, | |
top_k=10, | |
top_p=0.95, | |
typical_p=0.95, | |
temperature=0.01, | |
repetition_penalty=1.03, | |
stop_sequences=["<|human|>"], | |
) | |
return llm | |
def predict( | |
message: str, | |
chat_history_openai: list[tuple[str, str]], | |
chat_history_huggingface: list[tuple[str, str]], | |
openai_api_key: SecretStr, | |
huggingface_api_key: SecretStr, | |
): | |
openai_key_model = OAAPIKey(openai_api_key=openai_api_key) | |
huggingface_key_model = HFAPIKey(huggingface_api_key=huggingface_api_key) | |
openai_llm = set_openai_api_key(api_key=openai_key_model.openai_api_key) | |
huggingface_llm = set_huggingface_api_key( | |
api_key=huggingface_key_model.huggingface_api_key | |
) | |
# OpenAI | |
history_langchain_format_openai = [] | |
for human, ai in chat_history_openai: | |
history_langchain_format_openai.append(HumanMessage(content=human)) | |
history_langchain_format_openai.append(AIMessage(content=ai)) | |
history_langchain_format_openai.append(HumanMessage(content=message)) | |
openai_response = openai_llm.invoke(input=history_langchain_format_openai) | |
# Huggingface Endpoint | |
history_langchain_format_huggingface = [] | |
for human, ai in chat_history_openai: | |
history_langchain_format_huggingface.append(f"\n<|human|> {human}\n<|ai|> {ai}") | |
history_langchain_format_huggingface.append(f"\n<|human|> {message}\n<|ai|>") | |
huggingface_response = huggingface_llm.invoke( | |
input=history_langchain_format_huggingface | |
) | |
huggingface_response = huggingface_response.split("Human:")[0].strip() | |
chat_history_openai.append((message, openai_response.content)) | |
chat_history_huggingface.append((message, huggingface_response)) | |
return "", chat_history_openai, chat_history_huggingface | |
with open("askbakingtop.json", "r") as file: | |
ask_baking_msgs = json.load(file) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
openai_api_key = gr.Textbox( | |
label="Please enter your OpenAI API key", | |
type="password", | |
elem_id="lets-chat-openai-api-key", | |
) | |
with gr.Column(scale=1): | |
huggingface_api_key = gr.Textbox( | |
label="Please enter your HuggingFace API key", | |
type="password", | |
elem_id="lets-chat-huggingface-api-key", | |
) | |
with gr.Row(): | |
options = [ask["history"] for ask in random.sample(ask_baking_msgs, k=3)] | |
msg = gr.Dropdown( | |
options, | |
label="Please enter your message", | |
interactive=True, | |
multiselect=False, | |
allow_custom_value=True | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
chatbot_openai = gr.Chatbot(label="OpenAI Chatbot π’") | |
with gr.Column(scale=1): | |
chatbot_huggingface = gr.Chatbot( | |
label="Your own fine-tuned preference optimized Chatbot πͺ" | |
) | |
with gr.Row(): | |
submit_button = gr.Button("Submit") | |
with gr.Row(): | |
clear = gr.ClearButton([msg]) | |
def respond( | |
message: str, | |
chat_history_openai: list[tuple[str, str]], | |
chat_history_huggingface: list[tuple[str, str]], | |
openai_api_key: SecretStr, | |
huggingface_api_key: SecretStr, | |
): | |
return predict( | |
message=message, | |
chat_history_openai=chat_history_openai, | |
chat_history_huggingface=chat_history_huggingface, | |
openai_api_key=openai_api_key, | |
huggingface_api_key=huggingface_api_key, | |
) | |
submit_button.click( | |
fn=respond, | |
inputs=[ | |
msg, | |
chatbot_openai, | |
chatbot_huggingface, | |
openai_api_key, | |
huggingface_api_key, | |
], | |
outputs=[msg, chatbot_openai, chatbot_huggingface], | |
) | |
demo.launch() | |