|
|
|
"""CHULA Gino_Parkinson.ipynb |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1XPgGZILiBbDji5G0dHoFV7OQaUwGM3HJ |
|
""" |
|
|
|
!pip install SoundFile transformers scikit-learn |
|
|
|
from google.colab import drive |
|
drive.mount('/content/drive') |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
import os |
|
import soundfile as sf |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification |
|
from sklearn.model_selection import train_test_split |
|
import re |
|
from collections import Counter |
|
from sklearn.metrics import classification_report |
|
|
|
|
|
class DysarthriaDataset(Dataset): |
|
def __init__(self, data, labels, max_length=100000): |
|
self.data = data |
|
self.labels = labels |
|
self.max_length = max_length |
|
self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
try: |
|
wav_data, _ = sf.read(self.data[idx]) |
|
except: |
|
print(f"Error opening file: {self.data[idx]}. Skipping...") |
|
return self.__getitem__((idx + 1) % len(self.data)) |
|
inputs = self.processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True) |
|
input_values = inputs.input_values.squeeze(0) |
|
if self.max_length - input_values.shape[-1] > 0: |
|
input_values = torch.cat([input_values, torch.zeros((self.max_length - input_values.shape[-1],))], dim=-1) |
|
else: |
|
input_values = input_values[:self.max_length] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"input_values": input_values}, self.labels[idx] |
|
|
|
|
|
|
|
def train(model, dataloader, criterion, optimizer, device, loss_vals, epochs, current_epoch): |
|
model.train() |
|
running_loss = 0 |
|
|
|
for i, (inputs, labels) in enumerate(dataloader): |
|
inputs = {key: value.squeeze().to(device) for key, value in inputs.items()} |
|
labels = labels.to(device) |
|
|
|
optimizer.zero_grad() |
|
logits = model(**inputs).logits |
|
loss = criterion(logits, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
loss_vals.append(loss.item()) |
|
running_loss += loss.item() |
|
|
|
if i % 10 == 0: |
|
plt.clf() |
|
plt.plot(loss_vals) |
|
plt.xlim([0, len(dataloader)*epochs]) |
|
plt.ylim([0, max(loss_vals) + 2]) |
|
plt.xlabel('Training Iterations') |
|
plt.ylabel('Loss') |
|
plt.title(f"Training Loss at Epoch {current_epoch + 1}") |
|
plt.pause(0.001) |
|
|
|
avg_loss = running_loss / len(dataloader) |
|
print(f"Average Loss after Epoch {current_epoch + 1}: {avg_loss}\n") |
|
return avg_loss |
|
|
|
def predict(model, file_path, processor, device, max_length=100000): |
|
model.eval() |
|
with torch.no_grad(): |
|
wav_data, _ = sf.read(file_path) |
|
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True) |
|
|
|
|
|
|
|
input_values = inputs.input_values.squeeze(0) |
|
if max_length - input_values.shape[-1] > 0: |
|
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1) |
|
else: |
|
input_values = input_values[:max_length] |
|
input_values = input_values.unsqueeze(0).to(device) |
|
inputs = {"input_values": input_values} |
|
|
|
|
|
logits = model(**inputs).logits |
|
|
|
|
|
|
|
|
|
logits = logits.squeeze() |
|
predicted_class_id = torch.argmax(logits, dim=-1).item() |
|
|
|
|
|
|
|
return predicted_class_id |
|
|
|
def evaluate(model, dataloader, criterion, device): |
|
model.eval() |
|
running_loss = 0 |
|
correct_predictions = 0 |
|
total_predictions = 0 |
|
wrong_files = [] |
|
all_labels = [] |
|
all_predictions = [] |
|
|
|
with torch.no_grad(): |
|
for inputs, labels in dataloader: |
|
inputs = {key: value.squeeze().to(device) for key, value in inputs.items()} |
|
labels = labels.to(device) |
|
|
|
logits = model(**inputs).logits |
|
loss = criterion(logits, labels) |
|
running_loss += loss.item() |
|
|
|
_, predicted = torch.max(logits, 1) |
|
correct_predictions += (predicted == labels).sum().item() |
|
total_predictions += labels.size(0) |
|
|
|
wrong_idx = (predicted != labels).nonzero().squeeze().cpu().numpy() |
|
if wrong_idx.ndim > 0: |
|
for idx in wrong_idx: |
|
wrong_files.append(dataloader.dataset.data[idx]) |
|
elif wrong_idx.size > 0: |
|
wrong_files.append(dataloader.dataset.data[wrong_idx]) |
|
|
|
all_labels.extend(labels.cpu().numpy()) |
|
all_predictions.extend(predicted.cpu().numpy()) |
|
|
|
avg_loss = running_loss / len(dataloader) |
|
accuracy = correct_predictions / total_predictions |
|
|
|
return avg_loss, accuracy, wrong_files, np.array(all_labels), np.array(all_predictions) |
|
|
|
def get_wav_files(base_path): |
|
wav_files = [] |
|
for subject_folder in os.listdir(base_path): |
|
subject_path = os.path.join(base_path, subject_folder) |
|
if os.path.isdir(subject_path): |
|
for wav_file in os.listdir(subject_path): |
|
if wav_file.endswith('.wav'): |
|
wav_files.append(os.path.join(subject_path, wav_file)) |
|
|
|
return wav_files |
|
|
|
def get_torgo_data(dysarthria_path, non_dysarthria_path): |
|
dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')] |
|
non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')] |
|
|
|
data = dysarthria_files + non_dysarthria_files |
|
labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files) |
|
|
|
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, stratify=labels) |
|
train_data, val_data, train_labels, val_labels = train_test_split(train_data, train_labels, test_size=0.25, stratify=train_labels) |
|
|
|
return train_data, val_data, test_data, train_labels, val_labels, test_labels |
|
|
|
dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS" |
|
non_dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS" |
|
|
|
dysarthria_files = get_wav_files(dysarthria_path) |
|
non_dysarthria_files = get_wav_files(non_dysarthria_path) |
|
|
|
|
|
|
|
data = dysarthria_files + non_dysarthria_files |
|
labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files) |
|
|
|
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, stratify=labels) |
|
train_data, val_data, train_labels, val_labels = train_test_split(train_data, train_labels, test_size=0.25, stratify=train_labels) |
|
train_dataset = DysarthriaDataset(train_data, train_labels) |
|
test_dataset = DysarthriaDataset(test_data, test_labels) |
|
val_dataset = DysarthriaDataset(val_data, val_labels) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=16, drop_last=False) |
|
test_loader = DataLoader(test_dataset, batch_size=16, drop_last=False) |
|
validation_loader = DataLoader(val_dataset, batch_size=16, drop_last=False) |
|
|
|
""" dysarthria_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/training" |
|
non_dysarthria_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/training" |
|
|
|
dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')] |
|
non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')] |
|
|
|
data = dysarthria_files + non_dysarthria_files |
|
labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files) |
|
|
|
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2) |
|
|
|
train_dataset = DysarthriaDataset(train_data, train_labels) |
|
test_dataset = DysarthriaDataset(test_data, test_labels) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=8, drop_last=True) |
|
test_loader = DataLoader(test_dataset, batch_size=8, drop_last=True) |
|
validation_loader = DataLoader(test_dataset, batch_size=8, drop_last=True) |
|
|
|
dysarthria_validation_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation" |
|
non_dysarthria_validation_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/validation" |
|
|
|
dysarthria_validation_files = [os.path.join(dysarthria_validation_path, f) for f in os.listdir(dysarthria_validation_path) if f.endswith('.wav')] |
|
non_dysarthria_validation_files = [os.path.join(non_dysarthria_validation_path, f) for f in os.listdir(non_dysarthria_validation_path) if f.endswith('.wav')] |
|
|
|
validation_data = dysarthria_validation_files + non_dysarthria_validation_files |
|
validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)""" |
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device) |
|
|
|
model_path = "/content/dysarthria_classifier1.pth" |
|
if os.path.exists(model_path): |
|
print(f"Loading saved model {model_path}") |
|
model.load_state_dict(torch.load(model_path)) |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) |
|
|
|
from torch.optim.lr_scheduler import StepLR |
|
|
|
scheduler = StepLR(optimizer, step_size=5, gamma=0.1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
epochs = 10 |
|
plt.ion() |
|
fig, ax = plt.subplots() |
|
x_vals = np.arange(len(train_loader)*epochs) |
|
loss_vals = [] |
|
for epoch in range(epochs): |
|
train_loss = train(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch) |
|
print(f"Epoch {epoch + 1}, Train Loss: {train_loss}") |
|
|
|
val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device) |
|
print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}") |
|
print("Misclassified Files") |
|
for file_path in wrong_files: |
|
print(file_path) |
|
|
|
|
|
sentence_pattern = re.compile(r"_(\d+)\.wav$") |
|
|
|
sentence_counts = Counter() |
|
for file_path in wrong_files: |
|
match = sentence_pattern.search(file_path) |
|
if match: |
|
sentence_number = int(match.group(1)) |
|
sentence_counts[sentence_number] += 1 |
|
|
|
total_wrong = len(wrong_files) |
|
print("Total wrong files:", total_wrong) |
|
print() |
|
|
|
for sentence_number, count in sentence_counts.most_common(): |
|
percent = count / total_wrong * 100 |
|
print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)") |
|
scheduler.step() |
|
print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria'])) |
|
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav" |
|
predicted_label = predict(model, audio_file, train_dataset.processor, device) |
|
print(f"Predicted label: {predicted_label}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), "dysarthria_classifier1.pth") |
|
print("Predicting...") |
|
|
|
"""#audio aug""" |
|
|
|
!pip install audiomentations |
|
from audiomentations import Compose, PitchShift, TimeStretch |
|
|
|
augmenter = Compose([ |
|
PitchShift(min_semitones=-2, max_semitones=2, p=0.1), |
|
TimeStretch(min_rate=0.9, max_rate=1.1, p=0.1) |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
from transformers import get_linear_schedule_with_warmup |
|
|
|
|
|
|
|
num_training_steps = epochs * len(train_loader) |
|
|
|
|
|
|
|
num_warmup_steps = int(num_training_steps * 0.3) |
|
|
|
|
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) |
|
|
|
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device) |
|
|
|
model_path = "/content/models/my_model_06/pytorch_model.bin" |
|
if os.path.exists(model_path): |
|
print(f"Loading saved model {model_path}") |
|
model.load_state_dict(torch.load(model_path)) |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) |
|
|
|
import numpy as np |
|
|
|
def trainaug(model, dataloader, criterion, optimizer, device, loss_vals, epochs, current_epoch): |
|
model.train() |
|
running_loss = 0 |
|
|
|
for i, (inputs, labels) in enumerate(dataloader): |
|
inputs = {key: value.squeeze().to(device) for key, value in inputs.items() if torch.is_tensor(value)} |
|
labels = labels.to(device) |
|
|
|
|
|
augmented_audio = [] |
|
for audio in inputs['input_values']: |
|
|
|
audio_np = audio.cpu().numpy() |
|
|
|
|
|
augmented = augmenter(audio_np, sample_rate=16000) |
|
|
|
augmented_audio.append(augmented) |
|
|
|
|
|
inputs['input_values'] = torch.from_numpy(np.array(augmented_audio)).to(device) |
|
|
|
optimizer.zero_grad() |
|
logits = model(**inputs).logits |
|
loss = criterion(logits, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
loss_vals.append(loss.item()) |
|
running_loss += loss.item() |
|
|
|
if i % 10 == 0: |
|
plt.clf() |
|
plt.plot(loss_vals) |
|
plt.xlim([0, len(dataloader)*epochs]) |
|
plt.ylim([0, max(loss_vals) + 2]) |
|
plt.xlabel('Training Iterations') |
|
plt.ylabel('Loss') |
|
plt.title(f"Training Loss at Epoch {current_epoch + 1}") |
|
plt.pause(0.001) |
|
|
|
avg_loss = running_loss / len(dataloader) |
|
print(f"Average Loss after Epoch {current_epoch + 1}: {avg_loss}\n") |
|
return avg_loss |
|
|
|
epochs = 20 |
|
plt.ion() |
|
fig, ax = plt.subplots() |
|
x_vals = np.arange(len(train_loader)*epochs) |
|
loss_vals = [] |
|
for epoch in range(epochs): |
|
train_loss = trainaug(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch) |
|
print(f"Epoch {epoch + 1}, Train Loss: {train_loss}") |
|
|
|
val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device) |
|
print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}") |
|
print("Misclassified Files") |
|
for file_path in wrong_files: |
|
print(file_path) |
|
|
|
|
|
sentence_pattern = re.compile(r"_(\d+)\.wav$") |
|
|
|
sentence_counts = Counter() |
|
for file_path in wrong_files: |
|
match = sentence_pattern.search(file_path) |
|
if match: |
|
sentence_number = int(match.group(1)) |
|
sentence_counts[sentence_number] += 1 |
|
|
|
total_wrong = len(wrong_files) |
|
print("Total wrong files:", total_wrong) |
|
print() |
|
|
|
for sentence_number, count in sentence_counts.most_common(): |
|
percent = count / total_wrong * 100 |
|
print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)") |
|
scheduler.step() |
|
print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria'])) |
|
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
from collections import Counter |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from sklearn.metrics import classification_report |
|
|
|
|
|
sentence_pattern = re.compile(r"_(\d+)\.wav$") |
|
|
|
|
|
total_sentence_counts = Counter() |
|
|
|
for file_path in train_loader.dataset.data: |
|
match = sentence_pattern.search(file_path) |
|
if match: |
|
sentence_number = int(match.group(1)) |
|
total_sentence_counts[sentence_number] += 1 |
|
|
|
epochs = 1 |
|
plt.ion() |
|
fig, ax = plt.subplots() |
|
x_vals = np.arange(len(train_loader)*epochs) |
|
loss_vals = [] |
|
|
|
for epoch in range(epochs): |
|
|
|
|
|
|
|
val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device) |
|
print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}") |
|
print("Misclassified Files") |
|
for file_path in wrong_files: |
|
print(file_path) |
|
|
|
|
|
sentence_counts = Counter() |
|
|
|
for file_path in wrong_files: |
|
match = sentence_pattern.search(file_path) |
|
if match: |
|
sentence_number = int(match.group(1)) |
|
sentence_counts[sentence_number] += 1 |
|
|
|
print("Total wrong files:", len(wrong_files)) |
|
print() |
|
|
|
for sentence_number, count in sentence_counts.most_common(): |
|
percent = count / total_sentence_counts[sentence_number] * 100 |
|
print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)") |
|
|
|
scheduler.step() |
|
print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria'])) |
|
|
|
torch.save(model.state_dict(), "dysarthria_classifier2.pth") |
|
|
|
save_dir = "models/my_model_06" |
|
model.save_pretrained(save_dir) |
|
|
|
"""## Cross testing |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
epochs = 1 |
|
plt.ion() |
|
fig, ax = plt.subplots() |
|
x_vals = np.arange(len(train_loader)*epochs) |
|
loss_vals = [] |
|
for epoch in range(epochs): |
|
|
|
|
|
|
|
val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device) |
|
print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}") |
|
print("Misclassified Files") |
|
for file_path in wrong_files: |
|
print(file_path) |
|
|
|
|
|
sentence_pattern = re.compile(r"_(\d+)\.wav$") |
|
|
|
sentence_counts = Counter() |
|
for file_path in wrong_files: |
|
match = sentence_pattern.search(file_path) |
|
if match: |
|
sentence_number = int(match.group(1)) |
|
sentence_counts[sentence_number] += 1 |
|
|
|
total_wrong = len(wrong_files) |
|
print("Total wrong files:", total_wrong) |
|
print() |
|
|
|
for sentence_number, count in sentence_counts.most_common(): |
|
percent = count / total_wrong * 100 |
|
print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)") |
|
scheduler.step() |
|
print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria'])) |
|
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav" |
|
predicted_label = predict(model, audio_file, train_dataset.processor, device) |
|
print(f"Predicted label: {predicted_label}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""## DEBUGGING""" |
|
|
|
dysarthria_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/training" |
|
non_dysarthria_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/training" |
|
|
|
dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')] |
|
non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')] |
|
|
|
data = dysarthria_files + non_dysarthria_files |
|
labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files) |
|
|
|
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2) |
|
|
|
train_dataset = DysarthriaDataset(train_data, train_labels) |
|
test_dataset = DysarthriaDataset(test_data, test_labels) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=4, drop_last=True) |
|
test_loader = DataLoader(test_dataset, batch_size=4, drop_last=True) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device) |
|
|
|
max_length = 100_000 |
|
processor = train_dataset.processor |
|
|
|
model.eval() |
|
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav" |
|
|
|
|
|
|
|
wav_data, _ = sf.read(audio_file) |
|
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True) |
|
input_values = inputs.input_values.squeeze(0) |
|
if max_length - input_values.shape[-1] > 0: |
|
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1) |
|
else: |
|
input_values = input_values[:max_length] |
|
|
|
input_values = input_values.unsqueeze(0).to(device) |
|
input_values.shape |
|
|
|
with torch.no_grad(): |
|
outputs = model(**{"input_values": input_values}) |
|
logits = outputs.logits |
|
|
|
input_values.shape, logits.shape |
|
|
|
import torch.nn.functional as F |
|
|
|
logits = logits.squeeze() |
|
predicted_class_id = torch.argmax(logits, dim=-1) |
|
predicted_class_id |
|
|
|
"""Cross testing |
|
|
|
##origial code |
|
""" |
|
|
|
import os |
|
import soundfile as sf |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification |
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
class DysarthriaDataset(Dataset): |
|
def __init__(self, data, labels, max_length=100000): |
|
self.data = data |
|
self.labels = labels |
|
self.max_length = max_length |
|
self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
try: |
|
wav_data, _ = sf.read(self.data[idx]) |
|
except: |
|
print(f"Error opening file: {self.data[idx]}. Skipping...") |
|
return self.__getitem__((idx + 1) % len(self.data)) |
|
inputs = self.processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True) |
|
input_values = inputs.input_values.squeeze(0) |
|
if self.max_length - input_values.shape[-1] > 0: |
|
input_values = torch.cat([input_values, torch.zeros((self.max_length - input_values.shape[-1],))], dim=-1) |
|
else: |
|
input_values = input_values[:self.max_length] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"input_values": input_values}, self.labels[idx] |
|
|
|
|
|
|
|
def train(model, dataloader, criterion, optimizer, device, ax, loss_vals, x_vals, fig,train_loader,epochs): |
|
model.train() |
|
running_loss = 0 |
|
|
|
for i, (inputs, labels) in enumerate(dataloader): |
|
inputs = {key: value.squeeze().to(device) for key, value in inputs.items()} |
|
labels = labels.to(device) |
|
|
|
optimizer.zero_grad() |
|
logits = model(**inputs).logits |
|
loss = criterion(logits, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
loss_vals.append(loss.item()) |
|
running_loss += loss.item() |
|
|
|
if i: |
|
|
|
ax.clear() |
|
ax.set_xlim([0, len(train_loader)*epochs]) |
|
ax.set_xlabel('Training Iterations') |
|
ax.set_ylim([0, max(loss_vals) + 2]) |
|
ax.set_ylabel('Loss') |
|
ax.plot(x_vals[:len(loss_vals)], loss_vals) |
|
fig.canvas.draw() |
|
plt.pause(0.001) |
|
|
|
avg_loss = running_loss / len(dataloader) |
|
print(avg_loss) |
|
print("\n") |
|
return avg_loss |
|
|
|
|
|
|
|
def main(): |
|
dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/training" |
|
non_dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/training" |
|
|
|
dysarthria_files = get_wav_files(dysarthria_path) |
|
non_dysarthria_files = get_wav_files(non_dysarthria_path) |
|
|
|
data = dysarthria_files + non_dysarthria_files |
|
labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files) |
|
|
|
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2) |
|
|
|
train_dataset = DysarthriaDataset(train_data, train_labels) |
|
test_dataset = DysarthriaDataset(test_data, test_labels) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=8, drop_last=True) |
|
test_loader = DataLoader(test_dataset, batch_size=8, drop_last=True) |
|
validation_loader = DataLoader(test_dataset, batch_size=8, drop_last=True) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5) |
|
dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/testing" |
|
non_dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/testing" |
|
|
|
dysarthria_validation_files = get_wav_files(dysarthria_validation_path) |
|
non_dysarthria_validation_files = get_wav_files(non_dysarthria_validation_path) |
|
|
|
validation_data = dysarthria_validation_files + non_dysarthria_validation_files |
|
validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files) |
|
|
|
epochs = 10 |
|
fig, ax = plt.subplots() |
|
x_vals = np.arange(len(train_loader)*epochs) |
|
loss_vals = [] |
|
nume = 1 |
|
for epoch in range(epochs): |
|
train_loss = train(model, train_loader, criterion, optimizer, device, ax, loss_vals, x_vals, fig, train_loader, epoch+1) |
|
print(f"Epoch {epoch + 1}, Train Loss: {train_loss}") |
|
|
|
val_loss, val_accuracy, wrong_files = evaluate(model, validation_loader, criterion, device) |
|
print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}") |
|
print("Misclassified Files") |
|
for file_path in wrong_files: |
|
print(file_path) |
|
|
|
|
|
sentence_pattern = re.compile(r"_(\d+)\.wav$") |
|
|
|
sentence_counts = Counter() |
|
for file_path in wrong_files: |
|
match = sentence_pattern.search(file_path) |
|
if match: |
|
sentence_number = int(match.group(1)) |
|
sentence_counts[sentence_number] += 1 |
|
|
|
total_wrong = len(wrong_files) |
|
print("Total wrong files:", total_wrong) |
|
print() |
|
|
|
for sentence_number, count in sentence_counts.most_common(): |
|
percent = count / total_wrong * 100 |
|
print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)") |
|
|
|
|
|
torch.save(model.state_dict(), "dysarthria_classifier4.pth") |
|
print("Predicting...") |
|
|
|
|
|
|
|
|
|
|
|
def predict(model, file_path, processor, device, max_length=100000): |
|
model.eval() |
|
with torch.no_grad(): |
|
wav_data, _ = sf.read(file_path) |
|
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True) |
|
|
|
|
|
|
|
input_values = inputs.input_values.squeeze(0) |
|
if max_length - input_values.shape[-1] > 0: |
|
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1) |
|
else: |
|
input_values = input_values[:max_length] |
|
input_values = input_values.unsqueeze(0).to(device) |
|
inputs = {"input_values": input_values} |
|
|
|
|
|
logits = model(**inputs).logits |
|
|
|
|
|
|
|
|
|
logits = logits.squeeze() |
|
predicted_class_id = torch.argmax(logits, dim=-1).item() |
|
|
|
|
|
|
|
return predicted_class_id |
|
def evaluate(model, dataloader, criterion, device): |
|
model.eval() |
|
running_loss = 0 |
|
correct_predictions = 0 |
|
total_predictions = 0 |
|
wrong_files = [] |
|
with torch.no_grad(): |
|
for inputs, labels in dataloader: |
|
inputs = {key: value.squeeze().to(device) for key, value in inputs.items()} |
|
labels = labels.to(device) |
|
|
|
logits = model(**inputs).logits |
|
loss = criterion(logits, labels) |
|
running_loss += loss.item() |
|
|
|
_, predicted = torch.max(logits, 1) |
|
correct_predictions += (predicted == labels).sum().item() |
|
total_predictions += labels.size(0) |
|
|
|
wrong_idx = (predicted != labels).nonzero().squeeze().cpu().numpy() |
|
if wrong_idx.ndim > 0: |
|
for idx in wrong_idx: |
|
wrong_files.append(dataloader.dataset.data[idx]) |
|
elif wrong_idx.size > 0: |
|
wrong_files.append(dataloader.dataset.data[wrong_idx]) |
|
|
|
|
|
avg_loss = running_loss / len(dataloader) |
|
accuracy = correct_predictions / total_predictions |
|
return avg_loss, accuracy, wrong_files |
|
|
|
|
|
|
|
def get_wav_files(base_path): |
|
wav_files = [] |
|
for subject_folder in os.listdir(base_path): |
|
subject_path = os.path.join(base_path, subject_folder) |
|
if os.path.isdir(subject_path): |
|
for wav_file in os.listdir(subject_path): |
|
if wav_file.endswith('.wav'): |
|
wav_files.append(os.path.join(subject_path, wav_file)) |
|
return wav_files |
|
if __name__ == "__main__": |
|
main() |