File size: 23,887 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
import numpy as np
import yaml
import copy
from tqdm import tqdm
from torchaudio.compliance import kaldi
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from fairseq import checkpoint_utils
from transformers import AutoModel, Wav2Vec2FeatureExtractor

from utils.io_optim import (
    TorchaudioDataset,
    LibrosaDataset,
    FFmpegDataset,
    collate_batch,
)
import whisper
from modules.wenet_extractor.utils.init_model import init_model
from modules.wenet_extractor.utils.checkpoint import load_checkpoint

"""
    Extractor for content features
    1. whisper
    2. contentvec
    3. wenet
    4. mert

    Pipeline:
        in preprocess.py:
            call extract_utt_content_features() to extract content features for each utterance
            extract_utt_content_features() envelopes the following steps:
                1. load the model (whisper, contentvec, wenet)
                2. extract the content features
                3. save the content features into files
        in svc_dataset.py:
            call offline_align() to align the content features to the given target length

"""

"""
    Extractor Usage:
        1. initialize an instance of extractor
            extractor = WhisperExtractor(cfg)
        2. load the specified model
            extractor.load_model()
        3. extract the content features
            extractor.extract_content(utt) for single utterance
            extractor.extract_content_batch(utts) for batch utterances
        4. save the content features
            extractor.save_feature(utt, content_feature) for single utterance
"""


