Petr Tsvetkov
Switch to the special commit rewriting dataset
928b43c
raw
history blame
6.82 kB
import json
import os
import random
import uuid
from datetime import datetime
from difflib import ndiff
import gradio as gr
from data_loader import load_data
from hf_dataset_saver_builder import get_dataset_saver
HF_TOKEN = os.environ.get('HF_REWRITING_TOKEN')
HF_DATASET = os.environ.get('HF_REWRITING_DATASET')
data = load_data()
n_samples = len(data)
saver = get_dataset_saver(HF_TOKEN, HF_DATASET, private=True)
def convert_diff_to_unified(diff_string):
diff = json.loads(diff_string)
result = "\n".join(
[
f'--- {modified_file["old_path"]}\n'
f'+++ {modified_file["new_path"]}\n'
f'{modified_file["diff"]}'
for modified_file in diff
]
)
return result
def get_diff2html_view(raw_diff):
html = f"""
<div style='width:100%; height:1400px; overflow:auto; position: relative'>
<div id='diff-raw' hidden>{raw_diff}</div>
<div class="d2h-view-wrapper">
<div id='diff-view'></div>
</div>
</div>
"""
return html
def get_github_link_md(repo, hash):
return f'[See the commit on Github](https://github.com/{repo}/commit/{hash})'
def char_diff_obj(change_type, pos, character, timestamp):
return {"type": change_type, "pos": pos, "char": character, "timestamp": timestamp}
def update_commit_view(sample_ind):
if sample_ind >= n_samples:
return None
record = data[sample_ind]
diff_view = get_diff2html_view(convert_diff_to_unified(record['mods']))
repo_val = record['repo']
hash_val = record['hash']
github_link_md = get_github_link_md(repo_val, hash_val)
diff_loaded_timestamp = datetime.now().isoformat()
commit_message = record['prediction']
commit_message_start = commit_message
commit_message_prev = commit_message
commit_message_history = []
return (
github_link_md, diff_view, repo_val, hash_val, diff_loaded_timestamp,
commit_message_start, commit_message, commit_message_prev, commit_message_history)
def next_sample(current_sample_ind, shuffled_idx):
if current_sample_ind == n_samples:
return None
current_sample_ind += 1
updated_view = update_commit_view(shuffled_idx[current_sample_ind])
return (current_sample_ind,) + updated_view
with open("head.html") as head_file:
head_html = head_file.read()
with gr.Blocks(theme=gr.themes.Soft(), head=head_html, css="style_overrides.css") as application:
repo_val = gr.Textbox(interactive=False, label='repo', visible=False)
hash_val = gr.Textbox(interactive=False, label='hash', visible=False)
shuffled_idx_val = gr.JSON(visible=False)
with gr.Row():
with gr.Accordion("Help"):
with open("survey_guide.md") as content_file:
gr.Markdown(content_file.read())
with gr.Row():
current_sample_sld = gr.Slider(minimum=0, maximum=n_samples, step=1,
value=0,
interactive=False,
label='sample_ind',
info=f"Samples labeled/skipped (out of {n_samples})",
show_label=False,
container=False,
scale=5)
with gr.Column(scale=1):
skip_btn = gr.Button("Skip the current sample")
with gr.Row():
with gr.Column(scale=2):
github_link = gr.Markdown()
diff_view = gr.HTML()
with gr.Column(scale=1):
commit_msg_start = gr.TextArea(label="commit_msg_start", visible=False)
commit_msg = gr.TextArea(label="commit_msg_end", show_label=False,
info="Commit message (can be scrollable)")
commit_msg_prev = gr.TextArea(visible=False)
commit_msg_history = gr.JSON(label="commit_msg_history", visible=False)
submit_btn = gr.Button("Submit")
session_val = gr.Textbox(info='Session', interactive=False, container=True, show_label=False,
label='session')
with gr.Row(visible=False):
sample_loaded_timestamp = gr.Textbox(info="Sample loaded", label='loaded_ts', interactive=False,
container=True, show_label=False)
now_timestamp = gr.Textbox(info="Current time",
interactive=False, container=True, show_label=False,
value=lambda: datetime.now().isoformat(), every=0.1,
label='submitted_ts')
commit_view = [
github_link,
diff_view,
repo_val,
hash_val,
sample_loaded_timestamp,
commit_msg_start,
commit_msg,
commit_msg_prev,
commit_msg_history
]
feedback_metadata = [
session_val,
repo_val,
hash_val,
sample_loaded_timestamp,
now_timestamp
]
feedback_form = [
commit_msg_start,
commit_msg,
commit_msg_history
]
saver.setup([current_sample_sld] + feedback_metadata + feedback_form, "feedback")
skip_btn.click(next_sample, inputs=[current_sample_sld, shuffled_idx_val],
outputs=[current_sample_sld] + commit_view)
def submit(current_sample, shuffled_idx, *args):
saver.flag((current_sample,) + args)
return next_sample(current_sample, shuffled_idx)
submit_btn.click(
submit,
inputs=[current_sample_sld, shuffled_idx_val] + feedback_metadata + feedback_form,
outputs=[current_sample_sld] + commit_view
)
def on_commit_msg_changed(message, prev_message, history, timestamp):
for i, s in enumerate(ndiff(prev_message, message)):
diff = char_diff_obj(s[0], i, s[-1], timestamp)
if diff['type'] in ('+', '-'):
print(diff)
history.append(diff)
return message, history
commit_msg.change(on_commit_msg_changed, inputs=[commit_msg, commit_msg_prev, commit_msg_history,
now_timestamp],
outputs=[commit_msg_prev, commit_msg_history])
def init_session(current_sample):
session = str(uuid.uuid4())
shuffled_idx = list(range(n_samples))
random.shuffle(shuffled_idx)
return (session, shuffled_idx) + update_commit_view(shuffled_idx[current_sample])
application.load(init_session,
inputs=[current_sample_sld],
outputs=[session_val, shuffled_idx_val] + commit_view, )
application.launch()