MaziyarPanahi's picture
new update
25f188a unverified
raw
history blame
7.32 kB
import gradio as gr
import time
import requests
import json
import os
from urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter
API_URL = os.getenv("API_URL")
API_KEY = os.getenv("API_KEY")
print(f"API_URL: {API_URL}")
print(f"API_KEY: {API_KEY}")
url = f"{API_URL}/v1/chat/completions"
# The headers for the HTTP request
headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}",
}
def is_valid_json(data):
try:
parsed_data = json.loads(data)
return True, parsed_data
except ValueError as e:
return False, str(e)
with gr.Blocks() as demo:
markup = gr.Markdown(
"""
# Mistral 7B Instruct v0.2
This is a demo of the Mistral 7B Instruct quantized model in GGUF (Q2) hosted on K8s cluster.
The original models can be found [MaziyarPanahi/Mistral-7B-Instruct-v0.2-GGUF](https://huggingface.co/MaziyarPanahi/Mistral-7B-Instruct-v0.2-GGUF)"""
)
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(lines=1, label="User Message")
clear = gr.Button("Clear")
with gr.Row():
with gr.Column(scale=2):
system_prompt_input = gr.Textbox(
label="System Prompt",
placeholder="Type system prompt here...",
value="You are a helpful assistant.",
)
temperature_input = gr.Slider(
label="Temperature", minimum=0.0, maximum=1.0, value=0.9, step=0.01
)
max_new_tokens_input = gr.Slider(
label="Max New Tokens", minimum=0, maximum=1024, value=256, step=1
)
with gr.Column(scale=2):
top_p_input = gr.Slider(
label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.01
)
top_k_input = gr.Slider(
label="Top K", minimum=1, maximum=100, value=50, step=1
)
repetition_penalty_input = gr.Slider(
label="Repetition Penalty",
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.01,
)
def update_globals(
system_prompt, temperature, max_new_tokens, top_p, top_k, repetition_penalty
):
global global_system_prompt, global_temperature, global_max_new_tokens, global_top_p, global_repetition_penalty, global_top_k
global_system_prompt = system_prompt
global_temperature = temperature
global_max_new_tokens = max_new_tokens
global_top_p = top_p
global_top_k = top_k
global_repetition_penalty = repetition_penalty
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(
history,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
repetition_penalty,
):
print(f"History in bot: {history}")
print(f"System Prompt: {system_prompt}")
print(f"Temperature: {temperature}")
print(f"Max New Tokens: {max_new_tokens}")
print(f"Top P: {top_p}")
print(f"Top K: {top_k}")
print(f"Repetition Penalty: {repetition_penalty}")
history_messages = [{"content": h[0], "role": "user"} for h in history if h[0]]
history[-1][1] = ""
sys_msg = [
{
"content": (
system_prompt if system_prompt else "You are a helpful assistant."
),
"role": "system",
}
]
history_messages = sys_msg + history_messages
print(history_messages)
# Create a session object
session = requests.Session()
# Define the retry strategy
retries = Retry(
total=5, # Total number of retries to allow
backoff_factor=1, # A backoff factor to apply between attempts
status_forcelist=[
500,
502,
503,
504,
], # A set of HTTP status codes that we should force a retry on
method_whitelist=[
"HEAD",
"GET",
"OPTIONS",
"POST",
], # HTTP methods to retry on
)
data = {
"messages": history_messages,
"stream": True,
"temprature": temperature,
"top_k": top_k,
"top_p": top_p,
"seed": 42,
"repeat_penalty": repetition_penalty,
"chat_format": "mistral-instruct",
"max_tokens": max_new_tokens,
# "response_format": {
# "type": "json_object",
# },
}
# Mount it for http usage
session.mount("http://", HTTPAdapter(max_retries=retries))
# Making the POST request with increased timeout and retry logic
try:
response = session.post(
url,
headers=headers,
data=json.dumps(data),
stream=True,
timeout=(10, 30),
)
for line in response.iter_lines():
if line:
for line in response.iter_lines():
# Filter out keep-alive new lines
if line:
data = line.decode("utf-8").lstrip("data: ")
# Check if the examples are valid
valid_check = is_valid_json(data)
if valid_check[0]:
try:
# Attempt to parse the JSON dataa
# json_data = json.loads(data)
json_data = valid_check[1]
delta_content = (
json_data.get("choices", [{}])[0]
.get("delta", {})
.get("content", "")
)
if delta_content: # Ensure there's content to print
history[-1][1] += delta_content
time.sleep(0.05)
yield history
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e} date: {data}")
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
msg.submit(
user, [msg, chatbot], [msg, chatbot], queue=True, concurrency_limit=10
).then(
bot,
inputs=[
chatbot,
system_prompt_input,
temperature_input,
max_new_tokens_input,
top_p_input,
top_k_input,
repetition_penalty_input,
],
outputs=chatbot,
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue(default_concurrency_limit=20, max_size=20, api_open=False)
if __name__ == "__main__":
demo.launch(show_api=False, share=False)