import gradio as gr from huggingface_hub import InferenceClient SYSTEM_MESSAGE_DEFAULT = "You are a friendly Chatbot." MAX_TOKENS_DEFAULT = 512 TEMPERATURE_DEFAULT = 0.7 TOP_P_DEFAULT = 0.95 inference_client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") def respond( user_message: str, conversation_history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, ): """ Respond to a user message given the conversation history and other parameters. Args: user_message (str): The user's message. conversation_history (list[tuple[str, str]]): The conversation history. system_message (str): The system message to display at the top of the chat interface. max_tokens (int): The maximum number of tokens to generate in the response. temperature (float): The temperature to use when generating text. top_p (float): The top-p value to use when generating text. Yields: list[tuple[str, str]]: Updated conversation history with the new assistant response. """ messages = [{"role": "system", "content": system_message}] # Prepare messages for the model based on the history for user_input, assistant_response in conversation_history: if user_input: messages.append({"role": "user", "content": user_input}) if assistant_response: messages.append({"role": "assistant", "content": assistant_response}) # Append the new user message messages.append({"role": "user", "content": user_message}) # Initialize response string response = "" # Stream the completion from the inference client for message in inference_client.chat_completion( messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): token = message.choices[0].delta.content response += token # Continuously yield updated history with the new response updated_history = conversation_history + [(user_message, response)] yield updated_history # Chatbot interface definition chatbot_interface = gr.ChatInterface( fn=respond, chatbot=gr.Chatbot(height=600), additional_inputs=[ gr.Textbox( value=SYSTEM_MESSAGE_DEFAULT, label="System message", ), gr.Slider( minimum=1, maximum=2048, value=MAX_TOKENS_DEFAULT, step=1, label="Max new tokens", ), gr.Slider( minimum=0.1, maximum=4.0, value=TEMPERATURE_DEFAULT, step=0.1, label="Temperature", ), gr.Slider( minimum=0.1, maximum=1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p (nucleus sampling)", ), ], ) if __name__ == "__main__": chatbot_interface.launch()