File size: 3,994 Bytes
ae81e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Logging utilities to make terminal slightly more delightful
"""
import rich.syntax
import rich.tree

from omegaconf import OmegaConf, DictConfig, ListConfig


def _format_arg(arg_name: str, cutoff=2) -> str:
    if arg_name is None:
        return arg_name
    arg_name = str(arg_name)
    
    # Hardcode to handle backslash
    name_splits = arg_name.split('/')
    if len(name_splits) > 1:
        return name_splits[-1]
    # Abbreviate based on underscore
    name_splits = arg_name.split('_')
    if len(name_splits) > 1:
        return ''.join([s[0] for s in name_splits])
    else:
        return arg_name[:cutoff]


def print_header(x: str) -> None:
    """
    Print a header with a line above and below
    """
    print('-' * len(x))
    print(x)
    print('-' * len(x))


def print_args(args, return_dict=False, verbose=True):
    """
    Print the arguments passed to the script
    """
    attributes = [a for a in dir(args) if a[0] != '_']
    arg_dict = {}  # switched to ewr
    if verbose:
        print('ARGPARSE ARGS')
    for ix, attr in enumerate(attributes):
        fancy = '└─' if ix == len(attributes) - 1 else 'β”œβ”€'
        if verbose:
            print(f'{fancy} {attr}: {getattr(args, attr)}')
        arg_dict[attr] = getattr(args, attr)
    if return_dict:
        return arg_dict


def update_description_metrics(description: str, metrics: dict):
    """
    Set the numbers that show up on progress bars
    """
    for split in metrics:
        if split != 'test':  # No look
            for metric_name, metric in metrics[split].items():
                description += f' | {split}/{metric_name}: {metric:.3f}'
    return description

        
# Control how tqdm progress bar looks        
def type_of_script():
    try:
        ipy_str = str(type(get_ipython()))
        if 'zmqshell' in ipy_str:
            return 'jupyter'
        if 'terminal' in ipy_str:
            return 'ipython'
    except:
        return 'terminal'
    
# Progress bar 
def update_pbar_display(metrics, batch_ix, pbar, prefix, batch_size, accum_iter=1):
    description = f'└── {prefix} batch {int(batch_ix)}/{len(pbar)} [batch size: {batch_size} - grad. accum. over {accum_iter} batch(es)]'
    for metric_name, metric in metrics.items():
        if metric_name == 'correct':
            description += f' | {metric_name} (acc. %): {int(metric):>5d}/{int(metrics["total"])} = {metric / metrics["total"] * 100:.3f}%'
        elif metric_name == 'acc':
            description += f' | {metric_name}: {metric:.3f}'
        elif metric_name in ['perplexity']:  # , 'bpc']:
            description += f' | {metric_name}: {Decimal(metric):.3E}'            
        elif metric_name != 'total':
            description += f' | {metric_name}: {metric / metrics["total"]:.3f}'
    pbar.set_description(description)
    
    
def print_config(config: DictConfig,
                 resolve: bool = True,
                 name: str = 'CONFIG') -> None:
    """Prints content of DictConfig using Rich library and its tree structure.
    Args:
        config (DictConfig): Configuration composed by Hydra.
        fields (Sequence[str], optional): Determines which main fields from config will
        be printed and in what order.
        resolve (bool, optional): Whether to resolve reference fields of DictConfig.
    """

    style = "bright"  # "dim"
    tree = rich.tree.Tree(name, style=style, guide_style=style)

    fields = config.keys()
    for field in fields:
        branch = tree.add(field, style=style, guide_style=style)

        config_section = config.get(field)
        branch_content = str(config_section)
        if isinstance(config_section, DictConfig):
            branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
        elif isinstance(config_section, ListConfig):
            branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)

        branch.add(rich.syntax.Syntax(branch_content, "yaml"))

    rich.print(tree)