christianweyer's picture
Adding "Model" string to radio group
72e682b verified
raw
history blame
3.76 kB
import gradio as gr
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
def get_stream_chat_completion(
message, chat_history, model, api_key, system=None, **kwargs
):
messages = []
if system is not None:
messages.append(ChatMessage(role="system", content=system))
for chat in chat_history:
human_message, bot_message = chat
messages.extend(
(
ChatMessage(role="user", content=human_message),
ChatMessage(role="assistant", content=bot_message),
)
)
messages.append(ChatMessage(role="user", content=message))
client = MistralClient(api_key=api_key)
for chunk in client.chat_stream(
model=model,
messages=messages,
**kwargs,
):
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
def respond_stream(
message,
chat_history,
api_key,
model,
temperature,
top_p,
max_tokens,
system,
):
response = ""
received_anything = False
for chunk in get_stream_chat_completion(
message=message,
chat_history=chat_history,
model=model,
api_key=api_key,
temperature=temperature,
top_p=top_p,
max_tokens=int(max_tokens),
system=system if system else None,
):
response += chunk
yield response
received_anything = True
if not received_anything:
gr.Warning("Error: Invalid API Key")
yield ""
css = """
.header-text p {line-height: 80px !important; text-align: left; font-size: 26px;}
.header-logo {text-align: left;}
.image-container img {max-width: 80px; height: auto;}
"""
with gr.Blocks(title="Mistral Playground", css=css) as mistral_playground:
with gr.Row():
with gr.Column(scale=1, min_width=80):
gr.Image("tt-logo.jpg", show_download_button=False, show_share_button=False, interactive=False, show_label=False, elem_id="thinktecture-logo", container=False)
with gr.Column(scale=11):
gr.Markdown("Thinktecture Mistral AI Playground", elem_classes="header-text")
with gr.Row(variant='panel'):
api_key = gr.Textbox(type='password', placeholder='Your Mistral AI API key', lines=1, label="Mistral AI API Key")
model = gr.Radio(label="Model",
choices=["open-mistral-7b", "open-mixtral-8x7b", "mistral-small-latest", "mistral-medium-latest", "mistral-large-latest"],
value="mistral-medium-latest",
)
with gr.Row(variant='panel'):
temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.2, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label="Top P", value=0.95)
max_tokens = gr.Slider(minimum=1, maximum=16000, step=500, label="Max Tokens", value=4000)
with gr.Row(variant='panel'):
system = gr.Textbox(lines=2, label="System Message", value="You are a helpfull AI assistant.")
with gr.Row(variant='panel'):
gr.ChatInterface(
respond_stream,
additional_inputs=[
api_key,
model,
temperature,
top_p,
max_tokens,
system,
],
)
with gr.Row():
gr.HTML(value="<p style='margin-top: 1rem; margin-bottom: 1rem; text-align: center;'>Developed by Marco Frodl, Principal Consultant for Generative AI @ <a href='https://go.mfr.one/tt-en' _target='blank'>Thinktecture AG</a> -- Released 02/26/2024 -- More about me on my <a href='https://go.mfr.one/marcofrodl-en' _target='blank'>profile page</a></p>")
mistral_playground.launch()