import sys import os import gradio as gr from gradio.themes.utils import sizes from text_generation import Client # todo: remove and replace by the actual js file instead from share_btn import (share_js) from utils import ( get_file_as_string, get_sections, get_url_from_env_or_default_path, preview ) from constants import ( DEFAULT_STARCODER_API_PATH, DEFAULT_STARCODER_BASE_API_PATH, FIM_MIDDLE, FIM_PREFIX, FIM_SUFFIX, END_OF_TEXT, MIN_TEMPERATURE, ) HF_TOKEN = os.environ.get("HF_TOKEN", None) # Gracefully exit the app if the HF_TOKEN is not set, # printing to system `errout` the error (instead of raising an exception) # and the expected behavior if not HF_TOKEN: ERR_MSG = """ Please set the HF_TOKEN environment variable with your Hugging Face API token. You can get one by signing up at https://huggingface.co/join and then visiting https://huggingface.co/settings/tokens.""" print(ERR_MSG, file=sys.stderr) # gr.errors.GradioError(ERR_MSG) # gr.close_all(verbose=False) sys.exit(1) API_URL = get_url_from_env_or_default_path("STARCODER_API", DEFAULT_STARCODER_API_PATH) API_URL_BASE = get_url_from_env_or_default_path("STARCODER_BASE_API", DEFAULT_STARCODER_BASE_API_PATH) preview("StarCoder Model's URL", API_URL) preview("StarCoderBase Model's URL", API_URL_BASE) preview("HF Token", HF_TOKEN, ofuscate=True) DEFAULT_PORT = 7860 FIM_INDICATOR = "" # Loads the whole content of the formats.md file # and stores it into the FORMATS variable STATIC_PATH = "static" FORMATS = get_file_as_string("formats.md", path=STATIC_PATH) CSS = get_file_as_string("styles.css", path=STATIC_PATH) community_icon_svg = get_file_as_string("community_icon.svg", path=STATIC_PATH) loading_icon_svg = get_file_as_string("loading_icon.svg", path=STATIC_PATH) # todo: evaluate making STATIC_PATH the default path instead of the current one README = get_file_as_string("README.md") # Slicing the different sections from the README readme_sections = get_sections(README, "---") manifest, description, disclaimer = readme_sections[:3] theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", radius_size=sizes.radius_sm, font=[ gr.themes.GoogleFont("Rubik"), "ui-sans-serif", "system-ui", "sans-serif", ], text_size=sizes.text_lg, ) HEADERS = { "Authorization": f"Bearer {HF_TOKEN}", } client = Client(API_URL, headers = HEADERS) client_base = Client(API_URL_BASE, headers = HEADERS) def generate(prompt, temperature = 0.9, max_new_tokens = 256, top_p = 0.95, repetition_penalty = 1.0, version = "StarCoder", ): temperature = min(float(temperature), MIN_TEMPERATURE) top_p = float(top_p) generate_kwargs = dict( temperature = temperature, max_new_tokens = max_new_tokens, top_p = top_p, repetition_penalty = repetition_penalty, do_sample = True, seed = 42, ) if fim_mode := FIM_INDICATOR in prompt: try: prefix, suffix = prompt.split(FIM_INDICATOR) except Exception as err: print(str(err)) raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!") from err prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" model_client = client if version == "StarCoder" else client_base stream = model_client.generate_stream(prompt, **generate_kwargs) output = prefix if fim_mode else prompt for response in stream: if response.token.text == END_OF_TEXT: if fim_mode: output += suffix else: return output else: output += response.token.text # todo: log this value while in debug mode # previous_token = response.token.text yield output return output # todo: move it into the README too examples = [ "X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score", "// Returns every other value in the array as a new array.\nfunction everyOther(arr) {", "def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n \n else:\n results.extend(list2[i+1:])\n return results", ] def process_example(args): for x in generate(args): pass return x with gr.Blocks(theme=theme, analytics_enabled=False, css=CSS) as demo: with gr.Column(): gr.Markdown(description) with gr.Row(): with gr.Column(): instruction = gr.Textbox( placeholder="Enter your code here", label="Code", elem_id="q-input", ) submit = gr.Button("Generate", variant="primary") output = gr.Code(elem_id="q-output", lines=30) with gr.Row(): with gr.Column(): with gr.Accordion("Advanced settings", open=False): with gr.Row(): column_1, column_2 = gr.Column(), gr.Column() with column_1: temperature = gr.Slider( label="Temperature", value=0.2, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ) max_new_tokens = gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=8192, step=64, interactive=True, info="The maximum numbers of new tokens", ) with column_2: top_p = gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ) repetition_penalty = gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) with gr.Column(): version = gr.Dropdown( ["StarCoderBase", "StarCoder"], value="StarCoder", label="Version", info="", ) gr.Markdown(disclaimer) with gr.Group(elem_id="share-btn-container"): community_icon = gr.HTML(community_icon_svg, visible=True) loading_icon = gr.HTML(loading_icon_svg, visible=True) share_button = gr.Button( "Share to community", elem_id="share-btn", visible=True ) gr.Examples( examples=examples, inputs=[instruction], cache_examples=False, fn=process_example, outputs=[output], ) gr.Markdown(FORMATS) submit.click( generate, inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, version], outputs=[output], ) share_button.click(None, [], [], _js=share_js) demo.queue(concurrency_count=16).launch(debug=True, server_port=DEFAULT_PORT)