import json import os import threading import uuid from pathlib import Path from urllib.parse import parse_qs from datasets import load_dataset import gradio as gr from dotenv import load_dotenv from huggingface_hub import Repository import random from utils import force_git_push # These variables are for storing the MTurk HITs in a Hugging Face dataset. if Path(".env").is_file(): load_dotenv(".env") DATASET_REPO_URL = os.getenv("DATASET_REPO_URL") FORCE_PUSH = os.getenv("FORCE_PUSH") HF_TOKEN = os.getenv("HF_TOKEN") PROMPT_TEMPLATES = Path("prompt_templates") DATA_FILENAME = "data.jsonl" DATA_FILE = os.path.join("data", DATA_FILENAME) repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN) ds = load_dataset("HuggingFaceH4/instruction-pilot-outputs", split="train", use_auth_token=HF_TOKEN) TOTAL_CNT = 10 # How many user inputs per HIT # This function pushes the HIT data written in data.jsonl to our Hugging Face # dataset every minute. Adjust the frequency to suit your needs. PUSH_FREQUENCY = 60 def asynchronous_push(f_stop): if repo.is_repo_clean(): print("Repo currently clean. Ignoring push_to_hub") else: repo.git_add(auto_lfs_track=True) repo.git_commit("Auto commit by space") if FORCE_PUSH == "yes": force_git_push(repo) else: repo.git_push() if not f_stop.is_set(): # call again in 60 seconds threading.Timer(PUSH_FREQUENCY, asynchronous_push, [f_stop]).start() f_stop = threading.Event() asynchronous_push(f_stop) demo = gr.Blocks() def random_sample_with_least_annotated_examples_first(): annotations = open(DATA_FILE, "r").readlines() id_to_count = {} for line in annotations: annotation = json.loads(line) # Only include annotations by actual turkers in the count. if annotation["assignmentId"] != "": example_id = annotation["id"] id_to_count[example_id] = id_to_count.get(example_id, 0) + 1 ds_with_annotation_counts = ds.map(lambda example: {"annotation_count": id_to_count.get(example["id"], 0)}) ds_with_annotation_counts = ds_with_annotation_counts.shuffle() ds_with_annotation_counts = ds_with_annotation_counts.sort("annotation_count") example = ds_with_annotation_counts.select([0])[0] # We only want to give the annotator 2 choices, so we sample 2 outputs without replacement. example["outputs"] = random.sample(example["outputs"], 2) return example def prompt_pretty_markdown(prompt): prompt = prompt.replace("Input:", "\n\nInput:\n\n") return prompt with demo: dummy = gr.Textbox(visible=False) # dummy for passing assignmentId initial_sample = random_sample_with_least_annotated_examples_first() # We keep track of state as a JSON state_dict = { "taskId": str(uuid.uuid4()), "assignmentId": "", "cnt": 0, "data": [initial_sample], } state = gr.JSON(state_dict, visible=False) gr.Markdown("# Choose the more helpful response for the input") gr.Markdown("By 'helpful', we mean whatever answer you personally find more useful.") def _select_response(selected_response, state, dummy): if selected_response == "": # Don't do anything if the worker didn't select things yet. return ( gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), state, dummy, ) state["cnt"] += 1 state_display = f"Submissions left in HIT: {state['cnt']}/{TOTAL_CNT}" done = state["cnt"] == TOTAL_CNT state["data"][-1]["selected_response"] = selected_response if state["cnt"] == TOTAL_CNT: # Write the HIT data to our local dataset because the worker has # submitted everything now. with open(DATA_FILE, "a") as jsonlfile: json_data_with_assignment_id = [ json.dumps( dict( {"assignmentId": state["assignmentId"], "taskId": state["taskId"]}, **datum, ) ) for datum in state["data"] ] jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n") query = parse_qs(dummy[1:]) if "assignmentId" in query and query["assignmentId"][0] != "ASSIGNMENT_ID_NOT_AVAILABLE": # It seems that someone is using this app on mturk. We need to # store the assignmentId in the state before submit_hit_button # is clicked. We can do this here in _predict. We need to save the # assignmentId so that the turker can get credit for their HIT. state["assignmentId"] = query["assignmentId"][0] toggle_final_submit = gr.update(visible=done) toggle_final_submit_preview = gr.update(visible=False) else: toggle_final_submit_preview = gr.update(visible=done) toggle_final_submit = gr.update(visible=False) toggle_submit_response_button = gr.update(visible=not done) new_sample = random_sample_with_least_annotated_examples_first() new_outputs = [obj["output"] for obj in new_sample["outputs"]] state["data"].append(new_sample) past_conversation = gr.update( value=prompt_pretty_markdown(new_sample["prompt"]) ) select_response = gr.update(choices=["(a) " + new_outputs[0], "(b) " + new_outputs[1], "(c) Both (a) and (b) are similarly good", "(d) Both (a) and (b) are similarly bad"], value="") return ( past_conversation, select_response, toggle_submit_response_button, toggle_final_submit, toggle_final_submit_preview, state_display, state, dummy, ) # Input fields gr.Markdown('Prompt') past_conversation = gr.Markdown( value=prompt_pretty_markdown(initial_sample["prompt"]) ) initial_outputs = [obj["output"] for obj in initial_sample["outputs"]] gr.Markdown('Select the most helpful response') select_response = gr.Radio( choices=["(a) " + initial_outputs[0], "(b) " + initial_outputs[1], "(c) Both (a) and (b) are similarly good", "(d) Both (a) and (b) are similarly bad"], label="", ) submit_response_button = gr.Button("Submit Response") submit_hit_button = gr.Button("Submit HIT", visible=False) submit_hit_button_preview = gr.Button( "Submit Work (preview mode; no MTurk HIT credit, but your examples will still be stored)", visible=False, ) state_display = gr.Markdown(f"Submissions left in HIT: 0/{TOTAL_CNT}") # Button event handlers get_window_location_search_js = """ function(select_response, state, dummy) { return [select_response, state, window.location.search]; } """ submit_response_button.click( _select_response, inputs=[select_response, state, dummy], outputs=[ past_conversation, select_response, submit_response_button, submit_hit_button, submit_hit_button_preview, state_display, state, dummy, ], _js=get_window_location_search_js, ) post_hit_js = """ function(state) { // If there is an assignmentId, then the submitter is on mturk // and has accepted the HIT. So, we need to submit their HIT. const form = document.createElement('form'); form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit'; form.method = 'post'; for (const key in state) { const hiddenField = document.createElement('input'); hiddenField.type = 'hidden'; hiddenField.name = key; hiddenField.value = state[key]; form.appendChild(hiddenField); }; document.body.appendChild(form); form.submit(); return state; } """ submit_hit_button.click( lambda state: state, inputs=[state], outputs=[state], _js=post_hit_js, ) refresh_app_js = """ function(state) { // The following line here loads the app again so the user can // enter in another preview-mode "HIT". window.location.href = window.location.href; return state; } """ submit_hit_button_preview.click( lambda state: state, inputs=[state], outputs=[state], _js=refresh_app_js, ) demo.launch()