Petr Tsvetkov
commited on
Commit
β’
f1b08a8
1
Parent(s):
9d943c1
Compute & compare metrics
Browse files
generate_synthetic_dataset.py
CHANGED
@@ -1,26 +1,14 @@
|
|
1 |
import config
|
2 |
from api_wrappers import hf_data_loader
|
3 |
-
from generation_steps import synthetic_end_to_start,
|
4 |
|
5 |
|
6 |
def run():
|
7 |
df = hf_data_loader.load_processed_rewriting_dataset_as_pandas()
|
8 |
-
print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
|
9 |
-
print()
|
10 |
|
11 |
-
print(f"End -> start synthesis:")
|
12 |
-
print(f"GENERATION_MULTIPLIER = {synthetic_end_to_start.GENERATION_MULTIPLIER}")
|
13 |
-
print(f"REL_INSERTIONS_THRESHOLD = {synthetic_end_to_start.REL_INSERTIONS_THRESHOLD}")
|
14 |
-
print(f"GENERATION_ATTEMPTS = {synthetic_end_to_start.GENERATION_ATTEMPTS}")
|
15 |
df = synthetic_end_to_start.transform(df)
|
16 |
-
print("Done")
|
17 |
-
|
18 |
-
print(f"Start -> send synthesis:")
|
19 |
-
print(f"GENERATION_MULTIPLIER = {synthetic_start_to_end.GENERATION_MULTIPLIER}")
|
20 |
-
print(f"REL_DELETIONS_THRESHOLD = {synthetic_start_to_end.REL_DELETIONS_THRESHOLD}")
|
21 |
-
print(f"GENERATION_ATTEMPTS = {synthetic_start_to_end.GENERATION_ATTEMPTS}")
|
22 |
df = synthetic_start_to_end.transform(df)
|
23 |
-
|
24 |
|
25 |
df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
|
26 |
|
|
|
1 |
import config
|
2 |
from api_wrappers import hf_data_loader
|
3 |
+
from generation_steps import synthetic_end_to_start, synthetic_start_to_end, metrics_analysis
|
4 |
|
5 |
|
6 |
def run():
|
7 |
df = hf_data_loader.load_processed_rewriting_dataset_as_pandas()
|
|
|
|
|
8 |
|
|
|
|
|
|
|
|
|
9 |
df = synthetic_end_to_start.transform(df)
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
df = synthetic_start_to_end.transform(df)
|
11 |
+
df = metrics_analysis.transform(df)
|
12 |
|
13 |
df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
|
14 |
|
generation_steps/metrics_analysis.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
import pandas as pd
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
import config
|
6 |
+
from api_wrappers import hf_data_loader
|
7 |
+
|
8 |
+
BLEU = evaluate.load('bleu', cache_dir=config.CACHE_DIR)
|
9 |
+
|
10 |
+
|
11 |
+
def bleu_fn(pred, ref):
|
12 |
+
return BLEU.compute(predictions=[pred], references=[ref])["bleu"]
|
13 |
+
|
14 |
+
|
15 |
+
METEOR = evaluate.load('meteor', cache_dir=config.CACHE_DIR)
|
16 |
+
|
17 |
+
|
18 |
+
def meteor_fn(pred, ref):
|
19 |
+
return METEOR.compute(predictions=[pred], references=[ref])["meteor"]
|
20 |
+
|
21 |
+
|
22 |
+
ROUGE = evaluate.load('rouge', cache_dir=config.CACHE_DIR)
|
23 |
+
|
24 |
+
|
25 |
+
def rouge1_fn(pred, ref):
|
26 |
+
return ROUGE.compute(predictions=[pred], references=[ref])["rouge1"]
|
27 |
+
|
28 |
+
|
29 |
+
def rouge2_fn(pred, ref):
|
30 |
+
return ROUGE.compute(predictions=[pred], references=[ref])["rouge2"]
|
31 |
+
|
32 |
+
|
33 |
+
BERTSCORE = evaluate.load('bertscore', cache_dir=config.CACHE_DIR)
|
34 |
+
|
35 |
+
|
36 |
+
def bertscore_fn(pred, ref):
|
37 |
+
return BERTSCORE.compute(predictions=[pred], references=[ref], model_type="distilbert-base-uncased")["f1"][0]
|
38 |
+
|
39 |
+
|
40 |
+
METRICS = {
|
41 |
+
"bleu": bleu_fn,
|
42 |
+
"meteor": meteor_fn,
|
43 |
+
"rouge1": rouge1_fn,
|
44 |
+
"rouge2": rouge2_fn,
|
45 |
+
"bertscore": bertscore_fn
|
46 |
+
}
|
47 |
+
|
48 |
+
|
49 |
+
def attach_references(df):
|
50 |
+
reference_df = hf_data_loader.load_full_commit_dataset_as_pandas().set_index(["hash", "repo"])[["reference"]]
|
51 |
+
df = df.set_index(["hash", "repo"])
|
52 |
+
return df.join(other=reference_df, how="left").reset_index()
|
53 |
+
|
54 |
+
|
55 |
+
def compute_metrics(df):
|
56 |
+
tqdm.pandas()
|
57 |
+
|
58 |
+
def apply_metric_fn_to_row(row, fn, col_pred, col_ref):
|
59 |
+
return fn(row[col_pred], row[col_ref])
|
60 |
+
|
61 |
+
for metric in METRICS:
|
62 |
+
print(f"Computing {metric}")
|
63 |
+
metric_fn = METRICS[metric]
|
64 |
+
df[f"{metric}_related"] = df.progress_apply(
|
65 |
+
lambda row: apply_metric_fn_to_row(row=row,
|
66 |
+
fn=metric_fn,
|
67 |
+
col_pred="commit_msg_start",
|
68 |
+
col_ref="commit_msg_end"),
|
69 |
+
axis=1
|
70 |
+
)
|
71 |
+
df[f"{metric}_independent"] = df.progress_apply(
|
72 |
+
lambda row: apply_metric_fn_to_row(row=row,
|
73 |
+
fn=metric_fn,
|
74 |
+
col_pred="commit_msg_start",
|
75 |
+
col_ref="reference"),
|
76 |
+
axis=1
|
77 |
+
)
|
78 |
+
|
79 |
+
df[f"{metric}_pearson"] = df[f"{metric}_related"].corr(df[f"{metric}_independent"], method="pearson")
|
80 |
+
df[f"{metric}_spearman"] = df[f"{metric}_related"].corr(df[f"{metric}_independent"], method="spearman")
|
81 |
+
|
82 |
+
return df
|
83 |
+
|
84 |
+
|
85 |
+
def transform(df):
|
86 |
+
print("Computing metrics")
|
87 |
+
|
88 |
+
df = attach_references(df)
|
89 |
+
df = compute_metrics(df)
|
90 |
+
|
91 |
+
print("Done")
|
92 |
+
return df
|
93 |
+
|
94 |
+
|
95 |
+
def main():
|
96 |
+
df = pd.read_csv(config.SYNTHETIC_DATASET_ARTIFACT, index_col=[0])
|
97 |
+
df = transform(df)
|
98 |
+
df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == '__main__':
|
102 |
+
main()
|
generation_steps/synthetic_end_to_start.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
import pandas as pd
|
2 |
from tqdm import tqdm
|
3 |
|
|
|
4 |
import generate_annotated_diffs
|
5 |
import statistics
|
6 |
from api_wrappers import grazie_wrapper
|
7 |
from generation_steps import examples
|
8 |
|
9 |
-
GENERATION_MULTIPLIER =
|
10 |
REL_INSERTIONS_THRESHOLD = 0.5
|
11 |
GENERATION_ATTEMPTS = 5
|
12 |
|
@@ -62,6 +63,12 @@ COLS_TO_KEEP = ["hash", "repo", "commit_msg_end", "mods", "session"]
|
|
62 |
|
63 |
|
64 |
def transform(df):
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
df['end_to_start'] = False
|
66 |
|
67 |
generated_data = {
|
@@ -83,4 +90,17 @@ def transform(df):
|
|
83 |
generated_df = pd.DataFrame.from_dict(generated_data)
|
84 |
generated_df['end_to_start'] = True
|
85 |
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
from tqdm import tqdm
|
3 |
|
4 |
+
import config
|
5 |
import generate_annotated_diffs
|
6 |
import statistics
|
7 |
from api_wrappers import grazie_wrapper
|
8 |
from generation_steps import examples
|
9 |
|
10 |
+
GENERATION_MULTIPLIER = 3
|
11 |
REL_INSERTIONS_THRESHOLD = 0.5
|
12 |
GENERATION_ATTEMPTS = 5
|
13 |
|
|
|
63 |
|
64 |
|
65 |
def transform(df):
|
66 |
+
print(f"End -> start synthesis:")
|
67 |
+
print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
|
68 |
+
print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
|
69 |
+
print(f"REL_INSERTIONS_THRESHOLD = {REL_INSERTIONS_THRESHOLD}")
|
70 |
+
print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
|
71 |
+
|
72 |
df['end_to_start'] = False
|
73 |
|
74 |
generated_data = {
|
|
|
90 |
generated_df = pd.DataFrame.from_dict(generated_data)
|
91 |
generated_df['end_to_start'] = True
|
92 |
|
93 |
+
result = pd.concat([df, generated_df], ignore_index=True)
|
94 |
+
|
95 |
+
print("Done")
|
96 |
+
return result
|
97 |
+
|
98 |
+
|
99 |
+
def main():
|
100 |
+
df = pd.read_csv(config.SYNTHETIC_DATASET_ARTIFACT, index_col=[0])
|
101 |
+
df = transform(df)
|
102 |
+
df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
main()
|
generation_steps/synthetic_start_to_end.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
import pandas as pd
|
2 |
from tqdm import tqdm
|
3 |
|
|
|
4 |
import generate_annotated_diffs
|
5 |
import statistics
|
6 |
from api_wrappers import grazie_wrapper
|
7 |
from generation_steps import examples
|
8 |
|
9 |
-
GENERATION_MULTIPLIER =
|
10 |
REL_DELETIONS_THRESHOLD = 0.75
|
11 |
GENERATION_ATTEMPTS = 5
|
12 |
|
@@ -62,6 +63,12 @@ COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_s
|
|
62 |
|
63 |
|
64 |
def transform(df):
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
df['start_to_end'] = False
|
66 |
|
67 |
generated_data = {
|
@@ -83,4 +90,17 @@ def transform(df):
|
|
83 |
generated_df = pd.DataFrame.from_dict(generated_data)
|
84 |
generated_df['start_to_end'] = True
|
85 |
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
from tqdm import tqdm
|
3 |
|
4 |
+
import config
|
5 |
import generate_annotated_diffs
|
6 |
import statistics
|
7 |
from api_wrappers import grazie_wrapper
|
8 |
from generation_steps import examples
|
9 |
|
10 |
+
GENERATION_MULTIPLIER = 3
|
11 |
REL_DELETIONS_THRESHOLD = 0.75
|
12 |
GENERATION_ATTEMPTS = 5
|
13 |
|
|
|
63 |
|
64 |
|
65 |
def transform(df):
|
66 |
+
print(f"Start -> send synthesis:")
|
67 |
+
print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
|
68 |
+
print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
|
69 |
+
print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}")
|
70 |
+
print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
|
71 |
+
|
72 |
df['start_to_end'] = False
|
73 |
|
74 |
generated_data = {
|
|
|
90 |
generated_df = pd.DataFrame.from_dict(generated_data)
|
91 |
generated_df['start_to_end'] = True
|
92 |
|
93 |
+
result = pd.concat([df, generated_df], ignore_index=True)
|
94 |
+
|
95 |
+
print("Done")
|
96 |
+
return result
|
97 |
+
|
98 |
+
|
99 |
+
def main():
|
100 |
+
df = pd.read_csv(config.SYNTHETIC_DATASET_ARTIFACT, index_col=[0])
|
101 |
+
df = transform(df)
|
102 |
+
df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
main()
|