|
import transformers |
|
from transformers import pipeline |
|
import gradio as gr |
|
import pandas |
|
import matplotlib.pyplot as plt |
|
import os |
|
import sys |
|
os.system('python -m pip install --upgrade pip') |
|
os.system('pip install -U scikit-learn scipy matplotlib') |
|
|
|
from sklearn import model_selection |
|
from sklearn.linear_model import LogisticRegression |
|
from sklearn.tree import DecisionTreeClassifier |
|
from sklearn.neighbors import KNeighborsClassifier |
|
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis |
|
from sklearn.naive_bayes import GaussianNB |
|
from sklearn.svm import SVC |
|
os.system("pip install git+https://github.com/openai/whisper.git") |
|
import whisper |
|
|
|
|
|
|
|
|
|
|
|
whisper_esc50 = pipeline(model="mskov/whisper_esc50") |
|
whisper_miso= pipeline(model="mskov/whisper_miso") |
|
whisper_tiny = whisper.load_model("tiny") |
|
whisper_base = whisper.load_model("base") |
|
|
|
dataset = load_dataset("mskov/miso_test") |
|
|
|
names = ['path', 'file_name', 'category'] |
|
dataframe = pandas.read_csv(url, names=names) |
|
array = dataframe.values |
|
X = array[:,0:2] |
|
Y = array[:,2] |
|
|
|
seed = 7 |
|
|
|
models = [whisper_esc50, whisper_miso, whisper_tiny, whisper_base] |
|
models.append(('LR', LogisticRegression())) |
|
models.append(('LDA', LinearDiscriminantAnalysis())) |
|
models.append(('KNN', KNeighborsClassifier())) |
|
models.append(('CART', DecisionTreeClassifier())) |
|
models.append(('NB', GaussianNB())) |
|
models.append(('SVM', SVC())) |
|
|
|
results = [] |
|
names = [] |
|
scoring = 'accuracy' |
|
for name, model in models: |
|
kfold = model_selection.KFold(n_splits=10, random_state=seed) |
|
cv_results = model_selection.cross_val_score(model, X, Y, cv=kfold, scoring=scoring) |
|
results.append(cv_results) |
|
names.append(name) |
|
msg = "%s: %f (%f)" % (name, cv_results.mean(), cv_results.std()) |
|
print(msg) |
|
|
|
fig = plt.figure() |
|
fig.suptitle('Algorithm Comparison') |
|
ax = fig.add_subplot(111) |
|
plt.boxplot(results) |
|
ax.set_xticklabels(names) |
|
plt.show() |