app / app2.py
Abdel's picture
Update app2.py
8d3cd93 verified
import os
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import gradio as gr
current_dir = os.path.dirname(os.path.realpath(__file__))
data = pd.read_csv(os.path.join(current_dir, "datav4.csv"))
X_all = data.drop(["targets"], axis=1)
y_all = data["targets"]
num_test = 0.20
X_train, X_test, y_train, y_test = train_test_split(
X_all, y_all, test_size=num_test, random_state=23
)
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
from joblib import dump, load
js = """
function createGradioAnimation() {
var containerlog = document.createElement('div');
containerlog.id = 'gradio-animation';
containerlog.style.fontSize = '2em';
containerlog.style.fontWeight = 'bold';
containerlog.style.textAlign = 'center';
containerlog.style.marginBottom = '50px';
var logo = document.createElement('img');
logo.src = 'https://data2innov.fr/build/img/data2innov_logo.svg'; // Replace with your logo URL
logo.style.width = '200px'; // Adjust the size of the logo
containerlog.appendChild(logo);
var textElement = document.createElement('span');
// Ajouter le texte entre les deux images
textElement.textContent = "L'IA au service de l'implantologie Pr. Zoubeir TOURKI "; // Remplacez par votre texte
containerlog.appendChild(textElement);
var logo2 = document.createElement('img');
logo2.src=''
logo2.style.width = '100px'; // Adjust the size of the logo
containerlog.appendChild(logo2);
containerlog.style.display = 'flex';
containerlog.style.flexWrap = 'wrap';
containerlog.style.justifyContent = 'space-between';
var container = document.createElement('div');
container.id = 'gradio-animation';
container.style.fontSize = '2em';
container.style.fontWeight = 'bold';
container.style.textAlign = 'center';
container.style.marginBottom = '50px';
var text = 'Cette application est realisée dans le cadre du congrés ATMO 1-2 mars 2024 à Monastir';
for (var i = 0; i < text.length; i++) {
(function(i){
setTimeout(function(){
var letter = document.createElement('span');
letter.style.opacity = '0';
letter.style.color='red'
letter.style.transition = 'opacity 0.5s';
letter.innerText = text[i];
container.appendChild(letter);
setTimeout(function() {
letter.style.opacity = '1';
}, 50);
}, i * 250);
})(i);
}
var gradioContainerlog = document.querySelector('.gradio-container');
gradioContainerlog.insertBefore(containerlog, gradioContainerlog.firstChild);
var gradioContainer = document.querySelector('.gradio-container');
gradioContainer.insertBefore(container, gradioContainer.firstChild);
return 'Animation created';
}
"""
#clf = load('filename.model')
#predictions = clf.predict(X_test)
def predict_survival(densites, diametres):
df = pd.DataFrame.from_dict(
{
"densites": [densites],
"diametres": [diametres]
}
)
if 300> densites or densites>1200 :
gr.Info("La densité doit être entre 300 et 1200 HU'")
return None
if 3> diametres or diametres>6 :
gr.Info("le diametre doit être entre 3 et 6 mm")
return None
pred = clf.predict_proba(df)[0]
html=0
if (pred[0]<pred[1]):
html = (
"<div style='max-width:100%; color:green; font-size: 20px; background-color:white; text-align:center; max-height:360px; overflow:auto'>"
+ "<br> <br> Vous pouvez poser l'implant (précison : "+str(int(pred[1]*100))+"%)<br> <br>"
+ "</div>"
)
else:
html = (
"<div style='max-width:100%; color:red; font-size: 20px; background-color:white; text-align:center; max-height:360px; overflow:auto'>"
+ "<br> <br> Vous ne pouvez pas poser l'implant (précison : "+str(int(pred[0]*100))+"%)<br> <br>"
+ "</div>"
)
json_pred={"No": float(pred[0]), "Yes": float(pred[1])}
return html
demo = gr.Interface(
predict_survival,
[
gr.Number(label="Densité 'entre 300 et 1200 HU'"),
gr.Number(label="Diametre 'entre 3 et 6 mm'")
],
["html"],
examples=[
[700, 5],
[1200, 3],
[311, 5.2]
],
live=False,
js=js,
submit_btn="Prédire"
)
if __name__ == "__main__":
demo.launch(share=True)