import json import os import gradio as gr from huggingface_hub import Repository from text_generation import Client HF_TOKEN = os.environ.get("TRL_TOKEN", None) API_URL = os.environ.get("API_URL") theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", radius_size=gr.themes.sizes.radius_sm, font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"], ) if HF_TOKEN: repo = Repository( local_dir="data", clone_from="trl-lib/stack-llama-prompts", use_auth_token=HF_TOKEN, repo_type="dataset" ) client = Client( API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"}, ) PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer:""" def save_inputs_and_outputs(inputs, outputs, generate_kwargs): with open(os.path.join("data", "prompts.jsonl"), "a") as f: json.dump({"inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs}, f, ensure_ascii=False) f.write("\n") commit_url = repo.push_to_hub() def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, top_k=100): formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction) temperature = float(temperature) top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, do_sample=True, truncate=999, seed=42, stop_sequences=[""], ) stream = client.generate_stream( formatted_instruction, **generate_kwargs, ) output = "" for response in stream: output += response.token.text yield output if HF_TOKEN: print("Pushing prompt and completion to the Hub") save_inputs_and_outputs(formatted_instruction, output, generate_kwargs) return output examples = [ "A llama is in my lawn. How do I get rid of him?", "How do I create an array in C++ which contains all even numbers between 1 and 10?", "How can I sort a list in Python?", "How can I write a Java function to generate the nth Fibonacci number?", "How many helicopters can a llama eat in one sitting?", ] def process_example(args): for x in generate(args): pass return x with gr.Blocks(theme=theme, analytics_enabled=False, css=".generating {visibility: hidden}") as demo: with gr.Column(): gr.Markdown( """