"""Inspired by the SantaCoder demo Huggingface space. Link: https://huggingface.co/spaces/bigcode/santacoder-demo/tree/main/app.py """ import os import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed REPO = "replit/replit-code-v1-3b" description = """#

Code Completion with replit-code-v1-3b

replit-code-v1-3b model is a 2.7B LLM trained on 20 languages from the Stack Dedup v1.2 dataset. You can click the button several times to keep completing your code.""" token = os.environ["HUB_TOKEN"] device = "cuda" if torch.cuda.is_available() else "cpu" PAD_TOKEN = "<|pad|>" EOS_TOKEN = "<|endoftext|>" UNK_TOKEN = "<|unk|>" MAX_INPUT_TOKENS = 1024 # max tokens from context tokenizer = AutoTokenizer.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True) tokenizer.truncation_side = "left" # ensures if truncate, then keep the last N tokens of the prompt going L -> R if device == "cuda": model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True).to(device, dtype=torch.bfloat16) else: model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True) model.eval() custom_css = """ .gradio-container { background-color: #0D1525; color:white } #orange-button { background: #F26207 !important; color: white; } .cm-gutters{ border: none !important; } """ def post_processing(prompt, completion): return prompt + completion # completion = "" + completion + "" # prompt = "" + prompt + "" # code_html = f"

{prompt}{completion}


" # return code_html def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0): # truncates the prompt to MAX_INPUT_TOKENS if its too long x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device) print("Prompt shape: ", x.shape) # just adding to see in the space logs in prod set_seed(seed) y = model.generate(x, max_new_tokens=max_new_tokens, temperature=temperature, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, top_p=top_p, top_k=top_k, use_cache=use_cache, repetition_penalty=repetition_penalty ) completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) completion = completion[len(prompt):] return post_processing(prompt, completion) demo = gr.Blocks( css=custom_css ) with demo: gr.Markdown(value=description) with gr.Row(): input_col , settings_col = gr.Column(scale=6), gr.Column(scale=6), with input_col: code = gr.Code(lines=28,label='Input', value="def sieve_eratosthenes(n):") with settings_col: with gr.Accordion("Generation Settings", open=True): max_new_tokens= gr.Slider( minimum=8, maximum=128, step=1, value=48, label="Max Tokens", ) temperature = gr.Slider( minimum=0.1, maximum=2.5, step=0.1, value=0.2, label="Temperature", ) repetition_penalty = gr.Slider( minimum=1.0, maximum=1.9, step=0.1, value=1.0, label="Repetition Penalty. 1.0 means no penalty.", ) seed = gr.Slider( minimum=0, maximum=1000, step=1, label="Random Seed" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, step=0.1, value=0.9, label="Top P", ) top_k = gr.Slider( minimum=1, maximum=64, step=1, value=4, label="Top K", ) use_cache = gr.Checkbox( label="Use Cache", value=True ) with gr.Row(): run = gr.Button(elem_id="orange-button", value="Generate More Code") # with gr.Row(): # # _, middle_col_row_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1) # # with middle_col_row_2: # output = gr.HTML(label="Generated Code") event = run.click(code_generation, [code, max_new_tokens, temperature, seed, top_p, top_k, use_cache, repetition_penalty], code, api_name="predict") demo.queue(max_size=40).launch()