|
import os |
|
import random |
|
import uuid |
|
from datetime import datetime |
|
from difflib import ndiff |
|
|
|
import gradio as gr |
|
|
|
from data_loader import load_data |
|
from hf_dataset_saver_builder import get_dataset_saver |
|
|
|
HF_TOKEN = os.environ.get('HF_REWRITING_TOKEN') |
|
HF_DATASET = os.environ.get('HF_REWRITING_DATASET') |
|
|
|
data = load_data() |
|
|
|
n_samples = len(data) |
|
|
|
saver = get_dataset_saver(HF_TOKEN, HF_DATASET, private=True) |
|
|
|
|
|
def convert_diff_to_unified(diff): |
|
result = "\n".join( |
|
[ |
|
f'--- {modified_file["old_path"]}\n' |
|
f'+++ {modified_file["new_path"]}\n' |
|
f'{modified_file["diff"]}' |
|
for modified_file in diff |
|
] |
|
) |
|
|
|
return result |
|
|
|
|
|
def get_diff2html_view(raw_diff): |
|
html = f""" |
|
<div style='width:100%; height:1400px; overflow:auto; position: relative'> |
|
<div id='diff-raw' hidden>{raw_diff}</div> |
|
<div class="d2h-view-wrapper"> |
|
<div id='diff-view'></div> |
|
</div> |
|
</div> |
|
""" |
|
|
|
return html |
|
|
|
|
|
def get_github_link_md(repo, hash): |
|
return f'[See the commit on Github](https://github.com/{repo}/commit/{hash})' |
|
|
|
|
|
def char_diff_obj(change_type, pos, character, timestamp): |
|
return {"type": change_type, "pos": pos, "char": character, "timestamp": timestamp} |
|
|
|
|
|
def update_commit_view(sample_ind): |
|
if sample_ind >= n_samples: |
|
return None |
|
|
|
record = data[sample_ind] |
|
|
|
diff_view = get_diff2html_view(convert_diff_to_unified(record['mods'])) |
|
|
|
repo_val = record['repo'] |
|
hash_val = record['hash'] |
|
github_link_md = get_github_link_md(repo_val, hash_val) |
|
|
|
diff_loaded_timestamp = datetime.now().isoformat() |
|
|
|
commit_message = record['prediction'] |
|
commit_message_start = commit_message |
|
commit_message_prev = commit_message |
|
commit_message_history = [] |
|
|
|
return ( |
|
github_link_md, diff_view, repo_val, hash_val, diff_loaded_timestamp, |
|
commit_message_start, commit_message, commit_message_prev, commit_message_history) |
|
|
|
|
|
def next_sample(current_sample_ind, shuffled_idx): |
|
if current_sample_ind == n_samples: |
|
return None |
|
|
|
current_sample_ind += 1 |
|
updated_view = update_commit_view(shuffled_idx[current_sample_ind]) |
|
return (current_sample_ind,) + updated_view |
|
|
|
|
|
with open("head.html") as head_file: |
|
head_html = head_file.read() |
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), head=head_html, css="style_overrides.css") as application: |
|
repo_val = gr.Textbox(interactive=False, label='repo', visible=False) |
|
hash_val = gr.Textbox(interactive=False, label='hash', visible=False) |
|
shuffled_idx_val = gr.JSON(visible=False) |
|
|
|
with gr.Row(): |
|
with gr.Accordion("Help"): |
|
with open("survey_guide.md") as content_file: |
|
gr.Markdown(content_file.read()) |
|
|
|
with gr.Row(): |
|
current_sample_sld = gr.Slider(minimum=0, maximum=n_samples, step=1, |
|
value=0, |
|
interactive=False, |
|
label='sample_ind', |
|
info=f"Samples labeled/skipped (out of {n_samples})", |
|
show_label=False, |
|
container=False, |
|
scale=5) |
|
with gr.Column(scale=1): |
|
skip_btn = gr.Button("Skip the current sample") |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
github_link = gr.Markdown() |
|
diff_view = gr.HTML() |
|
with gr.Column(scale=1): |
|
commit_msg_start = gr.TextArea(label="commit_msg_start", visible=False) |
|
commit_msg = gr.TextArea(label="commit_msg_end", show_label=False, |
|
info="Commit message (can be scrollable)") |
|
commit_msg_prev = gr.TextArea(visible=False) |
|
commit_msg_history = gr.JSON(label="commit_msg_history", visible=False) |
|
|
|
submit_btn = gr.Button("Submit") |
|
|
|
session_val = gr.Textbox(info='Session', interactive=False, container=True, show_label=False, |
|
label='session') |
|
|
|
with gr.Row(visible=False): |
|
sample_loaded_timestamp = gr.Textbox(info="Sample loaded", label='loaded_ts', interactive=False, |
|
container=True, show_label=False) |
|
now_timestamp = gr.Textbox(info="Current time", |
|
interactive=False, container=True, show_label=False, |
|
value=lambda: datetime.now().isoformat(), every=1.0, |
|
label='submitted_ts') |
|
|
|
commit_view = [ |
|
github_link, |
|
diff_view, |
|
repo_val, |
|
hash_val, |
|
sample_loaded_timestamp, |
|
commit_msg_start, |
|
commit_msg, |
|
commit_msg_prev, |
|
commit_msg_history |
|
] |
|
|
|
feedback_metadata = [ |
|
session_val, |
|
repo_val, |
|
hash_val, |
|
sample_loaded_timestamp, |
|
now_timestamp |
|
] |
|
|
|
feedback_form = [ |
|
commit_msg_start, |
|
commit_msg, |
|
commit_msg_history |
|
] |
|
|
|
saver.setup([current_sample_sld] + feedback_metadata + feedback_form, "feedback") |
|
|
|
skip_btn.click(next_sample, inputs=[current_sample_sld, shuffled_idx_val], |
|
outputs=[current_sample_sld] + commit_view) |
|
|
|
|
|
def submit(current_sample, shuffled_idx, *args): |
|
saver.flag((current_sample,) + args) |
|
return next_sample(current_sample, shuffled_idx) |
|
|
|
|
|
submit_btn.click( |
|
submit, |
|
inputs=[current_sample_sld, shuffled_idx_val] + feedback_metadata + feedback_form, |
|
outputs=[current_sample_sld] + commit_view |
|
) |
|
|
|
|
|
def on_commit_msg_changed(message, prev_message, history, timestamp): |
|
for i, s in enumerate(ndiff(prev_message, message)): |
|
diff = char_diff_obj(s[0], i, s[-1], timestamp) |
|
if diff['type'] in ('+', '-'): |
|
history.append(diff) |
|
return message, history |
|
|
|
|
|
commit_msg.change(on_commit_msg_changed, inputs=[commit_msg, commit_msg_prev, commit_msg_history, |
|
now_timestamp], |
|
outputs=[commit_msg_prev, commit_msg_history]) |
|
|
|
|
|
def init_session(current_sample): |
|
session = str(uuid.uuid4()) |
|
shuffled_idx = list(range(n_samples)) |
|
random.shuffle(shuffled_idx) |
|
return (session, shuffled_idx) + update_commit_view(shuffled_idx[current_sample]) |
|
|
|
|
|
application.load(init_session, |
|
inputs=[current_sample_sld], |
|
outputs=[session_val, shuffled_idx_val] + commit_view, ) |
|
|
|
application.launch() |
|
|