Mirror / src /get_avg_results.py
Spico's picture
update
5953ef9
raw
history blame
No virus
4.77 kB
import os
import re
import statistics as sts
from collections import defaultdict
from pathlib import Path
from rex.utils.dict import get_dict_content
from rex.utils.io import load_json
from rich.console import Console
from rich.table import Table
inputs_dir = Path("mirror_fewshot_outputs")
# regex = re.compile(r"Mirror_SingleTask_(.*?)_seed(\d+)_(\d+)shot")
regex = re.compile(r"Mirror_wPT_woInst_(.*?)_seed(\d+)_(\d+)shot")
# task -> shot -> seeds
results = defaultdict(lambda: defaultdict(list))
for dirname in os.listdir(inputs_dir):
dpath = inputs_dir / dirname
re_matched = regex.match(dirname)
if dpath.is_dir() and re_matched:
task, seed, shot = re_matched.groups()
results_json_p = dpath / "measures" / "test.final.json"
metrics = load_json(results_json_p)
if "Ent_" in task:
results[task][shot].append(
get_dict_content(metrics, "metrics.ent.micro.f1")
)
elif "Rel_" in task or "ABSA_" in task:
results[task][shot].append(
get_dict_content(metrics, "metrics.rel.rel.micro.f1")
)
elif "Event_" in task:
results[task + "_Trigger"][shot].append(
get_dict_content(metrics, "metrics.event.trigger_cls.f1")
)
results[task + "_Arg"][shot].append(
get_dict_content(metrics, "metrics.event.arg_cls.f1")
)
else:
raise RuntimeError
table = Table(title="Few-shot results")
table.add_column("Task", justify="center")
table.add_column("1-shot", justify="right")
table.add_column("5-shot", justify="right")
table.add_column("10-shot", justify="right")
table.add_column("Avg.", justify="right")
for task in results:
shots = sorted(results[task].keys(), key=lambda x: int(x))
all_seeds = []
shot_results = []
for shot in shots:
seeds = results[task][shot]
all_seeds.extend(seeds)
avg = sum(seeds) / len(seeds)
sts.stdev(seeds)
shot_results.append(f"{100*avg:.2f}Β±{100*sts.stdev(seeds):.2f}")
shot_results.append(f"{100*sts.mean(all_seeds):.2f}")
table.add_row(task, *shot_results)
console = Console()
console.print(table)
"""
Few-shot results wPT wInst
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃ Task ┃ 1-shot ┃ 5-shot ┃ 10-shot ┃ Avg. ┃
┑━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
β”‚ Ent_CoNLL03 β”‚ 77.50Β±1.64 β”‚ 82.73Β±2.29 β”‚ 84.48Β±1.62 β”‚ 81.57 β”‚
β”‚ Rel_CoNLL04 β”‚ 34.66Β±10.52 β”‚ 52.23Β±3.16 β”‚ 58.68Β±1.77 β”‚ 48.52 β”‚
β”‚ Event_ACE05_Trigger β”‚ 49.50Β±3.59 β”‚ 65.61Β±19.29 β”‚ 60.68Β±2.45 β”‚ 58.60 β”‚
β”‚ Event_ACE05_Arg β”‚ 23.46Β±1.66 β”‚ 48.32Β±28.91 β”‚ 41.90Β±1.95 β”‚ 37.89 β”‚
β”‚ ABSA_16res β”‚ 67.06Β±0.56 β”‚ 73.51Β±14.75 β”‚ 68.70Β±1.46 β”‚ 69.76 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜
Few-shot results wPT woInst
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃ Task ┃ 1-shot ┃ 5-shot ┃ 10-shot ┃ Avg. ┃
┑━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
β”‚ Ent_CoNLL03 β”‚ 76.33Β±1.74 β”‚ 82.50Β±1.87 β”‚ 84.47Β±1.18 β”‚ 81.10 β”‚
β”‚ woInst_Rel_CoNLL04 β”‚ 34.86Β±6.20 β”‚ 48.00Β±4.44 β”‚ 55.65Β±2.53 β”‚ 46.17 β”‚
β”‚ Rel_CoNLL04 β”‚ 26.83Β±15.22 β”‚ 47.39Β±3.60 β”‚ 55.38Β±2.41 β”‚ 43.20 β”‚
β”‚ Event_ACE05_Trigger β”‚ 46.60Β±1.09 β”‚ 57.21Β±3.51 β”‚ 59.67Β±3.20 β”‚ 54.49 β”‚
β”‚ Event_ACE05_Arg β”‚ 21.60Β±3.61 β”‚ 34.43Β±3.63 β”‚ 39.62Β±2.60 β”‚ 31.88 β”‚
β”‚ ABSA_16res β”‚ 8.10Β±18.11 β”‚ 52.73Β±5.52 β”‚ 57.32Β±1.73 β”‚ 39.38 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜
"""