File size: 8,407 Bytes
3a98934
33c23f4
 
1d21972
33c23f4
 
 
a1c7d58
33c23f4
166aa6c
5195c9e
33c23f4
5195c9e
 
 
33c23f4
5195c9e
 
 
 
33c23f4
a1c7d58
5195c9e
 
 
33c23f4
abdf62b
33c23f4
 
abdf62b
33c23f4
 
abdf62b
33c23f4
 
 
 
 
 
 
 
 
abdf62b
33c23f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1c7d58
abdf62b
33c23f4
a1c7d58
abdf62b
33c23f4
abdf62b
33c23f4
a1c7d58
abdf62b
33c23f4
 
abdf62b
 
33c23f4
 
5195c9e
abdf62b
 
 
 
 
5195c9e
 
 
 
 
 
 
33c23f4
 
 
 
 
 
 
 
 
 
5195c9e
abdf62b
33c23f4
abdf62b
33c23f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5195c9e
166aa6c
33c23f4
abdf62b
5195c9e
 
33c23f4
 
 
 
 
abdf62b
33c23f4
abdf62b
5195c9e
33c23f4
5195c9e
 
abdf62b
 
166aa6c
33c23f4
5195c9e
abdf62b
 
166aa6c
 
abdf62b
5195c9e
33c23f4
 
 
 
5195c9e
33c23f4
 
 
 
5195c9e
 
abdf62b
33c23f4
abdf62b
5195c9e
 
 
 
33c23f4
 
 
5195c9e
 
 
abdf62b
 
 
 
5195c9e
 
 
 
 
 
 
 
 
 
1d21972
 
 
 
 
 
 
5195c9e
 
 
1d21972
 
 
 
abdf62b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import os
import json
import random
import argparse
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from huggingface_hub import upload_folder
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers.integrations import TensorBoardCallback
from transformers import (
    Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
    Trainer, TrainingArguments,
    EarlyStoppingCallback
)

MODEL = "ntu-spml/distilhubert" # modelo base utilizado, para usar otro basta con cambiar esto
FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(MODEL)
seed = 123
MAX_DURATION = 1.00
SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16000
token = os.getenv("HF_TOKEN") # TODO: probar a guardar el token en un archivo en local
config_file = "models_config.json"
clasificador = "class"
monitor = "mon"
batch_size = 16
num_workers = 12

class AudioDataset(Dataset):
    def __init__(self, dataset_path, label2id, filter_white_noise):
        self.dataset_path = dataset_path
        self.label2id = label2id
        self.filter_white_noise = filter_white_noise
        self.file_paths = []
        self.labels = []
        for label_dir, label_id in self.label2id.items():
            label_path = os.path.join(self.dataset_path, label_dir)
            if os.path.isdir(label_path):
                for file_name in os.listdir(label_path):
                    audio_path = os.path.join(label_path, file_name)
                    self.file_paths.append(audio_path)
                    self.labels.append(label_id)
        self.file_paths.sort(key=lambda x: x.split('_part')[0]) # no sé si influye

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        audio_path = self.file_paths[idx]
        label = self.labels[idx]
        input_values = self.preprocess_audio(audio_path)
        return {
            "input_values": input_values,
            "labels": torch.tensor(label)
        }

    def preprocess_audio(self, audio_path):
        waveform, sample_rate = torchaudio.load(
            audio_path,
            normalize=True, # Convierte a float32
            )
        if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz
            resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1: # Si es stereo, convertir a mono
            waveform = waveform.mean(dim=0, keepdim=True)
        waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # Normalizar, sin 1e-6 el accuracy es pésimo!!
        max_length = int(SAMPLING_RATE * MAX_DURATION)
        if waveform.shape[1] > max_length:
            waveform = waveform[:, :max_length] # Truncar
        else:
            waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) # Padding
        inputs = FEATURE_EXTRACTOR(
            waveform.squeeze(),
            sampling_rate=SAMPLING_RATE, # Hecho a mano, por si acaso
            return_tensors="pt",
            # max_length=int(SAMPLING_RATE * MAX_DURATION),
            # truncation=True, # Hecho a mano
            # padding=True, # Hecho a mano
        )
        return inputs.input_values.squeeze()

