Hnabil's picture
Add application file
2d09772
raw
history blame contribute delete
No virus
1.64 kB
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay
import gradio as gr
def train_model(normalize):
# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names
# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train)
np.set_printoptions(precision=2)
title = (
"Normalized confusion matrix" if normalize
else "Confusion matrix, without normalization"
)
disp = ConfusionMatrixDisplay.from_estimator(
classifier,
X_test,
y_test,
display_labels=class_names,
cmap=plt.cm.Blues,
normalize='true' if normalize else None,
)
disp.ax_.set_title(title)
return disp.figure_
title = "Confusion matrix"
description = "Example of confusion matrix usage to evaluate the quality of the output of a classifier on the iris data set"
with gr.Blocks() as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description)
normalize = gr.Checkbox(label="Normalize")
plot = gr.Plot(label="Confusion matrix")
fn = partial(train_model)
normalize.change(fn=fn, inputs=[normalize], outputs=plot)
demo.launch(enable_queue=True, debug=True)