import evaluate import pandas as pd from tqdm import tqdm import config from api_wrappers import hf_data_loader BLEU = evaluate.load('bleu', cache_dir=config.CACHE_DIR) def bleu_fn(pred, ref): return BLEU.compute(predictions=[pred], references=[ref])["bleu"] METEOR = evaluate.load('meteor', cache_dir=config.CACHE_DIR) def meteor_fn(pred, ref): return METEOR.compute(predictions=[pred], references=[ref])["meteor"] ROUGE = evaluate.load('rouge', cache_dir=config.CACHE_DIR) def rouge1_fn(pred, ref): return ROUGE.compute(predictions=[pred], references=[ref])["rouge1"] def rouge2_fn(pred, ref): return ROUGE.compute(predictions=[pred], references=[ref])["rouge2"] BERTSCORE = evaluate.load('bertscore', cache_dir=config.CACHE_DIR) def bertscore_fn(pred, ref): return BERTSCORE.compute(predictions=[pred], references=[ref], model_type="distilbert-base-uncased")["f1"][0] METRICS = { "bleu": bleu_fn, "meteor": meteor_fn, "rouge1": rouge1_fn, "rouge2": rouge2_fn, "bertscore": bertscore_fn } def attach_references(df): reference_df = hf_data_loader.load_full_commit_dataset_as_pandas().set_index(["hash", "repo"])[["reference"]] df = df.set_index(["hash", "repo"]) return df.join(other=reference_df, how="left").reset_index() def compute_metrics(df): tqdm.pandas() def apply_metric_fn_to_row(row, fn, col_pred, col_ref): return fn(row[col_pred], row[col_ref]) for metric in METRICS: print(f"Computing {metric}") metric_fn = METRICS[metric] df[f"{metric}_related"] = df.progress_apply( lambda row: apply_metric_fn_to_row(row=row, fn=metric_fn, col_pred="commit_msg_start", col_ref="commit_msg_end"), axis=1 ) df[f"{metric}_independent"] = df.progress_apply( lambda row: apply_metric_fn_to_row(row=row, fn=metric_fn, col_pred="commit_msg_start", col_ref="reference"), axis=1 ) df[f"{metric}_pearson"] = df[f"{metric}_related"].corr(df[f"{metric}_independent"], method="pearson") df[f"{metric}_spearman"] = df[f"{metric}_related"].corr(df[f"{metric}_independent"], method="spearman") return df def transform(df): print("Computing metrics") df = attach_references(df) df = compute_metrics(df) print("Done") return df def main(): df = pd.read_csv(config.SYNTHETIC_DATASET_ARTIFACT, index_col=[0]) df = transform(df) df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT) if __name__ == '__main__': main()