import os import gradio as gr from typing import Iterator from dialog import get_dialog_box from gateway import check_server_health, request_generation # CONSTANTS MAX_NEW_TOKENS: int = 2048 # GET ENVIRONMENT VARIABLES CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT") def toggle_ui(): """ Function to toggle the visibility of the UI based on the server health Returns: hide/show main ui/dialog """ health = check_server_health(cloud_gateway_api=CLOUD_GATEWAY_API) if health: return gr.update(visible=True), gr.update(visible=False) # Show main UI, hide dialog else: return gr.update(visible=False), gr.update(visible=True) # Hide main UI, show dialog def generate( message: str, chat_history: list, system_prompt: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: """Send a request to backend, fetch the streaming responses and emit to the UI. Args: message (str): input message from the user chat_history (list[tuple[str, str]]): entire chat history of the session system_prompt (str): system prompt max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the prompt. Defaults to 1024. temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6. top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to 0.9. top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to 50. repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.2. Yields: Iterator[str]: Streaming responses to the UI """ # sample method to yield responses from the llm model outputs = [] for text in request_generation(message=message, system_prompt=system_prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, cloud_gateway_api=CLOUD_GATEWAY_API): outputs.append(text) yield "".join(outputs) chat_interface = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Textbox(label="System prompt", lines=6), gr.Slider( label="Max New Tokens", minimum=1, maximum=MAX_NEW_TOKENS, step=1, value=1024, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.1, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.95, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], stop_btn=None, examples=[ ["Hello there! How are you doing?"], ["Can you explain briefly to me what is the Python programming language?"], ["Explain the plot of Cinderella in a sentence."], ["How many hours does it take a man to eat a Helicopter?"], ["Write a 100-word article on 'Benefits of Open-Source in AI research'."], ], cache_examples=False, ) with gr.Blocks(css="style.css", fill_height=True) as demo: # Get the server status before displaying UI visibility = check_server_health(CLOUD_GATEWAY_API) # Container for the main interface with gr.Column(visible=visibility, elem_id="main_ui") as main_ui: gr.Markdown(f""" # Llama-3 8B Chat This Space is an Alpha release that demonstrates [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model running on AMD MI210 infrastructure. The space is built with Meta Llama 3 [License](https://www.llama.com/llama3/license/). Feel free to play with it! """) chat_interface.render() # Dialog box using Markdown for the error message with gr.Row(visible=(not visibility), elem_id="dialog_box") as dialog_box: # Add spinner and message get_dialog_box() # Timer to check server health every 5 seconds and update UI timer = gr.Timer(value=10) timer.tick(fn=toggle_ui, outputs=[main_ui, dialog_box]) if __name__ == "__main__": demo.queue(max_size=int(os.getenv("QUEUE")), default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT"))).launch()