Mirror / src /get_avg_results.py
Spico's picture
update
5953ef9
raw
history blame contribute delete
No virus
4.77 kB
import os
import re
import statistics as sts
from collections import defaultdict
from pathlib import Path
from rex.utils.dict import get_dict_content
from rex.utils.io import load_json
from rich.console import Console
from rich.table import Table
inputs_dir = Path("mirror_fewshot_outputs")
# regex = re.compile(r"Mirror_SingleTask_(.*?)_seed(\d+)_(\d+)shot")
regex = re.compile(r"Mirror_wPT_woInst_(.*?)_seed(\d+)_(\d+)shot")
# task -> shot -> seeds
results = defaultdict(lambda: defaultdict(list))
for dirname in os.listdir(inputs_dir):
dpath = inputs_dir / dirname
re_matched = regex.match(dirname)
if dpath.is_dir() and re_matched:
task, seed, shot = re_matched.groups()
results_json_p = dpath / "measures" / "test.final.json"
metrics = load_json(results_json_p)
if "Ent_" in task:
results[task][shot].append(
get_dict_content(metrics, "metrics.ent.micro.f1")
)
elif "Rel_" in task or "ABSA_" in task:
results[task][shot].append(
get_dict_content(metrics, "metrics.rel.rel.micro.f1")
)
elif "Event_" in task:
results[task + "_Trigger"][shot].append(
get_dict_content(metrics, "metrics.event.trigger_cls.f1")
)
results[task + "_Arg"][shot].append(
get_dict_content(metrics, "metrics.event.arg_cls.f1")
)
else:
raise RuntimeError
table = Table(title="Few-shot results")
table.add_column("Task", justify="center")
table.add_column("1-shot", justify="right")
table.add_column("5-shot", justify="right")
table.add_column("10-shot", justify="right")
table.add_column("Avg.", justify="right")
for task in results:
shots = sorted(results[task].keys(), key=lambda x: int(x))
all_seeds = []
shot_results = []
for shot in shots:
seeds = results[task][shot]
all_seeds.extend(seeds)
avg = sum(seeds) / len(seeds)
sts.stdev(seeds)
shot_results.append(f"{100*avg:.2f}Β±{100*sts.stdev(seeds):.2f}")
shot_results.append(f"{100*sts.mean(all_seeds):.2f}")
table.add_row(task, *shot_results)
console = Console()
console.print(table)
"""
Few-shot results wPT wInst
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃ Task ┃ 1-shot ┃ 5-shot ┃ 10-shot ┃ Avg. ┃
┑━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
β”‚ Ent_CoNLL03 β”‚ 77.50Β±1.64 β”‚ 82.73Β±2.29 β”‚ 84.48Β±1.62 β”‚ 81.57 β”‚
β”‚ Rel_CoNLL04 β”‚ 34.66Β±10.52 β”‚ 52.23Β±3.16 β”‚ 58.68Β±1.77 β”‚ 48.52 β”‚
β”‚ Event_ACE05_Trigger β”‚ 49.50Β±3.59 β”‚ 65.61Β±19.29 β”‚ 60.68Β±2.45 β”‚ 58.60 β”‚
β”‚ Event_ACE05_Arg β”‚ 23.46Β±1.66 β”‚ 48.32Β±28.91 β”‚ 41.90Β±1.95 β”‚ 37.89 β”‚
β”‚ ABSA_16res β”‚ 67.06Β±0.56 β”‚ 73.51Β±14.75 β”‚ 68.70Β±1.46 β”‚ 69.76 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜
Few-shot results wPT woInst
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃ Task ┃ 1-shot ┃ 5-shot ┃ 10-shot ┃ Avg. ┃
┑━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
β”‚ Ent_CoNLL03 β”‚ 76.33Β±1.74 β”‚ 82.50Β±1.87 β”‚ 84.47Β±1.18 β”‚ 81.10 β”‚
β”‚ woInst_Rel_CoNLL04 β”‚ 34.86Β±6.20 β”‚ 48.00Β±4.44 β”‚ 55.65Β±2.53 β”‚ 46.17 β”‚
β”‚ Rel_CoNLL04 β”‚ 26.83Β±15.22 β”‚ 47.39Β±3.60 β”‚ 55.38Β±2.41 β”‚ 43.20 β”‚
β”‚ Event_ACE05_Trigger β”‚ 46.60Β±1.09 β”‚ 57.21Β±3.51 β”‚ 59.67Β±3.20 β”‚ 54.49 β”‚
β”‚ Event_ACE05_Arg β”‚ 21.60Β±3.61 β”‚ 34.43Β±3.63 β”‚ 39.62Β±2.60 β”‚ 31.88 β”‚
β”‚ ABSA_16res β”‚ 8.10Β±18.11 β”‚ 52.73Β±5.52 β”‚ 57.32Β±1.73 β”‚ 39.38 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜
"""