Petr Tsvetkov
Add labels highlighting the total # of samples and the commit message textbox
039521b
import json
import os
import random
import uuid
from datetime import datetime
from difflib import ndiff
import gradio as gr
from data_loader import load_data
from hf_dataset_saver_builder import get_dataset_saver
HF_TOKEN = os.environ.get('HF_REWRITING_TOKEN')
HF_DATASET = os.environ.get('HF_REWRITING_DATASET')
data = load_data()
n_samples = len(data)
saver = get_dataset_saver(HF_TOKEN, HF_DATASET, private=True, separate_dirs=True)
def convert_diff_to_unified(diff_string):
diff = json.loads(diff_string)
result = "\n".join(
[
f'--- {modified_file["old_path"]}\n'
f'+++ {modified_file["new_path"]}\n'
f'{modified_file["diff"]}'
for modified_file in diff
]
)
return result
def get_diff2html_view(raw_diff):
html = f"""
<div style='width:100%; height:1400px; overflow:auto; position: relative'>
<div id='diff-raw' hidden>{raw_diff}</div>
<div class="d2h-view-wrapper">
<div id='diff-view'></div>
</div>
</div>
"""
return html
def get_github_link_md(repo, hash):
return f'[See the commit on Github](https://github.com/{repo}/commit/{hash})'
def char_diff_obj(change_type, pos, character, timestamp):
return {"t": change_type, "p": pos, "c": character, "ts": timestamp}
def update_commit_view(sample_ind):
if sample_ind >= n_samples:
return None
record = data[sample_ind]
diff_view = get_diff2html_view(convert_diff_to_unified(record['mods']))
repo_val = record['repo']
hash_val = record['hash']
github_link_md = get_github_link_md(repo_val, hash_val)
diff_loaded_timestamp = datetime.now().isoformat()
summary_md = f"{record['summary']}"
commit_message = record['prediction']
commit_message_start = commit_message
commit_message_prev = commit_message
commit_message_history = []
return (
github_link_md, diff_view, repo_val, hash_val, diff_loaded_timestamp, summary_md,
commit_message_start, commit_message, commit_message_prev, commit_message_history)
def next_sample(current_sample_ind, shuffled_idx):
if current_sample_ind == n_samples:
return None
current_sample_ind += 1
updated_view = update_commit_view(shuffled_idx[current_sample_ind])
return (current_sample_ind,) + updated_view
with open("head.html") as head_file:
head_html = head_file.read()
force_light_theme_js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
"""
with gr.Blocks(theme=gr.themes.Soft(), head=head_html, css="style_overrides.css",
js=force_light_theme_js_func) as application:
repo_val = gr.Textbox(interactive=False, label='repo', visible=False)
hash_val = gr.Textbox(interactive=False, label='hash', visible=False)
shuffled_idx_val = gr.JSON(visible=False)
with gr.Row():
with gr.Accordion("Help"):
with open("survey_guide.md") as content_file:
gr.Markdown(content_file.read())
with gr.Row():
current_sample_sld = gr.Slider(minimum=0, maximum=n_samples, step=1,
value=0,
interactive=False,
label='sample_ind',
info=f"Samples labeled/skipped",
show_label=False,
container=False,
scale=5)
with gr.Column(scale=1):
gr.Markdown(value=f"#### Total number of samples: {n_samples}")
with gr.Column(scale=1):
skip_btn = gr.Button("Skip the current sample")
with gr.Row():
with gr.Column(scale=2):
github_link = gr.Markdown()
diff_view = gr.HTML()
with gr.Column(scale=1):
with gr.Accordion("Commit summary (AI generated)", open=False):
commit_summary = gr.Markdown()
commit_msg_start = gr.TextArea(label="commit_msg_start", visible=False)
gr.Markdown(value=f"#### Please, edit the message in the text box below")
commit_msg = gr.TextArea(label="commit_msg_end", show_label=False,
info="Commit message (can be scrollable)")
commit_msg_prev = gr.TextArea(visible=False)
commit_msg_history = gr.JSON(label="commit_msg_history", visible=False)
submit_btn = gr.Button("Submit")
session_val = gr.Textbox(info='Session', interactive=False, container=True, show_label=False,
label='session')
with gr.Row(visible=False):
sample_loaded_timestamp = gr.Textbox(info="Sample loaded", label='loaded_ts', interactive=False,
container=True, show_label=False)
now_timestamp = gr.Textbox(info="Current time",
interactive=False, container=True, show_label=False,
value=lambda: datetime.now().isoformat(), every=0.1,
label='submitted_ts')
commit_view = [
github_link,
diff_view,
repo_val,
hash_val,
sample_loaded_timestamp,
commit_summary,
commit_msg_start,
commit_msg,
commit_msg_prev,
commit_msg_history
]
feedback_metadata = [
session_val,
repo_val,
hash_val,
sample_loaded_timestamp,
now_timestamp
]
feedback_form = [
commit_msg_start,
commit_msg,
commit_msg_history
]
saver.setup([current_sample_sld] + feedback_metadata + feedback_form, "feedback")
skip_btn.click(next_sample, inputs=[current_sample_sld, shuffled_idx_val],
outputs=[current_sample_sld] + commit_view)
def submit(current_sample, shuffled_idx, *args):
saver.flag((current_sample,) + args)
return next_sample(current_sample, shuffled_idx)
submit_btn.click(
submit,
inputs=[current_sample_sld, shuffled_idx_val] + feedback_metadata + feedback_form,
outputs=[current_sample_sld] + commit_view
)
def on_commit_msg_changed(message, prev_message, history):
timestamp = datetime.now().isoformat()
for i, s in enumerate(ndiff(prev_message, message)):
diff = char_diff_obj(s[0], i, s[-1], timestamp)
if diff['t'] in ('+', '-'):
history.append(diff)
return message, history
commit_msg.change(on_commit_msg_changed, inputs=[commit_msg, commit_msg_prev, commit_msg_history],
outputs=[commit_msg_prev, commit_msg_history])
def init_session(current_sample):
session = str(uuid.uuid4())
shuffled_idx = list(range(n_samples))
random.shuffle(shuffled_idx)
return (session, shuffled_idx) + update_commit_view(shuffled_idx[current_sample])
application.load(init_session,
inputs=[current_sample_sld],
outputs=[session_val, shuffled_idx_val] + commit_view, )
application.launch()