import diff_match_patch as dmp_module | |
from tqdm import tqdm | |
from api_wrappers import hf_data_loader | |
def get_annotated_diff(start_text, end_text): | |
dmp = dmp_module.diff_match_patch() | |
dmp_mapping = { | |
-1: '-', | |
0: None, | |
1: '+' | |
} | |
diff = dmp.diff_main(start_text, end_text) | |
dmp.diff_cleanupSemantic(diff) | |
result = [[w, dmp_mapping[t]] for t, w in diff] | |
return result | |
def annotated_diff_for_row(row): | |
if "commit_msg_start" in row: | |
start = row['commit_msg_start'] | |
else: | |
start = row["G_text"] | |
if "commit_msg_end" in row: | |
end = row['commit_msg_end'] | |
else: | |
end = row["E_text"] | |
return get_annotated_diff(start, end) | |
def data_with_annotated_diffs(): | |
tqdm.pandas() | |
df = hf_data_loader.load_synthetic_as_pandas() | |
df = df.loc[df.is_related].copy() | |
annotated = df.progress_apply(annotated_diff_for_row, axis=1) | |
df['annotated_diff'] = annotated | |
return df | |