|
from collections import defaultdict |
|
import logging |
|
from pathlib import Path |
|
from typing import Dict |
|
from typing import Iterable |
|
from typing import List |
|
from typing import Optional |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import torch |
|
from torch.nn.parallel import data_parallel |
|
from torch.utils.data import DataLoader |
|
from typeguard import check_argument_types |
|
|
|
from espnet2.fileio.datadir_writer import DatadirWriter |
|
from espnet2.fileio.npy_scp import NpyScpWriter |
|
from espnet2.torch_utils.device_funcs import to_device |
|
from espnet2.torch_utils.forward_adaptor import ForwardAdaptor |
|
from espnet2.train.abs_espnet_model import AbsESPnetModel |
|
|
|
|
|
@torch.no_grad() |
|
def collect_stats( |
|
model: AbsESPnetModel, |
|
train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], |
|
valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], |
|
output_dir: Path, |
|
ngpu: Optional[int], |
|
log_interval: Optional[int], |
|
write_collected_feats: bool, |
|
) -> None: |
|
"""Perform on collect_stats mode. |
|
|
|
Running for deriving the shape information from data |
|
and gathering statistics. |
|
This method is used before executing train(). |
|
|
|
""" |
|
assert check_argument_types() |
|
|
|
npy_scp_writers = {} |
|
for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]): |
|
if log_interval is None: |
|
try: |
|
log_interval = max(len(itr) // 20, 10) |
|
except TypeError: |
|
log_interval = 100 |
|
|
|
sum_dict = defaultdict(lambda: 0) |
|
sq_dict = defaultdict(lambda: 0) |
|
count_dict = defaultdict(lambda: 0) |
|
|
|
with DatadirWriter(output_dir / mode) as datadir_writer: |
|
for iiter, (keys, batch) in enumerate(itr, 1): |
|
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") |
|
|
|
|
|
for name in batch: |
|
if name.endswith("_lengths"): |
|
continue |
|
for i, (key, data) in enumerate(zip(keys, batch[name])): |
|
if f"{name}_lengths" in batch: |
|
lg = int(batch[f"{name}_lengths"][i]) |
|
data = data[:lg] |
|
datadir_writer[f"{name}_shape"][key] = ",".join( |
|
map(str, data.shape) |
|
) |
|
|
|
|
|
if ngpu <= 1: |
|
data = model.collect_feats(**batch) |
|
else: |
|
|
|
data = data_parallel( |
|
ForwardAdaptor(model, "collect_feats"), |
|
(), |
|
range(ngpu), |
|
module_kwargs=batch, |
|
) |
|
|
|
|
|
for key, v in data.items(): |
|
for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())): |
|
|
|
if f"{key}_lengths" in data: |
|
length = data[f"{key}_lengths"][i] |
|
|
|
seq = seq[:length] |
|
else: |
|
|
|
seq = seq[None] |
|
|
|
sum_dict[key] += seq.sum(0) |
|
sq_dict[key] += (seq ** 2).sum(0) |
|
count_dict[key] += len(seq) |
|
|
|
|
|
if write_collected_feats: |
|
|
|
if (key, mode) not in npy_scp_writers: |
|
p = output_dir / mode / "collect_feats" |
|
npy_scp_writers[(key, mode)] = NpyScpWriter( |
|
p / f"data_{key}", p / f"{key}.scp" |
|
) |
|
|
|
npy_scp_writers[(key, mode)][uttid] = seq |
|
|
|
if iiter % log_interval == 0: |
|
logging.info(f"Niter: {iiter}") |
|
|
|
for key in sum_dict: |
|
np.savez( |
|
output_dir / mode / f"{key}_stats.npz", |
|
count=count_dict[key], |
|
sum=sum_dict[key], |
|
sum_square=sq_dict[key], |
|
) |
|
|
|
|
|
with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f: |
|
f.write( |
|
"\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n" |
|
) |
|
with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f: |
|
f.write("\n".join(sum_dict) + "\n") |
|
|