File size: 4,771 Bytes
5953ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import re
import statistics as sts
from collections import defaultdict
from pathlib import Path

from rex.utils.dict import get_dict_content
from rex.utils.io import load_json
from rich.console import Console
from rich.table import Table

inputs_dir = Path("mirror_fewshot_outputs")
# regex = re.compile(r"Mirror_SingleTask_(.*?)_seed(\d+)_(\d+)shot")
regex = re.compile(r"Mirror_wPT_woInst_(.*?)_seed(\d+)_(\d+)shot")

# task -> shot -> seeds
results = defaultdict(lambda: defaultdict(list))

for dirname in os.listdir(inputs_dir):
    dpath = inputs_dir / dirname
    re_matched = regex.match(dirname)
    if dpath.is_dir() and re_matched:
        task, seed, shot = re_matched.groups()
        results_json_p = dpath / "measures" / "test.final.json"
        metrics = load_json(results_json_p)
        if "Ent_" in task:
            results[task][shot].append(
                get_dict_content(metrics, "metrics.ent.micro.f1")
            )
        elif "Rel_" in task or "ABSA_" in task:
            results[task][shot].append(
                get_dict_content(metrics, "metrics.rel.rel.micro.f1")
            )
        elif "Event_" in task:
            results[task + "_Trigger"][shot].append(
                get_dict_content(metrics, "metrics.event.trigger_cls.f1")
            )
            results[task + "_Arg"][shot].append(
                get_dict_content(metrics, "metrics.event.arg_cls.f1")
            )
        else:
            raise RuntimeError

table = Table(title="Few-shot results")
table.add_column("Task", justify="center")
table.add_column("1-shot", justify="right")
table.add_column("5-shot", justify="right")
table.add_column("10-shot", justify="right")
table.add_column("Avg.", justify="right")
for task in results:
    shots = sorted(results[task].keys(), key=lambda x: int(x))
    all_seeds = []
    shot_results = []
    for shot in shots:
        seeds = results[task][shot]
        all_seeds.extend(seeds)
        avg = sum(seeds) / len(seeds)
        sts.stdev(seeds)
        shot_results.append(f"{100*avg:.2f}Β±{100*sts.stdev(seeds):.2f}")
    shot_results.append(f"{100*sts.mean(all_seeds):.2f}")
    table.add_row(task, *shot_results)

console = Console()
console.print(table)

"""
                            Few-shot results wPT wInst
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃        Task         ┃      1-shot ┃      5-shot ┃    10-shot ┃  Avg. ┃
┑━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
β”‚     Ent_CoNLL03     β”‚  77.50Β±1.64 β”‚  82.73Β±2.29 β”‚ 84.48Β±1.62 β”‚ 81.57 β”‚
β”‚     Rel_CoNLL04     β”‚ 34.66Β±10.52 β”‚  52.23Β±3.16 β”‚ 58.68Β±1.77 β”‚ 48.52 β”‚
β”‚ Event_ACE05_Trigger β”‚  49.50Β±3.59 β”‚ 65.61Β±19.29 β”‚ 60.68Β±2.45 β”‚ 58.60 β”‚
β”‚   Event_ACE05_Arg   β”‚  23.46Β±1.66 β”‚ 48.32Β±28.91 β”‚ 41.90Β±1.95 β”‚ 37.89 β”‚
β”‚     ABSA_16res      β”‚  67.06Β±0.56 β”‚ 73.51Β±14.75 β”‚ 68.70Β±1.46 β”‚ 69.76 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜

                           Few-shot results wPT woInst
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃        Task         ┃      1-shot ┃     5-shot ┃    10-shot ┃  Avg. ┃
┑━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
β”‚     Ent_CoNLL03     β”‚  76.33Β±1.74 β”‚ 82.50Β±1.87 β”‚ 84.47Β±1.18 β”‚ 81.10 β”‚
β”‚ woInst_Rel_CoNLL04  β”‚  34.86Β±6.20 β”‚ 48.00Β±4.44 β”‚ 55.65Β±2.53 β”‚ 46.17 β”‚
β”‚     Rel_CoNLL04     β”‚ 26.83Β±15.22 β”‚ 47.39Β±3.60 β”‚ 55.38Β±2.41 β”‚ 43.20 β”‚
β”‚ Event_ACE05_Trigger β”‚  46.60Β±1.09 β”‚ 57.21Β±3.51 β”‚ 59.67Β±3.20 β”‚ 54.49 β”‚
β”‚   Event_ACE05_Arg   β”‚  21.60Β±3.61 β”‚ 34.43Β±3.63 β”‚ 39.62Β±2.60 β”‚ 31.88 β”‚
β”‚     ABSA_16res      β”‚  8.10Β±18.11 β”‚ 52.73Β±5.52 β”‚ 57.32Β±1.73 β”‚ 39.38 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜
"""