CHATBOT / model.py
Marcos12886's picture
Update model.py
c13002c verified
raw
history blame
11.6 kB
import os
import json
import random
import argparse
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
from huggingface_hub import upload_folder
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from collections import Counter
from transformers.integrations import TensorBoardCallback
from transformers import (
Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
Trainer, TrainingArguments,
EarlyStoppingCallback
)
MODEL = "ntu-spml/distilhubert" # modelo base
FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(MODEL) # feature extractor del modelo base
seed = 123
MAX_DURATION = 1.00 # Máxima duración de los audios
SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16kHz
token = os.getenv("HF_TOKEN")
config_file = "models_config.json"
batch_size = 1024 # TODO: repasar si sigue siendo necesario
num_workers = 12 # Núcleos de la CPU
class AudioDataset(Dataset):
def __init__(self, dataset_path, label2id, filter_white_noise, undersample_normal):
self.dataset_path = dataset_path
self.label2id = label2id
self.file_paths = []
self.filter_white_noise = filter_white_noise
self.labels = []
for label_dir, label_id in self.label2id.items():
label_path = os.path.join(self.dataset_path, label_dir)
if os.path.isdir(label_path):
for file_name in os.listdir(label_path):
audio_path = os.path.join(label_path, file_name)
self.file_paths.append(audio_path)
self.labels.append(label_id)
if undersample_normal and self.label2id:
self.undersample_normal_class()
def undersample_normal_class(self):
normal_label = self.label2id.get('1s_normal')
label_counts = Counter(self.labels)
other_counts = [count for label, count in label_counts.items() if label != normal_label]
if other_counts: # Ensure there are other counts before taking max
target_count = max(other_counts)
normal_indices = [i for i, label in enumerate(self.labels) if label == normal_label]
keep_indices = random.sample(normal_indices, target_count)
new_file_paths = []
new_labels = []
for i, (path, label) in enumerate(zip(self.file_paths, self.labels)):
if label != normal_label or i in keep_indices:
new_file_paths.append(path)
new_labels.append(label)
self.file_paths = new_file_paths
self.labels = new_labels
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
audio_path = self.file_paths[idx]
label = self.labels[idx]
input_values = self.preprocess_audio(audio_path)
return {
"input_values": input_values,
"labels": torch.tensor(label)
}
def preprocess_audio(self, audio_path):
waveform, sample_rate = torchaudio.load(
audio_path,
normalize=True,
)
if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz
resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
waveform = resampler(waveform)
if waveform.shape[0] > 1: # Si es stereo, convertir a mono
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # TODO: probar a quitar porque ya se hace, sin 1e-6 el accuracy es pésimo!!
max_length = int(SAMPLING_RATE * MAX_DURATION)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length] # Truncar
else:
waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) # Padding
inputs = FEATURE_EXTRACTOR(
waveform.squeeze(),
sampling_rate=SAMPLING_RATE, # Hecho a mano, por si acaso
return_tensors="pt",
)
return inputs.input_values.squeeze()
def is_white_noise(audio):
mean = torch.mean(audio)
std = torch.std(audio)
return torch.abs(mean) < 0.001 and std < 0.01
def seed_everything(): # TODO: mirar si es necesario algo más
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.backends.cudnn.deterministic = True # Para reproducibilidad
# torch.backends.cudnn.benchmark = False # Para reproducibilidad
def build_label_mappings(dataset_path):
label2id = {}
id2label = {}
label_id = 0
for label_dir in os.listdir(dataset_path):
if os.path.isdir(os.path.join(dataset_path, label_dir)):
label2id[label_dir] = label_id
id2label[label_id] = label_dir
label_id += 1
return label2id, id2label
def compute_class_weights(labels):
class_counts = Counter(labels)
total_samples = len(labels)
class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
return [class_weights[label] for label in labels]
def create_dataloader(dataset_path, filter_white_noise, undersample_normal, test_size=0.2, shuffle=True, pin_memory=True):
label2id, id2label = build_label_mappings(dataset_path)
dataset = AudioDataset(dataset_path, label2id, filter_white_noise, undersample_normal)
dataset_size = len(dataset)
indices = list(range(dataset_size))
random.shuffle(indices)
split_idx = int(dataset_size * (1 - test_size))
train_indices = indices[:split_idx]
test_indices = indices[split_idx:]
train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)
labels = [dataset.labels[i] for i in train_indices]
class_weights = compute_class_weights(labels)
sampler = WeightedRandomSampler(
weights=class_weights,
num_samples=len(train_dataset),
replacement=True
)
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory
)
test_dataloader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
)
return train_dataloader, test_dataloader, id2label
def load_model(model_path, id2label, num_labels):
config = HubertConfig.from_pretrained(
pretrained_model_name_or_path=model_path,
num_labels=num_labels,
id2label=id2label,
finetuning_task="audio-classification"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HubertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=model_path,
config=config,
torch_dtype=torch.float32, # TODO: Comprobar si se necesita float32 y ver si se puede cambiar por float16
)
model.to(device)
return model
def train_params(dataset_path, filter_white_noise, undersample_normal):
train_dataloader, test_dataloader, id2label = create_dataloader(dataset_path, filter_white_noise, undersample_normal)
model = load_model(MODEL, id2label, num_labels=len(id2label))
return model, train_dataloader, test_dataloader, id2label
def predict_params(dataset_path, model_path, filter_white_noise, undersample_normal):
_, _, id2label = create_dataloader(dataset_path, filter_white_noise, undersample_normal)
model = load_model(model_path, id2label, num_labels=len(id2label))
return model, id2label
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
acc = accuracy_score(labels, preds)
cm = confusion_matrix(labels, preds)
return {
'accuracy': acc,
'f1': f1,
'precision': precision,
'recall': recall,
'confusion_matrix': cm.tolist()
}
def main(training_args, output_dir, dataset_path, filter_white_noise, undersample_normal):
seed_everything()
model, train_dataloader, test_dataloader, id2label = train_params(dataset_path, filter_white_noise, undersample_normal)
early_stopping_callback = EarlyStoppingCallback(
early_stopping_patience=5,
early_stopping_threshold=0.001
)
trainer = Trainer(
model=model,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataloader.dataset,
eval_dataset=test_dataloader.dataset,
callbacks=[TensorBoardCallback, early_stopping_callback]
)
torch.cuda.empty_cache() # liberar memoria de la GPU
trainer.train() # resume_from_checkpoint para continuar el train
# trainer.save_model(output_dir) # Guardar modelo local.
os.makedirs(output_dir, exist_ok=True)
trainer.save_model(output_dir) # Guardar modelo local.
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")
trainer.push_to_hub(token=token) # Subir modelo a perfil
upload_folder(repo_id=f"A-POR-LOS-8000/{output_dir}", folder_path=output_dir, token=token) # subir a organización y local
def predict(audio_path):
waveform, sample_rate = torchaudio.load(audio_path, normalize=True)
if sample_rate != SAMPLING_RATE:
resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
waveform = resampler(waveform)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6)
max_length = int(SAMPLING_RATE * MAX_DURATION)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length]
else:
waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1]))
inputs = FEATURE_EXTRACTOR(
waveform.squeeze(),
sampling_rate=SAMPLING_RATE,
return_tensors="pt",
)
with torch.no_grad():
logits = model(inputs.input_values.to(model.device)).logits
predicted_class_id = logits.argmax().item()
predicted_label = id2label[predicted_class_id]
return predicted_label, logits
test_samples = random.sample(test_dataloader.dataset.dataset.file_paths, 15)
for sample in test_samples:
predicted_label, logits = predict(sample)
print(f"File: {sample}")
print(f"Predicted label: {predicted_label}")
print(f"Logits: {logits}")
print("---")
def load_config(model_name):
with open(config_file, 'r') as f:
config = json.load(f)
model_config = config[model_name]
training_args = TrainingArguments(**model_config["training_args"])
model_config["training_args"] = training_args
return model_config
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--n", choices=["mon", "class"],
required=True, help="Elegir qué modelo entrenar"
)
args = parser.parse_args()
config = load_config(args.n)
training_args = config["training_args"]
output_dir = config["output_dir"]
dataset_path = config["dataset_path"]
if args.n == "mon":
filter_white_noise = False
undersample_normal = False
elif args.n == "class":
filter_white_noise = True
undersample_normal = True
main(training_args, output_dir, dataset_path, filter_white_noise, undersample_normal)