wav2vec / training_script.py
spycoder's picture
Rename traing_script.py to training_script.py
ec0dd01
raw
history blame
35 kB
# -*- coding: utf-8 -*-
"""CHULA Gino_Parkinson.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1XPgGZILiBbDji5G0dHoFV7OQaUwGM3HJ
"""
!pip install SoundFile transformers scikit-learn
from google.colab import drive
drive.mount('/content/drive')
import matplotlib.pyplot as plt
import numpy as np
import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from sklearn.model_selection import train_test_split
import re
from collections import Counter
from sklearn.metrics import classification_report
# Custom Dataset class
class DysarthriaDataset(Dataset):
def __init__(self, data, labels, max_length=100000):
self.data = data
self.labels = labels
self.max_length = max_length
self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
try:
wav_data, _ = sf.read(self.data[idx])
except:
print(f"Error opening file: {self.data[idx]}. Skipping...")
return self.__getitem__((idx + 1) % len(self.data))
inputs = self.processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.squeeze(0) # Squeeze the batch dimension
if self.max_length - input_values.shape[-1] > 0:
input_values = torch.cat([input_values, torch.zeros((self.max_length - input_values.shape[-1],))], dim=-1)
else:
input_values = input_values[:self.max_length]
# Remove unsqueezing the channel dimension
# input_values = input_values.unsqueeze(0)
# label = torch.zeros(32,dtype=torch.long)
# label[self.labels[idx]] = 1
### CHANGES: simply return the label as a single integer
return {"input_values": input_values}, self.labels[idx]
# return {"input_values": input_values, "audio_path": self.data[idx]}, self.labels[idx]
###
def train(model, dataloader, criterion, optimizer, device, loss_vals, epochs, current_epoch):
model.train()
running_loss = 0
for i, (inputs, labels) in enumerate(dataloader):
inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
labels = labels.to(device)
optimizer.zero_grad()
logits = model(**inputs).logits
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
# append loss value to list
loss_vals.append(loss.item())
running_loss += loss.item()
if i % 10 == 0: # Update the plot every 10 iterations
plt.clf() # Clear the previous plot
plt.plot(loss_vals)
plt.xlim([0, len(dataloader)*epochs])
plt.ylim([0, max(loss_vals) + 2])
plt.xlabel('Training Iterations')
plt.ylabel('Loss')
plt.title(f"Training Loss at Epoch {current_epoch + 1}")
plt.pause(0.001) # Pause to update the plot
avg_loss = running_loss / len(dataloader)
print(f"Average Loss after Epoch {current_epoch + 1}: {avg_loss}\n")
return avg_loss
def predict(model, file_path, processor, device, max_length=100000): ### CHANGES: added max_length as an argument.
model.eval()
with torch.no_grad():
wav_data, _ = sf.read(file_path)
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
# inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
### NEW CODES HERE
input_values = inputs.input_values.squeeze(0) # Squeeze the batch dimension
if max_length - input_values.shape[-1] > 0:
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
else:
input_values = input_values[:max_length]
input_values = input_values.unsqueeze(0).to(device)
inputs = {"input_values": input_values}
###
logits = model(**inputs).logits
# _, predicted = torch.max(logits, dim=0)
### NEW CODES HERE
# Remove the batch dimension.
logits = logits.squeeze()
predicted_class_id = torch.argmax(logits, dim=-1).item()
###
# return predicted.item()
return predicted_class_id
def evaluate(model, dataloader, criterion, device):
model.eval()
running_loss = 0
correct_predictions = 0
total_predictions = 0
wrong_files = []
all_labels = []
all_predictions = []
with torch.no_grad():
for inputs, labels in dataloader:
inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
labels = labels.to(device)
logits = model(**inputs).logits
loss = criterion(logits, labels)
running_loss += loss.item()
_, predicted = torch.max(logits, 1)
correct_predictions += (predicted == labels).sum().item()
total_predictions += labels.size(0)
wrong_idx = (predicted != labels).nonzero().squeeze().cpu().numpy()
if wrong_idx.ndim > 0:
for idx in wrong_idx:
wrong_files.append(dataloader.dataset.data[idx])
elif wrong_idx.size > 0:
wrong_files.append(dataloader.dataset.data[wrong_idx])
all_labels.extend(labels.cpu().numpy())
all_predictions.extend(predicted.cpu().numpy())
avg_loss = running_loss / len(dataloader)
accuracy = correct_predictions / total_predictions
return avg_loss, accuracy, wrong_files, np.array(all_labels), np.array(all_predictions)
def get_wav_files(base_path):
wav_files = []
for subject_folder in os.listdir(base_path):
subject_path = os.path.join(base_path, subject_folder)
if os.path.isdir(subject_path):
for wav_file in os.listdir(subject_path):
if wav_file.endswith('.wav'):
wav_files.append(os.path.join(subject_path, wav_file))
return wav_files
def get_torgo_data(dysarthria_path, non_dysarthria_path):
dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')]
non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')]
data = dysarthria_files + non_dysarthria_files
labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, stratify=labels)
train_data, val_data, train_labels, val_labels = train_test_split(train_data, train_labels, test_size=0.25, stratify=train_labels) # 0.25 x 0.8 = 0.2
return train_data, val_data, test_data, train_labels, val_labels, test_labels
dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS"
non_dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS"
dysarthria_files = get_wav_files(dysarthria_path)
non_dysarthria_files = get_wav_files(non_dysarthria_path)
data = dysarthria_files + non_dysarthria_files
labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, stratify=labels)
train_data, val_data, train_labels, val_labels = train_test_split(train_data, train_labels, test_size=0.25, stratify=train_labels) # 0.25 x 0.8 = 0.2
train_dataset = DysarthriaDataset(train_data, train_labels)
test_dataset = DysarthriaDataset(test_data, test_labels)
val_dataset = DysarthriaDataset(val_data, val_labels) # Create a validation dataset
train_loader = DataLoader(train_dataset, batch_size=16, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=16, drop_last=False)
validation_loader = DataLoader(val_dataset, batch_size=16, drop_last=False) # Use the validation dataset for the validation_loader
""" dysarthria_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/training"
non_dysarthria_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/training"
dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')]
non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')]
data = dysarthria_files + non_dysarthria_files
labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2)
train_dataset = DysarthriaDataset(train_data, train_labels)
test_dataset = DysarthriaDataset(test_data, test_labels)
train_loader = DataLoader(train_dataset, batch_size=8, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)
validation_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)
dysarthria_validation_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation"
non_dysarthria_validation_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/validation"
dysarthria_validation_files = [os.path.join(dysarthria_validation_path, f) for f in os.listdir(dysarthria_validation_path) if f.endswith('.wav')]
non_dysarthria_validation_files = [os.path.join(non_dysarthria_validation_path, f) for f in os.listdir(non_dysarthria_validation_path) if f.endswith('.wav')]
validation_data = dysarthria_validation_files + non_dysarthria_validation_files
validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
# model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
### NEW CODES
# It seems like the classifier layer is excluded from the model's forward method (i.e., model(**inputs)).
# That's why the number of labels in the output was 32 instead of 2 even when you had already changed the classifier.
# Instead, huggingface offers the option for loading the Wav2Vec model with an adjustable classifier head on top (by setting num_labels).
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
##
model_path = "/content/dysarthria_classifier1.pth"
if os.path.exists(model_path):
print(f"Loading saved model {model_path}")
model.load_state_dict(torch.load(model_path))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
from torch.optim.lr_scheduler import StepLR
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
# dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/testing"
# non_dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/testing"
#dysarthria_validation_files = get_wav_files(dysarthria_validation_path)
# non_dysarthria_validation_files = get_wav_files(non_dysarthria_validation_path)
#validation_data = dysarthria_validation_files + non_dysarthria_validation_files
#validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)
epochs = 10
plt.ion()
fig, ax = plt.subplots()
x_vals = np.arange(len(train_loader)*epochs)
loss_vals = []
for epoch in range(epochs):
train_loss = train(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")
val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
print("Misclassified Files")
for file_path in wrong_files:
print(file_path)
sentence_pattern = re.compile(r"_(\d+)\.wav$")
sentence_counts = Counter()
for file_path in wrong_files:
match = sentence_pattern.search(file_path)
if match:
sentence_number = int(match.group(1))
sentence_counts[sentence_number] += 1
total_wrong = len(wrong_files)
print("Total wrong files:", total_wrong)
print()
for sentence_number, count in sentence_counts.most_common():
percent = count / total_wrong * 100
print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
scheduler.step()
print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
predicted_label = predict(model, audio_file, train_dataset.processor, device)
print(f"Predicted label: {predicted_label}")
# Test on a specific audio file
##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
##predicted_label = predict(model, audio_file, train_dataset.processor, device)
##print(f"Predicted label: {predicted_label}")
torch.save(model.state_dict(), "dysarthria_classifier1.pth")
print("Predicting...")
"""#audio aug"""
!pip install audiomentations
from audiomentations import Compose, PitchShift, TimeStretch
augmenter = Compose([
PitchShift(min_semitones=-2, max_semitones=2, p=0.1),
TimeStretch(min_rate=0.9, max_rate=1.1, p=0.1)
])
# from torch.optim.lr_scheduler import StepLR
# scheduler = StepLR(optimizer, step_size=2, gamma=0.5)
from transformers import get_linear_schedule_with_warmup
# Define the total number of training steps
# It is usually the number of epochs times the number of batches per epoch
num_training_steps = epochs * len(train_loader)
# Define the number of warmup steps
# Usually set to a fraction of total_training_steps such as 0.1 * num_training_steps
num_warmup_steps = int(num_training_steps * 0.3)
# Create the learning rate scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
##
model_path = "/content/models/my_model_06/pytorch_model.bin"
if os.path.exists(model_path):
print(f"Loading saved model {model_path}")
model.load_state_dict(torch.load(model_path))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
import numpy as np
def trainaug(model, dataloader, criterion, optimizer, device, loss_vals, epochs, current_epoch):
model.train()
running_loss = 0
for i, (inputs, labels) in enumerate(dataloader):
inputs = {key: value.squeeze().to(device) for key, value in inputs.items() if torch.is_tensor(value)}
labels = labels.to(device)
# Apply audio augmentation
augmented_audio = []
for audio in inputs['input_values']:
# The augmenter works with numpy arrays, so we need to convert the tensor to a numpy array
audio_np = audio.cpu().numpy()
# Apply the augmentation
augmented = augmenter(audio_np, sample_rate=16000) # Assuming a sample rate of 16000Hz
augmented_audio.append(augmented)
# Convert the list of numpy arrays back to a tensor
inputs['input_values'] = torch.from_numpy(np.array(augmented_audio)).to(device)
optimizer.zero_grad()
logits = model(**inputs).logits
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
# append loss value to list
loss_vals.append(loss.item())
running_loss += loss.item()
if i % 10 == 0: # Update the plot every 10 iterations
plt.clf() # Clear the previous plot
plt.plot(loss_vals)
plt.xlim([0, len(dataloader)*epochs])
plt.ylim([0, max(loss_vals) + 2])
plt.xlabel('Training Iterations')
plt.ylabel('Loss')
plt.title(f"Training Loss at Epoch {current_epoch + 1}")
plt.pause(0.001) # Pause to update the plot
avg_loss = running_loss / len(dataloader)
print(f"Average Loss after Epoch {current_epoch + 1}: {avg_loss}\n")
return avg_loss
epochs = 20
plt.ion()
fig, ax = plt.subplots()
x_vals = np.arange(len(train_loader)*epochs)
loss_vals = []
for epoch in range(epochs):
train_loss = trainaug(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")
val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
print("Misclassified Files")
for file_path in wrong_files:
print(file_path)
sentence_pattern = re.compile(r"_(\d+)\.wav$")
sentence_counts = Counter()
for file_path in wrong_files:
match = sentence_pattern.search(file_path)
if match:
sentence_number = int(match.group(1))
sentence_counts[sentence_number] += 1
total_wrong = len(wrong_files)
print("Total wrong files:", total_wrong)
print()
for sentence_number, count in sentence_counts.most_common():
percent = count / total_wrong * 100
print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
scheduler.step()
print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
# predicted_label = predict(model, audio_file, train_dataset.processor, device)
# print(f"Predicted label: {predicted_label}")
# Test on a specific audio file
##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
##predicted_label = predict(model, audio_file, train_dataset.processor, device)
##print(f"Predicted label: {predicted_label}")
import re
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report
# Define the pattern to extract the sentence number from the file path
sentence_pattern = re.compile(r"_(\d+)\.wav$")
# Counter for the total number of each sentence type in the dataset
total_sentence_counts = Counter()
for file_path in train_loader.dataset.data: # Access the file paths directly
match = sentence_pattern.search(file_path)
if match:
sentence_number = int(match.group(1))
total_sentence_counts[sentence_number] += 1
epochs = 1
plt.ion()
fig, ax = plt.subplots()
x_vals = np.arange(len(train_loader)*epochs)
loss_vals = []
for epoch in range(epochs):
# train_loss = trainaug(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
# print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")
val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
print("Misclassified Files")
for file_path in wrong_files:
print(file_path)
# Counter for the misclassified sentences
sentence_counts = Counter()
for file_path in wrong_files:
match = sentence_pattern.search(file_path)
if match:
sentence_number = int(match.group(1))
sentence_counts[sentence_number] += 1
print("Total wrong files:", len(wrong_files))
print()
for sentence_number, count in sentence_counts.most_common():
percent = count / total_sentence_counts[sentence_number] * 100
print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
scheduler.step()
print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))
torch.save(model.state_dict(), "dysarthria_classifier2.pth")
save_dir = "models/my_model_06"
model.save_pretrained(save_dir)
"""## Cross testing
"""
# dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/testing"
# non_dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/testing"
#dysarthria_validation_files = get_wav_files(dysarthria_validation_path)
# non_dysarthria_validation_files = get_wav_files(non_dysarthria_validation_path)
#validation_data = dysarthria_validation_files + non_dysarthria_validation_files
#validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)
epochs = 1
plt.ion()
fig, ax = plt.subplots()
x_vals = np.arange(len(train_loader)*epochs)
loss_vals = []
for epoch in range(epochs):
#train_loss = train(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
#print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")
val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
print("Misclassified Files")
for file_path in wrong_files:
print(file_path)
sentence_pattern = re.compile(r"_(\d+)\.wav$")
sentence_counts = Counter()
for file_path in wrong_files:
match = sentence_pattern.search(file_path)
if match:
sentence_number = int(match.group(1))
sentence_counts[sentence_number] += 1
total_wrong = len(wrong_files)
print("Total wrong files:", total_wrong)
print()
for sentence_number, count in sentence_counts.most_common():
percent = count / total_wrong * 100
print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
scheduler.step()
print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
predicted_label = predict(model, audio_file, train_dataset.processor, device)
print(f"Predicted label: {predicted_label}")
# Test on a specific audio file
##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
##predicted_label = predict(model, audio_file, train_dataset.processor, device)
##print(f"Predicted label: {predicted_label}")
"""## DEBUGGING"""
dysarthria_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/training"
non_dysarthria_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/training"
dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')]
non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')]
data = dysarthria_files + non_dysarthria_files
labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2)
train_dataset = DysarthriaDataset(train_data, train_labels)
test_dataset = DysarthriaDataset(test_data, test_labels)
train_loader = DataLoader(train_dataset, batch_size=4, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=4, drop_last=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
# model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
max_length = 100_000
processor = train_dataset.processor
model.eval()
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
# predicted_label = predict(model, audio_file, train_dataset.processor, device)
# print(f"Predicted label: {predicted_label}")
wav_data, _ = sf.read(audio_file)
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.squeeze(0) # Squeeze the batch dimension
if max_length - input_values.shape[-1] > 0:
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
else:
input_values = input_values[:max_length]
input_values = input_values.unsqueeze(0).to(device)
input_values.shape
with torch.no_grad():
outputs = model(**{"input_values": input_values})
logits = outputs.logits
input_values.shape, logits.shape
import torch.nn.functional as F
# Remove the batch dimension.
logits = logits.squeeze()
predicted_class_id = torch.argmax(logits, dim=-1)
predicted_class_id
"""Cross testing
##origial code
"""
import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from sklearn.model_selection import train_test_split
# Custom Dataset class
class DysarthriaDataset(Dataset):
def __init__(self, data, labels, max_length=100000):
self.data = data
self.labels = labels
self.max_length = max_length
self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
try:
wav_data, _ = sf.read(self.data[idx])
except:
print(f"Error opening file: {self.data[idx]}. Skipping...")
return self.__getitem__((idx + 1) % len(self.data))
inputs = self.processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.squeeze(0) # Squeeze the batch dimension
if self.max_length - input_values.shape[-1] > 0:
input_values = torch.cat([input_values, torch.zeros((self.max_length - input_values.shape[-1],))], dim=-1)
else:
input_values = input_values[:self.max_length]
# Remove unsqueezing the channel dimension
# input_values = input_values.unsqueeze(0)
# label = torch.zeros(32,dtype=torch.long)
# label[self.labels[idx]] = 1
### CHANGES: simply return the label as a single integer
return {"input_values": input_values}, self.labels[idx]
###
def train(model, dataloader, criterion, optimizer, device, ax, loss_vals, x_vals, fig,train_loader,epochs):
model.train()
running_loss = 0
for i, (inputs, labels) in enumerate(dataloader):
inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
labels = labels.to(device)
optimizer.zero_grad()
logits = model(**inputs).logits
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
# append loss value to list
loss_vals.append(loss.item())
running_loss += loss.item()
if i:
# update plot
ax.clear()
ax.set_xlim([0, len(train_loader)*epochs])
ax.set_xlabel('Training Iterations')
ax.set_ylim([0, max(loss_vals) + 2])
ax.set_ylabel('Loss')
ax.plot(x_vals[:len(loss_vals)], loss_vals)
fig.canvas.draw()
plt.pause(0.001)
avg_loss = running_loss / len(dataloader)
print(avg_loss)
print("\n")
return avg_loss
def main():
dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/training"
non_dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/training"
dysarthria_files = get_wav_files(dysarthria_path)
non_dysarthria_files = get_wav_files(non_dysarthria_path)
data = dysarthria_files + non_dysarthria_files
labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2)
train_dataset = DysarthriaDataset(train_data, train_labels)
test_dataset = DysarthriaDataset(test_data, test_labels)
train_loader = DataLoader(train_dataset, batch_size=8, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)
validation_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
# model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
### NEW CODES
# It seems like the classifier layer is excluded from the model's forward method (i.e., model(**inputs)).
# That's why the number of labels in the output was 32 instead of 2 even when you had already changed the classifier.
# Instead, huggingface offers the option for loading the Wav2Vec model with an adjustable classifier head on top (by setting num_labels).
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
###
#model_path = "/content/dysarthria_classifier3.pth"
#if os.path.exists(model_path):
#print(f"Loading saved model {model_path}")
#model.load_state_dict(torch.load(model_path))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/testing"
non_dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/testing"
dysarthria_validation_files = get_wav_files(dysarthria_validation_path)
non_dysarthria_validation_files = get_wav_files(non_dysarthria_validation_path)
validation_data = dysarthria_validation_files + non_dysarthria_validation_files
validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)
epochs = 10
fig, ax = plt.subplots()
x_vals = np.arange(len(train_loader)*epochs)
loss_vals = []
nume = 1
for epoch in range(epochs):
train_loss = train(model, train_loader, criterion, optimizer, device, ax, loss_vals, x_vals, fig, train_loader, epoch+1)
print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")
val_loss, val_accuracy, wrong_files = evaluate(model, validation_loader, criterion, device)
print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
print("Misclassified Files")
for file_path in wrong_files:
print(file_path)
sentence_pattern = re.compile(r"_(\d+)\.wav$")
sentence_counts = Counter()
for file_path in wrong_files:
match = sentence_pattern.search(file_path)
if match:
sentence_number = int(match.group(1))
sentence_counts[sentence_number] += 1
total_wrong = len(wrong_files)
print("Total wrong files:", total_wrong)
print()
for sentence_number, count in sentence_counts.most_common():
percent = count / total_wrong * 100
print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
torch.save(model.state_dict(), "dysarthria_classifier4.pth")
print("Predicting...")
# Test on a specific audio file
##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
##predicted_label = predict(model, audio_file, train_dataset.processor, device)
##print(f"Predicted label: {predicted_label}")
def predict(model, file_path, processor, device, max_length=100000): ### CHANGES: added max_length as an argument.
model.eval()
with torch.no_grad():
wav_data, _ = sf.read(file_path)
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
# inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
### NEW CODES HERE
input_values = inputs.input_values.squeeze(0) # Squeeze the batch dimension
if max_length - input_values.shape[-1] > 0:
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
else:
input_values = input_values[:max_length]
input_values = input_values.unsqueeze(0).to(device)
inputs = {"input_values": input_values}
###
logits = model(**inputs).logits
# _, predicted = torch.max(logits, dim=0)
### NEW CODES HERE
# Remove the batch dimension.
logits = logits.squeeze()
predicted_class_id = torch.argmax(logits, dim=-1).item()
###
# return predicted.item()
return predicted_class_id
def evaluate(model, dataloader, criterion, device):
model.eval()
running_loss = 0
correct_predictions = 0
total_predictions = 0
wrong_files = []
with torch.no_grad():
for inputs, labels in dataloader:
inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
labels = labels.to(device)
logits = model(**inputs).logits
loss = criterion(logits, labels)
running_loss += loss.item()
_, predicted = torch.max(logits, 1)
correct_predictions += (predicted == labels).sum().item()
total_predictions += labels.size(0)
wrong_idx = (predicted != labels).nonzero().squeeze().cpu().numpy()
if wrong_idx.ndim > 0:
for idx in wrong_idx:
wrong_files.append(dataloader.dataset.data[idx])
elif wrong_idx.size > 0:
wrong_files.append(dataloader.dataset.data[wrong_idx])
avg_loss = running_loss / len(dataloader)
accuracy = correct_predictions / total_predictions
return avg_loss, accuracy, wrong_files
def get_wav_files(base_path):
wav_files = []
for subject_folder in os.listdir(base_path):
subject_path = os.path.join(base_path, subject_folder)
if os.path.isdir(subject_path):
for wav_file in os.listdir(subject_path):
if wav_file.endswith('.wav'):
wav_files.append(os.path.join(subject_path, wav_file))
return wav_files
if __name__ == "__main__":
main()