import os import json import random import argparse import torch import torchaudio from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset from huggingface_hub import upload_folder from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix from collections import Counter from transformers.integrations import TensorBoardCallback from transformers import ( Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback ) MODEL = "ntu-spml/distilhubert" # modelo base FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(MODEL) # feature extractor del modelo base seed = 123 MAX_DURATION = 1.00 # Máxima duración de los audios SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16kHz token = os.getenv("HF_TOKEN") config_file = "models_config.json" batch_size = 1024 # TODO: repasar si sigue siendo necesario num_workers = 12 # Núcleos de la CPU class AudioDataset(Dataset): def __init__(self, dataset_path, label2id, filter_white_noise, undersample_normal): self.dataset_path = dataset_path self.label2id = label2id self.file_paths = [] self.filter_white_noise = filter_white_noise self.labels = [] for label_dir, label_id in self.label2id.items(): label_path = os.path.join(self.dataset_path, label_dir) if os.path.isdir(label_path): for file_name in os.listdir(label_path): audio_path = os.path.join(label_path, file_name) self.file_paths.append(audio_path) self.labels.append(label_id) if undersample_normal and self.label2id: self.undersample_normal_class() def undersample_normal_class(self): normal_label = self.label2id.get('1s_normal') label_counts = Counter(self.labels) other_counts = [count for label, count in label_counts.items() if label != normal_label] if other_counts: # Ensure there are other counts before taking max target_count = max(other_counts) normal_indices = [i for i, label in enumerate(self.labels) if label == normal_label] keep_indices = random.sample(normal_indices, target_count) new_file_paths = [] new_labels = [] for i, (path, label) in enumerate(zip(self.file_paths, self.labels)): if label != normal_label or i in keep_indices: new_file_paths.append(path) new_labels.append(label) self.file_paths = new_file_paths self.labels = new_labels def __len__(self): return len(self.file_paths) def __getitem__(self, idx): audio_path = self.file_paths[idx] label = self.labels[idx] input_values = self.preprocess_audio(audio_path) return { "input_values": input_values, "labels": torch.tensor(label) } def preprocess_audio(self, audio_path): waveform, sample_rate = torchaudio.load( audio_path, normalize=True, ) if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE) waveform = resampler(waveform) if waveform.shape[0] > 1: # Si es stereo, convertir a mono waveform = waveform.mean(dim=0, keepdim=True) waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # TODO: probar a quitar porque ya se hace, sin 1e-6 el accuracy es pésimo!! max_length = int(SAMPLING_RATE * MAX_DURATION) if waveform.shape[1] > max_length: waveform = waveform[:, :max_length] # Truncar else: waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) # Padding inputs = FEATURE_EXTRACTOR( waveform.squeeze(), sampling_rate=SAMPLING_RATE, # Hecho a mano, por si acaso return_tensors="pt", ) return inputs.input_values.squeeze() def is_white_noise(audio): mean = torch.mean(audio) std = torch.std(audio) return torch.abs(mean) < 0.001 and std < 0.01 def seed_everything(): # TODO: mirar si es necesario algo más torch.manual_seed(seed) torch.cuda.manual_seed(seed) # torch.backends.cudnn.deterministic = True # Para reproducibilidad # torch.backends.cudnn.benchmark = False # Para reproducibilidad def build_label_mappings(dataset_path): label2id = {} id2label = {} label_id = 0 for label_dir in os.listdir(dataset_path): if os.path.isdir(os.path.join(dataset_path, label_dir)): label2id[label_dir] = label_id id2label[label_id] = label_dir label_id += 1 return label2id, id2label def compute_class_weights(labels): class_counts = Counter(labels) total_samples = len(labels) class_weights = {cls: total_samples / count for cls, count in class_counts.items()} return [class_weights[label] for label in labels] def create_dataloader(dataset_path, filter_white_noise, undersample_normal, test_size=0.2, shuffle=True, pin_memory=True): label2id, id2label = build_label_mappings(dataset_path) dataset = AudioDataset(dataset_path, label2id, filter_white_noise, undersample_normal) dataset_size = len(dataset) indices = list(range(dataset_size)) random.shuffle(indices) split_idx = int(dataset_size * (1 - test_size)) train_indices = indices[:split_idx] test_indices = indices[split_idx:] train_dataset = Subset(dataset, train_indices) test_dataset = Subset(dataset, test_indices) labels = [dataset.labels[i] for i in train_indices] class_weights = compute_class_weights(labels) sampler = WeightedRandomSampler( weights=class_weights, num_samples=len(train_dataset), replacement=True ) train_dataloader = DataLoader( train_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory ) test_dataloader = DataLoader( test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory ) return train_dataloader, test_dataloader, id2label def load_model(model_path, id2label, num_labels): config = HubertConfig.from_pretrained( pretrained_model_name_or_path=model_path, num_labels=num_labels, id2label=id2label, finetuning_task="audio-classification" ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = HubertForSequenceClassification.from_pretrained( pretrained_model_name_or_path=model_path, config=config, torch_dtype=torch.float32, # TODO: Comprobar si se necesita float32 y ver si se puede cambiar por float16 ) model.to(device) return model def train_params(dataset_path, filter_white_noise, undersample_normal): train_dataloader, test_dataloader, id2label = create_dataloader(dataset_path, filter_white_noise, undersample_normal) model = load_model(MODEL, id2label, num_labels=len(id2label)) return model, train_dataloader, test_dataloader, id2label def predict_params(dataset_path, model_path, filter_white_noise, undersample_normal): _, _, id2label = create_dataloader(dataset_path, filter_white_noise, undersample_normal) model = load_model(model_path, id2label, num_labels=len(id2label)) return model, id2label def compute_metrics(pred): labels = pred.label_ids preds = pred.predictions.argmax(-1) precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted') acc = accuracy_score(labels, preds) cm = confusion_matrix(labels, preds) return { 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall, 'confusion_matrix': cm.tolist() } def main(training_args, output_dir, dataset_path, filter_white_noise, undersample_normal): seed_everything() model, train_dataloader, test_dataloader, id2label = train_params(dataset_path, filter_white_noise, undersample_normal) early_stopping_callback = EarlyStoppingCallback( early_stopping_patience=5, early_stopping_threshold=0.001 ) trainer = Trainer( model=model, args=training_args, compute_metrics=compute_metrics, train_dataset=train_dataloader.dataset, eval_dataset=test_dataloader.dataset, callbacks=[TensorBoardCallback, early_stopping_callback] ) torch.cuda.empty_cache() # liberar memoria de la GPU trainer.train() # resume_from_checkpoint para continuar el train # trainer.save_model(output_dir) # Guardar modelo local. os.makedirs(output_dir, exist_ok=True) trainer.save_model(output_dir) # Guardar modelo local. eval_results = trainer.evaluate() print(f"Evaluation results: {eval_results}") trainer.push_to_hub(token=token) # Subir modelo a perfil upload_folder(repo_id=f"A-POR-LOS-8000/{output_dir}", folder_path=output_dir, token=token) # subir a organización y local def predict(audio_path): waveform, sample_rate = torchaudio.load(audio_path, normalize=True) if sample_rate != SAMPLING_RATE: resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE) waveform = resampler(waveform) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) max_length = int(SAMPLING_RATE * MAX_DURATION) if waveform.shape[1] > max_length: waveform = waveform[:, :max_length] else: waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) inputs = FEATURE_EXTRACTOR( waveform.squeeze(), sampling_rate=SAMPLING_RATE, return_tensors="pt", ) with torch.no_grad(): logits = model(inputs.input_values.to(model.device)).logits predicted_class_id = logits.argmax().item() predicted_label = id2label[predicted_class_id] return predicted_label, logits test_samples = random.sample(test_dataloader.dataset.dataset.file_paths, 15) for sample in test_samples: predicted_label, logits = predict(sample) print(f"File: {sample}") print(f"Predicted label: {predicted_label}") print(f"Logits: {logits}") print("---") def load_config(model_name): with open(config_file, 'r') as f: config = json.load(f) model_config = config[model_name] training_args = TrainingArguments(**model_config["training_args"]) model_config["training_args"] = training_args return model_config if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--n", choices=["mon", "class"], required=True, help="Elegir qué modelo entrenar" ) args = parser.parse_args() config = load_config(args.n) training_args = config["training_args"] output_dir = config["output_dir"] dataset_path = config["dataset_path"] if args.n == "mon": filter_white_noise = False undersample_normal = False elif args.n == "class": filter_white_noise = True undersample_normal = True main(training_args, output_dir, dataset_path, filter_white_noise, undersample_normal)