Hugo Flores commited on
Commit
cc3a37b
1 Parent(s): 260b46d

rm old eval script

Browse files
Files changed (1) hide show
  1. scripts/exp/eval.py +0 -124
scripts/exp/eval.py DELETED
@@ -1,124 +0,0 @@
1
- import glob
2
- import imp
3
- import os
4
- from pathlib import Path
5
-
6
- import argbind
7
- import audiotools
8
- import numpy as np
9
- import pandas as pd
10
- import torch
11
- from flatten_dict import flatten
12
- from rich.progress import track
13
- from torch.utils.tensorboard import SummaryWriter
14
-
15
- import wav2wav
16
-
17
- train = imp.load_source("train", str(Path(__file__).absolute().parent / "train.py"))
18
-
19
-
20
- @argbind.bind(without_prefix=True)
21
- def evaluate(
22
- args,
23
- model_tag: str = "ckpt/best",
24
- device: str = "cuda",
25
- exp: str = None,
26
- overwrite: bool = False,
27
- ):
28
- assert exp is not None
29
-
30
- sisdr_loss = audiotools.metrics.distance.SISDRLoss()
31
- stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
32
- mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
33
-
34
- with audiotools.util.chdir(exp):
35
- vampnet = wav2wav.modules.vampnet.transformer.VampNet.load(
36
- f"{model_tag}/vampnet/package.pth"
37
- )
38
- vampnet = vampnet.to(device)
39
- if vampnet.cond_dim > 0:
40
- condnet = wav2wav.modules.condnet.transformer.CondNet.load(
41
- f"{model_tag}/condnet/package.pth"
42
- )
43
- condnet = condnet.to(device)
44
- else:
45
- condnet = None
46
-
47
- vqvae = wav2wav.modules.generator.Generator.load(
48
- f"{model_tag}/vqvae/package.pth"
49
- )
50
-
51
- _, _, test_data = train.build_datasets(args, vqvae.sample_rate)
52
-
53
- with audiotools.util.chdir(exp):
54
- datasets = {
55
- "test": test_data,
56
- }
57
-
58
- metrics_path = Path(f"{model_tag}/metrics")
59
- metrics_path.mkdir(parents=True, exist_ok=True)
60
-
61
- for key, dataset in datasets.items():
62
- csv_path = metrics_path / f"{key}.csv"
63
- if csv_path.exists() and not overwrite:
64
- break
65
- metrics = []
66
- for i in track(range(len(dataset))):
67
- # TODO: for coarse2fine
68
- # grab the signal
69
- # mask all the codebooks except the conditioning ones
70
- # and infer
71
- # then compute metrics
72
- # for a baseline, just use the coarsest codebook
73
-
74
- try:
75
- visqol = audiotools.metrics.quality.visqol(
76
- enhanced, clean, "audio"
77
- ).item()
78
- except:
79
- visqol = None
80
-
81
- sisdr = sisdr_loss(enhanced, clean)
82
- stft = stft_loss(enhanced, clean)
83
- mel = mel_loss(enhanced, clean)
84
-
85
- metrics.append(
86
- {
87
- "visqol": visqol,
88
- "sisdr": sisdr.item(),
89
- "stft": stft.item(),
90
- "mel": mel.item(),
91
- "dataset": key,
92
- "condition": exp,
93
- }
94
- )
95
- print(metrics[-1])
96
-
97
- transform_args = flatten(item["transform_args"], "dot")
98
- for k, v in transform_args.items():
99
- if torch.is_tensor(v):
100
- if len(v.shape) == 0:
101
- metrics[-1][k] = v.item()
102
-
103
- metrics = pd.DataFrame.from_dict(metrics)
104
- with open(csv_path, "w") as f:
105
- metrics.to_csv(f)
106
-
107
- data = summary(model_tag).to_dict()
108
- metrics = {}
109
- for k1, v1 in data.items():
110
- for k2, v2 in v1.items():
111
- metrics[f"metrics/{k2}/{k1}"] = v2
112
-
113
- # Number of steps to record
114
- writer = SummaryWriter(log_dir=metrics_path)
115
- num_steps = 10
116
- for k, v in metrics.items():
117
- for i in range(num_steps):
118
- writer.add_scalar(k, v, i)
119
-
120
-
121
- if __name__ == "__main__":
122
- args = argbind.parse_args()
123
- with argbind.scope(args):
124
- evaluate(args)