File size: 3,994 Bytes
ae81e0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
Logging utilities to make terminal slightly more delightful
import rich.syntax
import rich.tree
from omegaconf import OmegaConf, DictConfig, ListConfig
def _format_arg(arg_name: str, cutoff=2) -> str:
if arg_name is None:
return arg_name
arg_name = str(arg_name)
# Hardcode to handle backslash
name_splits = arg_name.split('/')
if len(name_splits) > 1:
return name_splits[-1]
# Abbreviate based on underscore
name_splits = arg_name.split('_')
if len(name_splits) > 1:
return ''.join([s[0] for s in name_splits])
return arg_name[:cutoff]
def print_header(x: str) -> None:
Print a header with a line above and below
print('-' * len(x))
print('-' * len(x))
def print_args(args, return_dict=False, verbose=True):
Print the arguments passed to the script
attributes = [a for a in dir(args) if a[0] != '_']
arg_dict = {} # switched to ewr
if verbose:
for ix, attr in enumerate(attributes):
fancy = 'ββ' if ix == len(attributes) - 1 else 'ββ'
if verbose:
print(f'{fancy} {attr}: {getattr(args, attr)}')
arg_dict[attr] = getattr(args, attr)
if return_dict:
return arg_dict
def update_description_metrics(description: str, metrics: dict):
Set the numbers that show up on progress bars
for split in metrics:
if split != 'test': # No look
for metric_name, metric in metrics[split].items():
description += f' | {split}/{metric_name}: {metric:.3f}'
return description
# Control how tqdm progress bar looks
def type_of_script():
ipy_str = str(type(get_ipython()))
if 'zmqshell' in ipy_str:
return 'jupyter'
if 'terminal' in ipy_str:
return 'ipython'
return 'terminal'
# Progress bar
def update_pbar_display(metrics, batch_ix, pbar, prefix, batch_size, accum_iter=1):
description = f'βββ {prefix} batch {int(batch_ix)}/{len(pbar)} [batch size: {batch_size} - grad. accum. over {accum_iter} batch(es)]'
for metric_name, metric in metrics.items():
if metric_name == 'correct':
description += f' | {metric_name} (acc. %): {int(metric):>5d}/{int(metrics["total"])} = {metric / metrics["total"] * 100:.3f}%'
elif metric_name == 'acc':
description += f' | {metric_name}: {metric:.3f}'
elif metric_name in ['perplexity']: # , 'bpc']:
description += f' | {metric_name}: {Decimal(metric):.3E}'
elif metric_name != 'total':
description += f' | {metric_name}: {metric / metrics["total"]:.3f}'
def print_config(config: DictConfig,
resolve: bool = True,
name: str = 'CONFIG') -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
config (DictConfig): Configuration composed by Hydra.
fields (Sequence[str], optional): Determines which main fields from config will
be printed and in what order.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
style = "bright" # "dim"
tree = rich.tree.Tree(name, style=style, guide_style=style)
fields = config.keys()
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, DictConfig):
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
elif isinstance(config_section, ListConfig):
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
rich.print(tree) |