class AudioPretrainedModelFeaturesExtractor:
    def __init__(self, cfg, extractor_type):
        self.cfg = cfg
        self.extractor_type = extractor_type
        self.model = None
        self.init_for_retrans()

    def init_for_retrans(self):
        target_hop = self.cfg.preprocess.hop_size

        assert self.extractor_type in ["whisper", "contentvec", "wenet"]
        if self.extractor_type == "whisper":
            source_hop = (
                self.cfg.preprocess.whisper_frameshift
                * self.cfg.preprocess.whisper_downsample_rate
                * self.cfg.preprocess.sample_rate
            )
        elif self.extractor_type == "contentvec":
            source_hop = (
                self.cfg.preprocess.contentvec_frameshift
                * self.cfg.preprocess.sample_rate
            )
        elif self.extractor_type == "wenet":
            source_hop = (
                self.cfg.preprocess.wenet_frameshift
                * self.cfg.preprocess.wenet_downsample_rate
                * self.cfg.preprocess.sample_rate
            )
        source_hop = int(source_hop)
        factor = np.gcd(source_hop, target_hop)
        source_hop //= factor
        target_hop //= factor

        self.source_hop = source_hop
        self.target_hop = target_hop

    def offline_resolution_transformation(self, content, target_len):
        """
        args:
            content: (source_len, dim)
            target_len: target length
        return:
            mapped_feature: (target_len, dim)
        """
        source_hop = self.source_hop
        target_hop = self.target_hop

        # (source_len, 256)
        _, width = content.shape
        # slice the content from padded feature
        source_len = min(target_len * target_hop // source_hop + 1, len(content))

        # const ~= target_len * target_hop
        const = source_len * source_hop // target_hop * target_hop

        # (source_len * source_hop, dim)
        up_sampling_feats = np.repeat(content, source_hop, axis=0)
        # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
        down_sampling_feats = np.average(
            up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
        )

        err = abs(target_len - len(down_sampling_feats))
        if err > 8:
            # err_log_dir is indeterminate
            err_log_dir = os.path.join(
                self.cfg.preprocess.processed_dir, "align_max_err.log"
            )
            try:
                with open(err_log_dir, "r") as f:
                    err_num = int(f.read())
            except:
                with open(err_log_dir, "w") as f:
                    f.write("0")
                err_num = 0
            if err > err_num:
                with open(err_log_dir, "w") as f:
                    f.write(str(err))

        if len(down_sampling_feats) < target_len:
            # (1, dim) -> (err, dim)
            end = down_sampling_feats[-1][None, :].repeat(err, axis=0)
            down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0)

        # (target_len, dim)
        mapped_feature = down_sampling_feats[:target_len]

        return mapped_feature

    def log_for_ReTrans(self, err):
        err_log_dir = os.path.join(
            self.cfg.preprocess.processed_dir, "align_max_err.log"
        )
        try:
            with open(err_log_dir, "r") as f:
                err_num = int(f.read())
        except:
            with open(err_log_dir, "w") as f:
                f.write("0")
            err_num = 0
        if err > err_num:
            with open(err_log_dir, "w") as f:
                f.write(str(err))

    def ReTrans(self, source_feats, padded_target_len):
        """
        Resolution Transformation for mismatched frames alginment.

        TODO: Merge the offline resolution_transformation into one

        args:
            source_feats: Tensor, (B, padded_source_len, D)
            padded_target_len: int, the maximum target length in a batch
        return:
            mapped_feature: Tensor, (B, padded_target_len, D)
        """
        source_hop = self.source_hop
        target_hop = self.target_hop

        # (B, padded_source_len, D)
        B, padded_source_len, D = source_feats.shape

        # select the valid content from padded feature
        source_len = min(
            padded_target_len * target_hop // source_hop + 1, padded_source_len
        )

        # const ~= padded_target_len * target_hop (padded wav's duration)
        const = source_len * source_hop // target_hop * target_hop

        # (B, padded_source_len, D) -> (B, padded_source_len * source_hop, D) -> (B, const, D)
        up_sampling_feats = torch.repeat_interleave(source_feats, source_hop, dim=1)[
            :, :const
        ]
        # (B, const, D) -> (B, const/target_hop, target_hop, D) -> (B, const/target_hop, D)
        down_sampling_feats = torch.mean(
            up_sampling_feats.reshape(B, -1, target_hop, D), dim=2
        )

        err = abs(padded_target_len - down_sampling_feats.shape[1])
        if err > 8:
            self.log_for_ReTrans(err)

        if down_sampling_feats.shape[1] < padded_target_len:
            # (B, 1, D) -> (B, err, D)
            end = down_sampling_feats[:, -1, :][:, None, :].repeat_interleave(
                err, dim=1
            )
            # -> (B, padded_target_len, D)
            down_sampling_feats = torch.cat([down_sampling_feats, end], dim=1)

        # (B, padded_target_len, D)
        mapped_feature = down_sampling_feats[:, :padded_target_len]
        return mapped_feature

    def get_valid_features(self, utt, content_feature):
        # only keep effective parts
        duration = utt["Duration"]
        if self.extractor_type == "whisper":
            frameshift = (
                self.cfg.preprocess.whisper_frameshift
                * self.cfg.preprocess.whisper_downsample_rate
            )  # 20ms
        elif self.extractor_type == "contentvec":
            frameshift = self.cfg.preprocess.contentvec_frameshift  # 20ms
        elif self.extractor_type == "wenet":
            frameshift = (
                self.cfg.preprocess.wenet_frameshift
                * self.cfg.preprocess.wenet_downsample_rate
            )  # 40ms
        elif self.extractor_type == "mert":
            frameshift = self.cfg.preprocess.mert_frameshift
        else:
            raise NotImplementedError

        # calculate the number of valid frames
        num_frames = int(np.ceil((duration - frameshift) / frameshift)) + 1
        assert (
            len(content_feature.shape) == 2
        ), "content feature shape error, it should be (num_frames, dim)"
        content_feature = content_feature[:num_frames, :]
        return content_feature

    def save_feature(self, utt, content_feature):
        """Save a single utternace to path {cfg.preprocess.processed_dir}

        Args:
            utt (dict): one item in metadata, containing information for one utterance
            content_feature (tensor): content feature of one utterance
        """
        uid = utt["Uid"]
        assert self.extractor_type != None
        out_dir = os.path.join(
            self.cfg.preprocess.processed_dir, utt["Dataset"], self.extractor_type
        )
        os.makedirs(out_dir, exist_ok=True)
        save_path = os.path.join(out_dir, uid + ".npy")

        content_feature = self.get_valid_features(utt, content_feature)
        np.save(save_path, content_feature.cpu().detach().numpy())


class WhisperExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, config):
        super(WhisperExtractor, self).__init__(config, extractor_type="whisper")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def load_model(self):
        # load whisper checkpoint
        print("Loading Whisper Model...")

        if "whisper_model_path" in self.cfg.preprocess:
            if os.path.isfile(self.cfg.preprocess.whisper_model_path):
                # "pretrained/whisper/medium.pt"
                download_root = os.path.dirname(self.cfg.preprocess.whisper_model_path)
            elif os.path.isdir(self.cfg.preprocess.whisper_model_path):
                # "pretrained/whisper"
                download_root = self.cfg.preprocess.whisper_model_path
            else:
                # if the path does not exist, download the model to the path
                download_root = self.cfg.preprocess.whisper_model_path
                if download_root.endswith(".pt"):
                    download_root = os.path.dirname(download_root)
        else:
            download_root = None

        model = whisper.load_model(
            self.cfg.preprocess.whisper_model, self.device, download_root
        )
        if torch.cuda.is_available():
            print("Using GPU...\n")
            model = model.cuda()
        else:
            print("Using CPU...\n")

        self.model = model.eval()

    def extract_content_features(self, wavs):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor (batch_size, T)
        """
        # wavs: (batch, max_len)
        wavs = whisper.pad_or_trim(wavs)
        # batch_mel: (batch, 80, 3000)
        batch_mel = whisper.log_mel_spectrogram(wavs, device=self.model.device)
        with torch.no_grad():
            # (batch, 1500, 1024)
            features = self.model.embed_audio(batch_mel)
        return features


class ContentvecExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, cfg):
        super(ContentvecExtractor, self).__init__(cfg, extractor_type="contentvec")

    def load_model(self):
        assert self.model == None
        # Load model
        ckpt_path = self.cfg.preprocess.contentvec_file
        print("Load Contentvec Model...")

        models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
            [ckpt_path],
            suffix="",
        )
        model = models[0]
        model.eval()

        if torch.cuda.is_available():
            # print("Using GPU...\n")
            model = model.cuda()

        self.model = model

    def extract_content_features(self, wavs):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor (batch, T)
        """
        device = next(self.model.parameters()).device
        wavs = wavs.to(device)  # (batch, max_len)
        padding_mask = torch.eq(wavs, torch.zeros_like(wavs)).to(device)
        with torch.no_grad():
            logits = self.model.extract_features(
                source=wavs, padding_mask=padding_mask, output_layer=12
            )
            # feats: (batch, T, 256)
            feats = self.model.final_proj(logits[0])
        return feats


class WenetExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, config):
        super(WenetExtractor, self).__init__(config, extractor_type="wenet")

    def load_model(self):
        wenet_cfg = self.cfg.preprocess.wenet_config
        wenet_model_path = self.cfg.preprocess.wenet_model_path
        # load Wenet config
        with open(wenet_cfg, "r") as w:
            wenet_configs = yaml.load(w, Loader=yaml.FullLoader)
        self.extract_conf = copy.deepcopy(wenet_configs["dataset_conf"])
        print("Loading Wenet Model...")
        self.model = init_model(wenet_configs)
        load_checkpoint(self.model, wenet_model_path)

        if torch.cuda.is_available():
            print("Using GPU...\n")
            self.model = self.model.cuda()
        else:
            print("Using CPU...\n")

        self.model = self.model.eval()

    def extract_content_features(self, wavs, lens):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor, whose shape is (B, T)
            lens: list
        """
        feats_list = []
        lengths_list = []

        device = next(self.model.parameters()).device
        # Extract fbank/mfcc features by kaldi
        assert self.extract_conf is not None, "load model first!"
        feats_type = self.extract_conf.get("feats_type", "fbank")
        assert feats_type in ["fbank", "mfcc"]

        for idx, wav in enumerate(wavs):
            # wav: (T)
            wav = wav[: lens[idx]].to(device)

            # pad one frame to compensate for the frame cut off after feature extraction
            pad_tensor = torch.zeros(160, device=wav.device)
            wav = torch.cat((wav, pad_tensor), dim=-1)
            wav *= 1 << 15

            wav = wav.unsqueeze(0)  # (T) -> (1, T)
            if feats_type == "fbank":
                fbank_conf = self.extract_conf.get("fbank_conf", {})
                feat = kaldi.fbank(
                    wav,
                    sample_frequency=16000,
                    num_mel_bins=fbank_conf["num_mel_bins"],
                    frame_length=fbank_conf["frame_length"],
                    frame_shift=fbank_conf["frame_shift"],
                    dither=fbank_conf["dither"],
                )
            elif feats_type == "mfcc":
                mfcc_conf = self.extract_conf.get("mfcc", {})
                feat = kaldi.mfcc(
                    wav,
                    sample_frequency=16000,
                    num_mel_bins=mfcc_conf["num_mel_bins"],
                    frame_length=mfcc_conf["frame_length"],
                    frame_shift=mfcc_conf["frame_shift"],
                    dither=mfcc_conf["dither"],
                    num_ceps=mfcc_conf.get("num_ceps", 40),
                    high_freq=mfcc_conf.get("high_freq", 0.0),
                    low_freq=mfcc_conf.get("low_freq", 20.0),
                )
            feats_list.append(feat)
            lengths_list.append(feat.shape[0])

        feats_lengths = torch.tensor(lengths_list, dtype=torch.int32).to(device)
        feats_tensor = pad_sequence(feats_list, batch_first=True).to(
            device
        )  # (batch, len, 80)

        features = self.model.encoder_extractor(
            feats_tensor,
            feats_lengths,
            decoding_chunk_size=-1,
            num_decoding_left_chunks=-1,
            simulate_streaming=False,
        )
        return features


class MertExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, cfg):
        super(MertExtractor, self).__init__(cfg, extractor_type="mert")
        self.preprocessor = None

    def load_model(self):
        assert self.model == None
        assert self.preprocessor == None

        print("Loading MERT Model: ...", self.cfg.preprocess.mert_model)

        model_name = self.cfg.preprocess.mert_model
        model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

        if torch.cuda.is_available():
            model = model.cuda()
        preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(
            model_name, trust_remote_code=True
        )

        self.model = model
        self.preprocessor = preprocessor

    def extract_content_features(self, wavs):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor (batch, T)
        """
        with torch.no_grad():
            sample_rate = self.preprocessor.sampling_rate
            device = next(self.model.parameters()).device
            assert (
                sample_rate == self.cfg.preprocess.mert_sample_rate
            ), "mert sample rate mismatch, expected {}, got {}".format(
                self.cfg.preprocess.mert_sample_rate, sample_rate
            )
            mert_features = []
            # wav: (len)
            for wav in wavs:
                # {input_values: tensor, attention_mask: tensor}
                inputs = self.preprocessor(
                    wavs, sampling_rate=sample_rate, return_tensors="pt"
                ).to(device)

                outputs = self.model(**inputs, output_hidden_states=True)
                # (25 layers, time steps, 1024 feature_dim)
                all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
                # (1, frame_len, 1024) -> (frame_len, 1024)
                feature = outputs.hidden_states[
                    self.cfg.preprocess.mert_feature_layer
                ].squeeze(0)
                mert_features.append(feature)

        return mert_features


def extract_utt_content_features_dataloader(cfg, metadata, num_workers):
    dataset_name = metadata[0]["Dataset"]
    with torch.no_grad():
        if cfg.preprocess.extract_whisper_feature:
            feat_dir = os.path.join(
                cfg.preprocess.processed_dir, dataset_name, "whisper"
            )
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                whisper_waveforms = FFmpegDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.whisper_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    whisper_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = WhisperExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(wavs)
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])

        if cfg.preprocess.extract_contentvec_feature:
            feat_dir = os.path.join(
                cfg.preprocess.processed_dir, dataset_name, "contentvec"
            )
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                contentvec_waveforms = LibrosaDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.contentvec_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    contentvec_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = ContentvecExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(wavs)
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])

        if cfg.preprocess.extract_wenet_feature:
            feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "wenet")
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                wenet_waveforms = TorchaudioDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.wenet_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    wenet_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = WenetExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(
                        wavs,
                        lens,
                    )
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])

        if cfg.preprocess.extract_mert_feature:
            feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "mert")
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                mert_waveforms = TorchaudioDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.mert_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    mert_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = MertExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(wavs)
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])