def is_white_noise(audio):
    mean = torch.mean(audio)
    std = torch.std(audio)
    return torch.abs(mean) < 0.001 and std < 0.01

def seed_everything():
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16384:8'

def build_label_mappings(dataset_path):
    label2id = {}
    id2label = {}
    label_id = 0
    for label_dir in os.listdir(dataset_path):
        if os.path.isdir(os.path.join(dataset_path, label_dir)):
            label2id[label_dir] = label_id
            id2label[label_id] = label_dir
            label_id += 1
    return label2id, id2label

def create_dataloader(dataset_path, filter_white_noise, test_size=0.2, shuffle=True, pin_memory=True):
    label2id, id2label = build_label_mappings(dataset_path)
    dataset = AudioDataset(dataset_path, label2id, filter_white_noise)
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    random.shuffle(indices)
    split_idx = int(dataset_size * (1 - test_size))
    train_indices = indices[:split_idx]
    test_indices = indices[split_idx:]
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    train_dataloader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
    )
    return train_dataloader, test_dataloader, label2id, id2label

def load_model(model_path, label2id, id2label, num_labels):
    config = HubertConfig.from_pretrained(
        pretrained_model_name_or_path=model_path,
        num_labels=num_labels,
        label2id=label2id,
        id2label=id2label,
        finetuning_task="audio-classification"
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = HubertForSequenceClassification.from_pretrained( # TODO: mirar parámetros. Posibles optimizaciones
        pretrained_model_name_or_path=model_path,
        config=config,
        torch_dtype=torch.float32,
    )
    model.to(device)
    return model

def train_params(dataset_path, filter_white_noise):
    train_dataloader, test_dataloader, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
    model = load_model(MODEL, label2id, id2label, num_labels=len(id2label))    
    return model, train_dataloader, test_dataloader, id2label

def predict_params(dataset_path, model_path, filter_white_noise):
    _, _, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
    model = load_model(model_path, label2id, id2label, num_labels=len(id2label))
    return model, id2label

def compute_metrics(eval_pred):
    predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
    references = torch.tensor(eval_pred.label_ids)
    accuracy = accuracy_score(references, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(references, predictions, average='weighted')
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

def main(training_args, output_dir, dataset_path, filter_white_noise):
    seed_everything()
    model, train_dataloader, test_dataloader, _ = train_params(dataset_path, filter_white_noise)
    trainer = Trainer(
        model=model,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=train_dataloader.dataset,
        eval_dataset=test_dataloader.dataset,
        callbacks=[TensorBoardCallback(), EarlyStoppingCallback(early_stopping_patience=3)]
    )
    torch.cuda.empty_cache() # liberar memoria de la GPU
    trainer.train() # se pueden modificar los parámetros para continuar el train
    # trainer.save_model(output_dir) # Guardar modelo local.
    os.makedirs(output_dir, exist_ok=True) # Crear carpeta
    trainer.push_to_hub(token=token) # Subir modelo a perfil
    upload_folder(repo_id=f"A-POR-LOS-8000/{output_dir}", folder_path=output_dir, token=token) # subir a organización y local

def load_config(model_name):
    with open(config_file, 'r') as f:
        config = json.load(f)
    model_config = config[model_name]
    training_args = TrainingArguments(**model_config["training_args"])
    model_config["training_args"] = training_args
    return model_config

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--n", choices=["mon", "class"], 
        required=True, help="Elegir qué modelo entrenar"
        )
    args = parser.parse_args()
    config = load_config(args.n)
    training_args = config["training_args"]
    output_dir = config["output_dir"]
    dataset_path = config["dataset_path"]
    if args.n == "mon":
        filter_white_noise = False
    elif args.n == "class":
        filter_white_noise = True
    main(training_args, output_dir, dataset_path, filter_white_noise)