spycoder commited on
Commit
2642927
1 Parent(s): 73cab25

Upload chula_gino_parkinson.py

Browse files
Files changed (1) hide show
  1. chula_gino_parkinson.py +881 -0
chula_gino_parkinson.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """CHULA Gino_Parkinson.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1XPgGZILiBbDji5G0dHoFV7OQaUwGM3HJ
8
+ """
9
+
10
+ !pip install SoundFile transformers scikit-learn
11
+
12
+ from google.colab import drive
13
+ drive.mount('/content/drive')
14
+
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+
18
+ import os
19
+ import soundfile as sf
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.utils.data import Dataset, DataLoader
24
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
25
+ from sklearn.model_selection import train_test_split
26
+ import re
27
+ from collections import Counter
28
+ from sklearn.metrics import classification_report
29
+
30
+ # Custom Dataset class
31
+ class DysarthriaDataset(Dataset):
32
+ def __init__(self, data, labels, max_length=100000):
33
+ self.data = data
34
+ self.labels = labels
35
+ self.max_length = max_length
36
+ self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
37
+
38
+ def __len__(self):
39
+ return len(self.data)
40
+
41
+ def __getitem__(self, idx):
42
+ try:
43
+ wav_data, _ = sf.read(self.data[idx])
44
+ except:
45
+ print(f"Error opening file: {self.data[idx]}. Skipping...")
46
+ return self.__getitem__((idx + 1) % len(self.data))
47
+ inputs = self.processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
48
+ input_values = inputs.input_values.squeeze(0) # Squeeze the batch dimension
49
+ if self.max_length - input_values.shape[-1] > 0:
50
+ input_values = torch.cat([input_values, torch.zeros((self.max_length - input_values.shape[-1],))], dim=-1)
51
+ else:
52
+ input_values = input_values[:self.max_length]
53
+
54
+ # Remove unsqueezing the channel dimension
55
+ # input_values = input_values.unsqueeze(0)
56
+
57
+ # label = torch.zeros(32,dtype=torch.long)
58
+ # label[self.labels[idx]] = 1
59
+
60
+ ### CHANGES: simply return the label as a single integer
61
+ return {"input_values": input_values}, self.labels[idx]
62
+ # return {"input_values": input_values, "audio_path": self.data[idx]}, self.labels[idx]
63
+ ###
64
+
65
+ def train(model, dataloader, criterion, optimizer, device, loss_vals, epochs, current_epoch):
66
+ model.train()
67
+ running_loss = 0
68
+
69
+ for i, (inputs, labels) in enumerate(dataloader):
70
+ inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
71
+ labels = labels.to(device)
72
+
73
+ optimizer.zero_grad()
74
+ logits = model(**inputs).logits
75
+ loss = criterion(logits, labels)
76
+ loss.backward()
77
+ optimizer.step()
78
+
79
+ # append loss value to list
80
+ loss_vals.append(loss.item())
81
+ running_loss += loss.item()
82
+
83
+ if i % 10 == 0: # Update the plot every 10 iterations
84
+ plt.clf() # Clear the previous plot
85
+ plt.plot(loss_vals)
86
+ plt.xlim([0, len(dataloader)*epochs])
87
+ plt.ylim([0, max(loss_vals) + 2])
88
+ plt.xlabel('Training Iterations')
89
+ plt.ylabel('Loss')
90
+ plt.title(f"Training Loss at Epoch {current_epoch + 1}")
91
+ plt.pause(0.001) # Pause to update the plot
92
+
93
+ avg_loss = running_loss / len(dataloader)
94
+ print(f"Average Loss after Epoch {current_epoch + 1}: {avg_loss}\n")
95
+ return avg_loss
96
+
97
+ def predict(model, file_path, processor, device, max_length=100000): ### CHANGES: added max_length as an argument.
98
+ model.eval()
99
+ with torch.no_grad():
100
+ wav_data, _ = sf.read(file_path)
101
+ inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
102
+ # inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
103
+
104
+ ### NEW CODES HERE
105
+ input_values = inputs.input_values.squeeze(0) # Squeeze the batch dimension
106
+ if max_length - input_values.shape[-1] > 0:
107
+ input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
108
+ else:
109
+ input_values = input_values[:max_length]
110
+ input_values = input_values.unsqueeze(0).to(device)
111
+ inputs = {"input_values": input_values}
112
+ ###
113
+
114
+ logits = model(**inputs).logits
115
+ # _, predicted = torch.max(logits, dim=0)
116
+
117
+ ### NEW CODES HERE
118
+ # Remove the batch dimension.
119
+ logits = logits.squeeze()
120
+ predicted_class_id = torch.argmax(logits, dim=-1).item()
121
+ ###
122
+
123
+ # return predicted.item()
124
+ return predicted_class_id
125
+
126
+ def evaluate(model, dataloader, criterion, device):
127
+ model.eval()
128
+ running_loss = 0
129
+ correct_predictions = 0
130
+ total_predictions = 0
131
+ wrong_files = []
132
+ all_labels = []
133
+ all_predictions = []
134
+
135
+ with torch.no_grad():
136
+ for inputs, labels in dataloader:
137
+ inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
138
+ labels = labels.to(device)
139
+
140
+ logits = model(**inputs).logits
141
+ loss = criterion(logits, labels)
142
+ running_loss += loss.item()
143
+
144
+ _, predicted = torch.max(logits, 1)
145
+ correct_predictions += (predicted == labels).sum().item()
146
+ total_predictions += labels.size(0)
147
+
148
+ wrong_idx = (predicted != labels).nonzero().squeeze().cpu().numpy()
149
+ if wrong_idx.ndim > 0:
150
+ for idx in wrong_idx:
151
+ wrong_files.append(dataloader.dataset.data[idx])
152
+ elif wrong_idx.size > 0:
153
+ wrong_files.append(dataloader.dataset.data[wrong_idx])
154
+
155
+ all_labels.extend(labels.cpu().numpy())
156
+ all_predictions.extend(predicted.cpu().numpy())
157
+
158
+ avg_loss = running_loss / len(dataloader)
159
+ accuracy = correct_predictions / total_predictions
160
+
161
+ return avg_loss, accuracy, wrong_files, np.array(all_labels), np.array(all_predictions)
162
+
163
+ def get_wav_files(base_path):
164
+ wav_files = []
165
+ for subject_folder in os.listdir(base_path):
166
+ subject_path = os.path.join(base_path, subject_folder)
167
+ if os.path.isdir(subject_path):
168
+ for wav_file in os.listdir(subject_path):
169
+ if wav_file.endswith('.wav'):
170
+ wav_files.append(os.path.join(subject_path, wav_file))
171
+
172
+ return wav_files
173
+
174
+ def get_torgo_data(dysarthria_path, non_dysarthria_path):
175
+ dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')]
176
+ non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')]
177
+
178
+ data = dysarthria_files + non_dysarthria_files
179
+ labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)
180
+
181
+ train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, stratify=labels)
182
+ train_data, val_data, train_labels, val_labels = train_test_split(train_data, train_labels, test_size=0.25, stratify=train_labels) # 0.25 x 0.8 = 0.2
183
+
184
+ return train_data, val_data, test_data, train_labels, val_labels, test_labels
185
+
186
+ dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS"
187
+ non_dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS"
188
+
189
+ dysarthria_files = get_wav_files(dysarthria_path)
190
+ non_dysarthria_files = get_wav_files(non_dysarthria_path)
191
+
192
+
193
+
194
+ data = dysarthria_files + non_dysarthria_files
195
+ labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)
196
+
197
+ train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, stratify=labels)
198
+ train_data, val_data, train_labels, val_labels = train_test_split(train_data, train_labels, test_size=0.25, stratify=train_labels) # 0.25 x 0.8 = 0.2
199
+ train_dataset = DysarthriaDataset(train_data, train_labels)
200
+ test_dataset = DysarthriaDataset(test_data, test_labels)
201
+ val_dataset = DysarthriaDataset(val_data, val_labels) # Create a validation dataset
202
+
203
+ train_loader = DataLoader(train_dataset, batch_size=16, drop_last=False)
204
+ test_loader = DataLoader(test_dataset, batch_size=16, drop_last=False)
205
+ validation_loader = DataLoader(val_dataset, batch_size=16, drop_last=False) # Use the validation dataset for the validation_loader
206
+
207
+ """ dysarthria_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/training"
208
+ non_dysarthria_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/training"
209
+
210
+ dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')]
211
+ non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')]
212
+
213
+ data = dysarthria_files + non_dysarthria_files
214
+ labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)
215
+
216
+ train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2)
217
+
218
+ train_dataset = DysarthriaDataset(train_data, train_labels)
219
+ test_dataset = DysarthriaDataset(test_data, test_labels)
220
+
221
+ train_loader = DataLoader(train_dataset, batch_size=8, drop_last=True)
222
+ test_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)
223
+ validation_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)
224
+
225
+ dysarthria_validation_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation"
226
+ non_dysarthria_validation_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/validation"
227
+
228
+ dysarthria_validation_files = [os.path.join(dysarthria_validation_path, f) for f in os.listdir(dysarthria_validation_path) if f.endswith('.wav')]
229
+ non_dysarthria_validation_files = [os.path.join(non_dysarthria_validation_path, f) for f in os.listdir(non_dysarthria_validation_path) if f.endswith('.wav')]
230
+
231
+ validation_data = dysarthria_validation_files + non_dysarthria_validation_files
232
+ validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)"""
233
+
234
+
235
+
236
+
237
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
238
+
239
+
240
+
241
+
242
+
243
+
244
+
245
+
246
+
247
+
248
+
249
+
250
+ # model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
251
+ # model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
252
+
253
+ ### NEW CODES
254
+ # It seems like the classifier layer is excluded from the model's forward method (i.e., model(**inputs)).
255
+ # That's why the number of labels in the output was 32 instead of 2 even when you had already changed the classifier.
256
+ # Instead, huggingface offers the option for loading the Wav2Vec model with an adjustable classifier head on top (by setting num_labels).
257
+
258
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
259
+ ##
260
+ model_path = "/content/dysarthria_classifier1.pth"
261
+ if os.path.exists(model_path):
262
+ print(f"Loading saved model {model_path}")
263
+ model.load_state_dict(torch.load(model_path))
264
+
265
+ criterion = nn.CrossEntropyLoss()
266
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
267
+
268
+ from torch.optim.lr_scheduler import StepLR
269
+
270
+ scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
271
+
272
+ # dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/testing"
273
+ # non_dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/testing"
274
+
275
+ #dysarthria_validation_files = get_wav_files(dysarthria_validation_path)
276
+ # non_dysarthria_validation_files = get_wav_files(non_dysarthria_validation_path)
277
+
278
+ #validation_data = dysarthria_validation_files + non_dysarthria_validation_files
279
+ #validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)
280
+
281
+ epochs = 10
282
+ plt.ion()
283
+ fig, ax = plt.subplots()
284
+ x_vals = np.arange(len(train_loader)*epochs)
285
+ loss_vals = []
286
+ for epoch in range(epochs):
287
+ train_loss = train(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
288
+ print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")
289
+
290
+ val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
291
+ print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
292
+ print("Misclassified Files")
293
+ for file_path in wrong_files:
294
+ print(file_path)
295
+
296
+
297
+ sentence_pattern = re.compile(r"_(\d+)\.wav$")
298
+
299
+ sentence_counts = Counter()
300
+ for file_path in wrong_files:
301
+ match = sentence_pattern.search(file_path)
302
+ if match:
303
+ sentence_number = int(match.group(1))
304
+ sentence_counts[sentence_number] += 1
305
+
306
+ total_wrong = len(wrong_files)
307
+ print("Total wrong files:", total_wrong)
308
+ print()
309
+
310
+ for sentence_number, count in sentence_counts.most_common():
311
+ percent = count / total_wrong * 100
312
+ print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
313
+ scheduler.step()
314
+ print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))
315
+ audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
316
+ predicted_label = predict(model, audio_file, train_dataset.processor, device)
317
+ print(f"Predicted label: {predicted_label}")
318
+
319
+
320
+
321
+
322
+
323
+ # Test on a specific audio file
324
+ ##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
325
+ ##predicted_label = predict(model, audio_file, train_dataset.processor, device)
326
+ ##print(f"Predicted label: {predicted_label}")
327
+
328
+ torch.save(model.state_dict(), "dysarthria_classifier1.pth")
329
+ print("Predicting...")
330
+
331
+ """#audio aug"""
332
+
333
+ !pip install audiomentations
334
+ from audiomentations import Compose, PitchShift, TimeStretch
335
+
336
+ augmenter = Compose([
337
+ PitchShift(min_semitones=-2, max_semitones=2, p=0.1),
338
+ TimeStretch(min_rate=0.9, max_rate=1.1, p=0.1)
339
+ ])
340
+
341
+ # from torch.optim.lr_scheduler import StepLR
342
+
343
+ # scheduler = StepLR(optimizer, step_size=2, gamma=0.5)
344
+
345
+ from transformers import get_linear_schedule_with_warmup
346
+
347
+ # Define the total number of training steps
348
+ # It is usually the number of epochs times the number of batches per epoch
349
+ num_training_steps = epochs * len(train_loader)
350
+
351
+ # Define the number of warmup steps
352
+ # Usually set to a fraction of total_training_steps such as 0.1 * num_training_steps
353
+ num_warmup_steps = int(num_training_steps * 0.3)
354
+
355
+ # Create the learning rate scheduler
356
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
357
+
358
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
359
+ ##
360
+ model_path = "/content/models/my_model_06/pytorch_model.bin"
361
+ if os.path.exists(model_path):
362
+ print(f"Loading saved model {model_path}")
363
+ model.load_state_dict(torch.load(model_path))
364
+
365
+ criterion = nn.CrossEntropyLoss()
366
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
367
+
368
+ import numpy as np
369
+
370
+ def trainaug(model, dataloader, criterion, optimizer, device, loss_vals, epochs, current_epoch):
371
+ model.train()
372
+ running_loss = 0
373
+
374
+ for i, (inputs, labels) in enumerate(dataloader):
375
+ inputs = {key: value.squeeze().to(device) for key, value in inputs.items() if torch.is_tensor(value)}
376
+ labels = labels.to(device)
377
+
378
+ # Apply audio augmentation
379
+ augmented_audio = []
380
+ for audio in inputs['input_values']:
381
+ # The augmenter works with numpy arrays, so we need to convert the tensor to a numpy array
382
+ audio_np = audio.cpu().numpy()
383
+
384
+ # Apply the augmentation
385
+ augmented = augmenter(audio_np, sample_rate=16000) # Assuming a sample rate of 16000Hz
386
+
387
+ augmented_audio.append(augmented)
388
+
389
+ # Convert the list of numpy arrays back to a tensor
390
+ inputs['input_values'] = torch.from_numpy(np.array(augmented_audio)).to(device)
391
+
392
+ optimizer.zero_grad()
393
+ logits = model(**inputs).logits
394
+ loss = criterion(logits, labels)
395
+ loss.backward()
396
+ optimizer.step()
397
+
398
+ # append loss value to list
399
+ loss_vals.append(loss.item())
400
+ running_loss += loss.item()
401
+
402
+ if i % 10 == 0: # Update the plot every 10 iterations
403
+ plt.clf() # Clear the previous plot
404
+ plt.plot(loss_vals)
405
+ plt.xlim([0, len(dataloader)*epochs])
406
+ plt.ylim([0, max(loss_vals) + 2])
407
+ plt.xlabel('Training Iterations')
408
+ plt.ylabel('Loss')
409
+ plt.title(f"Training Loss at Epoch {current_epoch + 1}")
410
+ plt.pause(0.001) # Pause to update the plot
411
+
412
+ avg_loss = running_loss / len(dataloader)
413
+ print(f"Average Loss after Epoch {current_epoch + 1}: {avg_loss}\n")
414
+ return avg_loss
415
+
416
+ epochs = 20
417
+ plt.ion()
418
+ fig, ax = plt.subplots()
419
+ x_vals = np.arange(len(train_loader)*epochs)
420
+ loss_vals = []
421
+ for epoch in range(epochs):
422
+ train_loss = trainaug(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
423
+ print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")
424
+
425
+ val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
426
+ print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
427
+ print("Misclassified Files")
428
+ for file_path in wrong_files:
429
+ print(file_path)
430
+
431
+
432
+ sentence_pattern = re.compile(r"_(\d+)\.wav$")
433
+
434
+ sentence_counts = Counter()
435
+ for file_path in wrong_files:
436
+ match = sentence_pattern.search(file_path)
437
+ if match:
438
+ sentence_number = int(match.group(1))
439
+ sentence_counts[sentence_number] += 1
440
+
441
+ total_wrong = len(wrong_files)
442
+ print("Total wrong files:", total_wrong)
443
+ print()
444
+
445
+ for sentence_number, count in sentence_counts.most_common():
446
+ percent = count / total_wrong * 100
447
+ print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
448
+ scheduler.step()
449
+ print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))
450
+ audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
451
+ # predicted_label = predict(model, audio_file, train_dataset.processor, device)
452
+ # print(f"Predicted label: {predicted_label}")
453
+
454
+
455
+
456
+
457
+
458
+ # Test on a specific audio file
459
+ ##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
460
+ ##predicted_label = predict(model, audio_file, train_dataset.processor, device)
461
+ ##print(f"Predicted label: {predicted_label}")
462
+
463
+ import re
464
+ from collections import Counter
465
+ import matplotlib.pyplot as plt
466
+ import numpy as np
467
+ from sklearn.metrics import classification_report
468
+
469
+ # Define the pattern to extract the sentence number from the file path
470
+ sentence_pattern = re.compile(r"_(\d+)\.wav$")
471
+
472
+ # Counter for the total number of each sentence type in the dataset
473
+ total_sentence_counts = Counter()
474
+
475
+ for file_path in train_loader.dataset.data: # Access the file paths directly
476
+ match = sentence_pattern.search(file_path)
477
+ if match:
478
+ sentence_number = int(match.group(1))
479
+ total_sentence_counts[sentence_number] += 1
480
+
481
+ epochs = 1
482
+ plt.ion()
483
+ fig, ax = plt.subplots()
484
+ x_vals = np.arange(len(train_loader)*epochs)
485
+ loss_vals = []
486
+
487
+ for epoch in range(epochs):
488
+ # train_loss = trainaug(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
489
+ # print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")
490
+
491
+ val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
492
+ print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
493
+ print("Misclassified Files")
494
+ for file_path in wrong_files:
495
+ print(file_path)
496
+
497
+ # Counter for the misclassified sentences
498
+ sentence_counts = Counter()
499
+
500
+ for file_path in wrong_files:
501
+ match = sentence_pattern.search(file_path)
502
+ if match:
503
+ sentence_number = int(match.group(1))
504
+ sentence_counts[sentence_number] += 1
505
+
506
+ print("Total wrong files:", len(wrong_files))
507
+ print()
508
+
509
+ for sentence_number, count in sentence_counts.most_common():
510
+ percent = count / total_sentence_counts[sentence_number] * 100
511
+ print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
512
+
513
+ scheduler.step()
514
+ print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))
515
+
516
+ torch.save(model.state_dict(), "dysarthria_classifier2.pth")
517
+
518
+ save_dir = "models/my_model_06"
519
+ model.save_pretrained(save_dir)
520
+
521
+ """## Cross testing
522
+
523
+ """
524
+
525
+ # dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/testing"
526
+ # non_dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/testing"
527
+
528
+ #dysarthria_validation_files = get_wav_files(dysarthria_validation_path)
529
+ # non_dysarthria_validation_files = get_wav_files(non_dysarthria_validation_path)
530
+
531
+ #validation_data = dysarthria_validation_files + non_dysarthria_validation_files
532
+ #validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)
533
+
534
+ epochs = 1
535
+ plt.ion()
536
+ fig, ax = plt.subplots()
537
+ x_vals = np.arange(len(train_loader)*epochs)
538
+ loss_vals = []
539
+ for epoch in range(epochs):
540
+ #train_loss = train(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
541
+ #print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")
542
+
543
+ val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
544
+ print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
545
+ print("Misclassified Files")
546
+ for file_path in wrong_files:
547
+ print(file_path)
548
+
549
+
550
+ sentence_pattern = re.compile(r"_(\d+)\.wav$")
551
+
552
+ sentence_counts = Counter()
553
+ for file_path in wrong_files:
554
+ match = sentence_pattern.search(file_path)
555
+ if match:
556
+ sentence_number = int(match.group(1))
557
+ sentence_counts[sentence_number] += 1
558
+
559
+ total_wrong = len(wrong_files)
560
+ print("Total wrong files:", total_wrong)
561
+ print()
562
+
563
+ for sentence_number, count in sentence_counts.most_common():
564
+ percent = count / total_wrong * 100
565
+ print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
566
+ scheduler.step()
567
+ print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))
568
+ audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
569
+ predicted_label = predict(model, audio_file, train_dataset.processor, device)
570
+ print(f"Predicted label: {predicted_label}")
571
+
572
+
573
+
574
+
575
+
576
+ # Test on a specific audio file
577
+ ##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
578
+ ##predicted_label = predict(model, audio_file, train_dataset.processor, device)
579
+ ##print(f"Predicted label: {predicted_label}")
580
+
581
+ """## DEBUGGING"""
582
+
583
+ dysarthria_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/training"
584
+ non_dysarthria_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/training"
585
+
586
+ dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')]
587
+ non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')]
588
+
589
+ data = dysarthria_files + non_dysarthria_files
590
+ labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)
591
+
592
+ train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2)
593
+
594
+ train_dataset = DysarthriaDataset(train_data, train_labels)
595
+ test_dataset = DysarthriaDataset(test_data, test_labels)
596
+
597
+ train_loader = DataLoader(train_dataset, batch_size=4, drop_last=True)
598
+ test_loader = DataLoader(test_dataset, batch_size=4, drop_last=True)
599
+
600
+
601
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
602
+ # model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
603
+ # model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
604
+
605
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
606
+
607
+ max_length = 100_000
608
+ processor = train_dataset.processor
609
+
610
+ model.eval()
611
+ audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
612
+ # predicted_label = predict(model, audio_file, train_dataset.processor, device)
613
+ # print(f"Predicted label: {predicted_label}")
614
+
615
+ wav_data, _ = sf.read(audio_file)
616
+ inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
617
+ input_values = inputs.input_values.squeeze(0) # Squeeze the batch dimension
618
+ if max_length - input_values.shape[-1] > 0:
619
+ input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
620
+ else:
621
+ input_values = input_values[:max_length]
622
+
623
+ input_values = input_values.unsqueeze(0).to(device)
624
+ input_values.shape
625
+
626
+ with torch.no_grad():
627
+ outputs = model(**{"input_values": input_values})
628
+ logits = outputs.logits
629
+
630
+ input_values.shape, logits.shape
631
+
632
+ import torch.nn.functional as F
633
+ # Remove the batch dimension.
634
+ logits = logits.squeeze()
635
+ predicted_class_id = torch.argmax(logits, dim=-1)
636
+ predicted_class_id
637
+
638
+ """Cross testing
639
+
640
+ ##origial code
641
+ """
642
+
643
+ import os
644
+ import soundfile as sf
645
+ import torch
646
+ import torch.nn as nn
647
+ import torch.nn.functional as F
648
+ from torch.utils.data import Dataset, DataLoader
649
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
650
+ from sklearn.model_selection import train_test_split
651
+
652
+ # Custom Dataset class
653
+ class DysarthriaDataset(Dataset):
654
+ def __init__(self, data, labels, max_length=100000):
655
+ self.data = data
656
+ self.labels = labels
657
+ self.max_length = max_length
658
+ self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
659
+
660
+ def __len__(self):
661
+ return len(self.data)
662
+
663
+ def __getitem__(self, idx):
664
+ try:
665
+ wav_data, _ = sf.read(self.data[idx])
666
+ except:
667
+ print(f"Error opening file: {self.data[idx]}. Skipping...")
668
+ return self.__getitem__((idx + 1) % len(self.data))
669
+ inputs = self.processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
670
+ input_values = inputs.input_values.squeeze(0) # Squeeze the batch dimension
671
+ if self.max_length - input_values.shape[-1] > 0:
672
+ input_values = torch.cat([input_values, torch.zeros((self.max_length - input_values.shape[-1],))], dim=-1)
673
+ else:
674
+ input_values = input_values[:self.max_length]
675
+
676
+ # Remove unsqueezing the channel dimension
677
+ # input_values = input_values.unsqueeze(0)
678
+
679
+ # label = torch.zeros(32,dtype=torch.long)
680
+ # label[self.labels[idx]] = 1
681
+
682
+ ### CHANGES: simply return the label as a single integer
683
+ return {"input_values": input_values}, self.labels[idx]
684
+ ###
685
+
686
+
687
+ def train(model, dataloader, criterion, optimizer, device, ax, loss_vals, x_vals, fig,train_loader,epochs):
688
+ model.train()
689
+ running_loss = 0
690
+
691
+ for i, (inputs, labels) in enumerate(dataloader):
692
+ inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
693
+ labels = labels.to(device)
694
+
695
+ optimizer.zero_grad()
696
+ logits = model(**inputs).logits
697
+ loss = criterion(logits, labels)
698
+ loss.backward()
699
+ optimizer.step()
700
+
701
+ # append loss value to list
702
+ loss_vals.append(loss.item())
703
+ running_loss += loss.item()
704
+
705
+ if i:
706
+ # update plot
707
+ ax.clear()
708
+ ax.set_xlim([0, len(train_loader)*epochs])
709
+ ax.set_xlabel('Training Iterations')
710
+ ax.set_ylim([0, max(loss_vals) + 2])
711
+ ax.set_ylabel('Loss')
712
+ ax.plot(x_vals[:len(loss_vals)], loss_vals)
713
+ fig.canvas.draw()
714
+ plt.pause(0.001)
715
+
716
+ avg_loss = running_loss / len(dataloader)
717
+ print(avg_loss)
718
+ print("\n")
719
+ return avg_loss
720
+
721
+
722
+
723
+ def main():
724
+ dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/training"
725
+ non_dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/training"
726
+
727
+ dysarthria_files = get_wav_files(dysarthria_path)
728
+ non_dysarthria_files = get_wav_files(non_dysarthria_path)
729
+
730
+ data = dysarthria_files + non_dysarthria_files
731
+ labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)
732
+
733
+ train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2)
734
+
735
+ train_dataset = DysarthriaDataset(train_data, train_labels)
736
+ test_dataset = DysarthriaDataset(test_data, test_labels)
737
+
738
+ train_loader = DataLoader(train_dataset, batch_size=8, drop_last=True)
739
+ test_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)
740
+ validation_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)
741
+
742
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
743
+ # model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
744
+ # model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
745
+
746
+ ### NEW CODES
747
+ # It seems like the classifier layer is excluded from the model's forward method (i.e., model(**inputs)).
748
+ # That's why the number of labels in the output was 32 instead of 2 even when you had already changed the classifier.
749
+ # Instead, huggingface offers the option for loading the Wav2Vec model with an adjustable classifier head on top (by setting num_labels).
750
+
751
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
752
+ ###
753
+ #model_path = "/content/dysarthria_classifier3.pth"
754
+ #if os.path.exists(model_path):
755
+ #print(f"Loading saved model {model_path}")
756
+ #model.load_state_dict(torch.load(model_path))
757
+
758
+ criterion = nn.CrossEntropyLoss()
759
+ optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
760
+ dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/testing"
761
+ non_dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/testing"
762
+
763
+ dysarthria_validation_files = get_wav_files(dysarthria_validation_path)
764
+ non_dysarthria_validation_files = get_wav_files(non_dysarthria_validation_path)
765
+
766
+ validation_data = dysarthria_validation_files + non_dysarthria_validation_files
767
+ validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)
768
+
769
+ epochs = 10
770
+ fig, ax = plt.subplots()
771
+ x_vals = np.arange(len(train_loader)*epochs)
772
+ loss_vals = []
773
+ nume = 1
774
+ for epoch in range(epochs):
775
+ train_loss = train(model, train_loader, criterion, optimizer, device, ax, loss_vals, x_vals, fig, train_loader, epoch+1)
776
+ print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")
777
+
778
+ val_loss, val_accuracy, wrong_files = evaluate(model, validation_loader, criterion, device)
779
+ print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
780
+ print("Misclassified Files")
781
+ for file_path in wrong_files:
782
+ print(file_path)
783
+
784
+
785
+ sentence_pattern = re.compile(r"_(\d+)\.wav$")
786
+
787
+ sentence_counts = Counter()
788
+ for file_path in wrong_files:
789
+ match = sentence_pattern.search(file_path)
790
+ if match:
791
+ sentence_number = int(match.group(1))
792
+ sentence_counts[sentence_number] += 1
793
+
794
+ total_wrong = len(wrong_files)
795
+ print("Total wrong files:", total_wrong)
796
+ print()
797
+
798
+ for sentence_number, count in sentence_counts.most_common():
799
+ percent = count / total_wrong * 100
800
+ print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
801
+
802
+
803
+ torch.save(model.state_dict(), "dysarthria_classifier4.pth")
804
+ print("Predicting...")
805
+ # Test on a specific audio file
806
+ ##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
807
+ ##predicted_label = predict(model, audio_file, train_dataset.processor, device)
808
+ ##print(f"Predicted label: {predicted_label}")
809
+
810
+ def predict(model, file_path, processor, device, max_length=100000): ### CHANGES: added max_length as an argument.
811
+ model.eval()
812
+ with torch.no_grad():
813
+ wav_data, _ = sf.read(file_path)
814
+ inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
815
+ # inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
816
+
817
+ ### NEW CODES HERE
818
+ input_values = inputs.input_values.squeeze(0) # Squeeze the batch dimension
819
+ if max_length - input_values.shape[-1] > 0:
820
+ input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
821
+ else:
822
+ input_values = input_values[:max_length]
823
+ input_values = input_values.unsqueeze(0).to(device)
824
+ inputs = {"input_values": input_values}
825
+ ###
826
+
827
+ logits = model(**inputs).logits
828
+ # _, predicted = torch.max(logits, dim=0)
829
+
830
+ ### NEW CODES HERE
831
+ # Remove the batch dimension.
832
+ logits = logits.squeeze()
833
+ predicted_class_id = torch.argmax(logits, dim=-1).item()
834
+ ###
835
+
836
+ # return predicted.item()
837
+ return predicted_class_id
838
+ def evaluate(model, dataloader, criterion, device):
839
+ model.eval()
840
+ running_loss = 0
841
+ correct_predictions = 0
842
+ total_predictions = 0
843
+ wrong_files = []
844
+ with torch.no_grad():
845
+ for inputs, labels in dataloader:
846
+ inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
847
+ labels = labels.to(device)
848
+
849
+ logits = model(**inputs).logits
850
+ loss = criterion(logits, labels)
851
+ running_loss += loss.item()
852
+
853
+ _, predicted = torch.max(logits, 1)
854
+ correct_predictions += (predicted == labels).sum().item()
855
+ total_predictions += labels.size(0)
856
+
857
+ wrong_idx = (predicted != labels).nonzero().squeeze().cpu().numpy()
858
+ if wrong_idx.ndim > 0:
859
+ for idx in wrong_idx:
860
+ wrong_files.append(dataloader.dataset.data[idx])
861
+ elif wrong_idx.size > 0:
862
+ wrong_files.append(dataloader.dataset.data[wrong_idx])
863
+
864
+
865
+ avg_loss = running_loss / len(dataloader)
866
+ accuracy = correct_predictions / total_predictions
867
+ return avg_loss, accuracy, wrong_files
868
+
869
+
870
+
871
+ def get_wav_files(base_path):
872
+ wav_files = []
873
+ for subject_folder in os.listdir(base_path):
874
+ subject_path = os.path.join(base_path, subject_folder)
875
+ if os.path.isdir(subject_path):
876
+ for wav_file in os.listdir(subject_path):
877
+ if wav_file.endswith('.wav'):
878
+ wav_files.append(os.path.join(subject_path, wav_file))
879
+ return wav_files
880
+ if __name__ == "__main__":
881
+ main()