Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from text_generation import Client | |
# HF-hosted endpoint for testing purposes (requires an HF API token) | |
API_TOKEN = os.environ.get("API_TOKEN", None) | |
CURRENT_CLIENT = Client("https://afrts4trc759c6eq.us-east-1.aws.endpoints.huggingface.cloud/generate_stream", | |
timeout=120, | |
headers={ | |
"Accept": "application/json", | |
"Authorization": f"Bearer {API_TOKEN}", | |
"Content-Type": "application/json"} | |
) | |
DEFAULT_HEADER = os.environ.get("HEADER", "") | |
DEFAULT_USER_NAME = os.environ.get("USER_NAME", "user") | |
DEFAULT_ASSISTANT_NAME = os.environ.get("ASSISTANT_NAME", "assistant") | |
DEFAULT_SEPARATOR = os.environ.get("SEPARATOR", "<|im_end|>") | |
PROMPT_TEMPLATE = "<|im_start|>{user_name}\n{query}{separator}\n<|im_start|>{assistant_name}\n{response}" | |
repo = None | |
def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): | |
past = [] | |
for data in chatbot: | |
user_data, model_data = data | |
if not user_data.startswith(user_name): | |
user_data = user_name + user_data | |
if not model_data.startswith(sep + assistant_name): | |
model_data = sep + assistant_name + model_data | |
past.append(user_data + model_data.rstrip() + sep) | |
if not inputs.startswith(user_name): | |
inputs = user_name + inputs | |
total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() | |
return total_inputs | |
def has_no_history(chatbot, history): | |
return not chatbot and not history | |
def generate( | |
user_message, | |
chatbot, | |
history, | |
temperature, | |
top_p, | |
max_new_tokens, | |
repetition_penalty, | |
header, | |
user_name, | |
assistant_name, | |
separator | |
): | |
# Don't return meaningless message when the input is empty | |
if not user_message: | |
print("Empty input") | |
history.append(user_message) | |
past_messages = [] | |
for data in chatbot: | |
user_data, model_data = data | |
past_messages.extend( | |
[{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] | |
) | |
print(past_messages) | |
if len(past_messages) < 1: | |
prompt = header + PROMPT_TEMPLATE.format(user_name=user_name, | |
query=user_message, | |
assistant_name=assistant_name, | |
response="", | |
separator=separator) | |
else: | |
prompt = header | |
for i in range(0, len(past_messages), 2): | |
intermediate_prompt = PROMPT_TEMPLATE.format(user_name=user_name, | |
query=past_messages[i]["content"], | |
assistant_name=assistant_name, | |
response=past_messages[i + 1]["content"], | |
separator=separator) | |
# print(prompt, separator, intermediate_prompt) | |
prompt = prompt + intermediate_prompt + separator + "\n" | |
# print(prompt) | |
prompt = prompt + PROMPT_TEMPLATE.format(user_name=user_name, | |
query=user_message, | |
assistant_name=assistant_name, | |
response="", | |
separator=separator) | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
top_k=40, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
truncate=1024, | |
# seed=42, | |
# stop_sequences=[user_name, DEFAULT_SEPARATOR] | |
stop_sequences=[DEFAULT_SEPARATOR] | |
) | |
# print(prompt) | |
stream = CURRENT_CLIENT.generate_stream( | |
prompt, | |
**generate_kwargs, | |
) | |
output = "" | |
for idx, response in enumerate(stream): | |
# print(response.token) | |
if response.token.text == '': | |
pass | |
# print(response.token.text) | |
# break | |
if response.token.special: | |
continue | |
output += response.token.text | |
if idx == 0: | |
history.append(" " + output) | |
else: | |
history[-1] = output | |
chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)] | |
# chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] | |
yield chat, history, user_message, "" | |
return chat, history, user_message, "" | |
def clear_chat(): | |
return [], [] | |
title = """<h1 align="center">CroissantLLMChat Playground π₯</h1>""" | |
custom_css = """ | |
#banner-image { | |
display: block; | |
margin-left: auto; | |
margin-right: auto; | |
} | |
#chat-message { | |
font-size: 14px; | |
min-height: 300px; | |
} | |
""" | |
with gr.Blocks(analytics_enabled=False, css=custom_css) as demo: | |
gr.HTML(title) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
## Demo platform for π₯ CroissantLLMChat | |
### Usage recommendations | |
We recommend testing the chat model for open-ended writing tasks, tips, translations, etc... | |
We find direct instructions to work best, and performance to drop after the first round of interactions. | |
We limit the length of each message to 256 tokens by default (can be changed in the settings below), and of the entire conversation so clear the Chat between tests ! | |
### Errors | |
The model is very small in size (1.3B), about 130 times smaller than GPT3. As such, it's generalist Chat version logically exhibits reduced understanding, reasoning and knowledge capacities, and may still exhibit undesired behavior such as hallucinations, or toxicity (rarely)... | |
For industrial applications, we recommend finetuning the model, but trained this Chat version to allow for experimenting and to showcase the capabilities for it's size. | |
### More info | |
ποΈ The blogpost: https://huggingface.co/blog/manu/croissant-llm-blog | |
π The 45 page report with lots of gems: https://arxiv.org/abs/2402.00786 | |
π€ Models, Data, Demo: https://huggingface.co/croissantllm | |
### | |
""" | |
) | |
with gr.Row(): | |
with gr.Group(): | |
output = gr.Markdown() | |
chatbot = gr.Chatbot(elem_id="chat-message", label="Chat") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input") | |
with gr.Row(): | |
send_button = gr.Button("Send", elem_id="send-btn", visible=True) | |
clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True) | |
with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"): | |
temperature = gr.Slider( | |
label="Temperature", | |
value=0.3, | |
minimum=0.1, | |
maximum=1.0, | |
step=0.1, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
) | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.9, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
) | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
value=256, | |
minimum=0, | |
maximum=512, | |
step=8, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition Penalty", | |
value=1.05, | |
minimum=0.0, | |
maximum=2, | |
step=0.05, | |
interactive=True, | |
info="The parameter for repetition penalty. 1.0 means no penalty.", | |
) | |
with gr.Accordion(label="Prompt", open=False, elem_id="prompt-accordion"): | |
header = gr.Textbox( | |
label="Header instructions", | |
value=DEFAULT_HEADER, | |
interactive=True, | |
info="Instructions given to the assistant at the beginning of the prompt", | |
) | |
user_name = gr.Textbox( | |
label="User name", | |
value=DEFAULT_USER_NAME, | |
interactive=True, | |
info="Name to be given to the user in the prompt", | |
) | |
assistant_name = gr.Textbox( | |
label="Assistant name", | |
value=DEFAULT_ASSISTANT_NAME, | |
interactive=True, | |
info="Name to be given to the assistant in the prompt", | |
) | |
separator = gr.Textbox( | |
label="Separator", | |
value=DEFAULT_SEPARATOR, | |
interactive=True, | |
info="Character to be used when the speaker changes in the prompt", | |
) | |
history = gr.State([]) | |
last_user_message = gr.State("") | |
user_message.submit( | |
generate, | |
inputs=[ | |
user_message, | |
chatbot, | |
history, | |
temperature, | |
top_p, | |
max_new_tokens, | |
repetition_penalty, | |
header, | |
user_name, | |
assistant_name, | |
separator | |
], | |
outputs=[chatbot, history, last_user_message, user_message], | |
) | |
send_button.click( | |
generate, | |
inputs=[ | |
user_message, | |
chatbot, | |
history, | |
temperature, | |
top_p, | |
max_new_tokens, | |
repetition_penalty, | |
header, | |
user_name, | |
assistant_name, | |
separator | |
], | |
outputs=[chatbot, history, last_user_message, user_message], | |
) | |
clear_chat_button.click(clear_chat, outputs=[chatbot, history]) | |
demo.queue().launch() | |