|
import gradio as gr
|
|
import pandas as pd
|
|
import os
|
|
import pickle
|
|
from datetime import datetime
|
|
from sklearn import set_config
|
|
set_config(transform_output="pandas")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DIRPATH = os.path.dirname(os.path.realpath(__file__))
|
|
tmp_dir = os.path.join(DIRPATH, "src", "assets", "tmp",)
|
|
|
|
tmp_df_fp = os.path.join(
|
|
tmp_dir, f"history_{datetime.now().strftime('%d-%m-%Y')}.csv")
|
|
ml_core_fp = os.path.join(DIRPATH, "src", "assets",
|
|
"ml", "crop_recommandation2.pkl")
|
|
init_df = pd.DataFrame(
|
|
{
|
|
"N": [],
|
|
"P": [],
|
|
"K": [],
|
|
"temperature": [],
|
|
"humidity": [],
|
|
"ph": [],
|
|
"rainfall": [],
|
|
}
|
|
)
|
|
|
|
|
|
|
|
|
|
def load_ml_components(fp):
|
|
"Load the ml component to re-use in app"
|
|
with open(fp, "rb") as f:
|
|
object = pickle.load(f)
|
|
return object
|
|
|
|
|
|
def setup(fp):
|
|
"Setup the required elements like files, models, global variables, etc"
|
|
|
|
|
|
if not os.path.exists(fp):
|
|
df_history = init_df.copy()
|
|
else:
|
|
df_history = pd.read_csv(fp)
|
|
|
|
df_history.to_csv(fp, index=False)
|
|
|
|
return df_history
|
|
|
|
|
|
def select_categorical_widget(col_index, col_name, encoder):
|
|
"""This function will return the right widget to use for each categorical feature
|
|
"""
|
|
|
|
categories = encoder.categories_[col_index].tolist()
|
|
n_unique = len(categories)
|
|
|
|
|
|
|
|
|
|
if n_unique == 2:
|
|
print(
|
|
f"[Info] unique categories for feature '{col_index}' {col_name} ({type(categories)}) are : {categories}")
|
|
|
|
widget = gr.Checkbox(label=f"Enter {col_name}", value=categories)
|
|
elif n_unique <= 5:
|
|
widget = gr.Radio(label=f"Enter {col_name}", choices=categories)
|
|
else:
|
|
widget = gr.Dropdown(label=f"Enter {col_name}", choices=categories)
|
|
|
|
return widget
|
|
|
|
|
|
def make_prediction(*args):
|
|
"""Function that takes values from fields to make 1-by-1 prediction
|
|
"""
|
|
print(
|
|
f"[Info] input args of the function {args} ")
|
|
raw = {k: [val if not isinstance(val, list) else val[0]]
|
|
for val, k in zip(args, num_cols+cat_cols)}
|
|
print(
|
|
f"[Info] input modified a bit {raw}\n")
|
|
|
|
df_input = pd.DataFrame(raw)
|
|
global df_history
|
|
|
|
|
|
df_input.drop_duplicates(inplace=True, ignore_index=True)
|
|
print(f"\n[Info] Input with duplicated rows: \n{df_input.to_string()}")
|
|
|
|
df_input_num, df_input_cat = None, None
|
|
|
|
if len(cat_cols) > 0:
|
|
df_input_cat = df_input[cat_cols].copy()
|
|
if cat_imputer:
|
|
df_input_cat = cat_imputer.transform(df_input_cat)
|
|
if encoder:
|
|
df_input_cat = encoder.transform(df_input_cat)
|
|
|
|
if len(num_cols) > 0:
|
|
df_input_num = df_input[num_cols].copy()
|
|
if num_imputer:
|
|
df_input_num = num_imputer.transform(df_input_num)
|
|
if scaler:
|
|
df_input_num = scaler.transform(df_input_num)
|
|
|
|
df_input_ok = pd.concat([df_input_num, df_input_cat], axis=1)
|
|
|
|
prediction_output = model.predict_proba(df_input_ok)
|
|
|
|
output = model.predict_proba(df_input_ok)
|
|
|
|
|
|
confidence_score = output.max(axis=-1)
|
|
df_input["confidence score"] = confidence_score
|
|
|
|
|
|
predicted_idx = output.argmax(axis=-1)
|
|
|
|
|
|
df_input["predicted crop"] = predicted_idx
|
|
predicted_label = df_input["predicted crop"].replace(idx_to_labels)
|
|
df_input["predicted crop"] = predicted_label
|
|
|
|
print(
|
|
f"[Info] Prediction output (of type '{type(prediction_output)}') from passed input: {prediction_output} of shape {prediction_output.shape}")
|
|
|
|
|
|
|
|
|
|
print(
|
|
f"\n[Info] output information as dataframe: \n{df_input.to_string()}")
|
|
df_history = pd.concat([df_history, df_input], ignore_index=True).drop_duplicates(
|
|
ignore_index=True, keep='last')
|
|
df_history.to_csv(tmp_df_fp, index=False, )
|
|
|
|
return df_input
|
|
|
|
|
|
def download():
|
|
return gr.File.update(label="History File",
|
|
visible=True,
|
|
value=tmp_df_fp)
|
|
|
|
|
|
def hide_download():
|
|
return gr.File.update(label="History File",
|
|
visible=False)
|
|
|
|
|
|
|
|
ml_components_dict = load_ml_components(fp=ml_core_fp)
|
|
|
|
num_cols = [
|
|
"N",
|
|
"P",
|
|
"K",
|
|
"temperature",
|
|
"humidity",
|
|
"ph",
|
|
"rainfall",
|
|
]
|
|
cat_cols = ml_components_dict['cat_cols'] if 'cat_cols' in ml_components_dict else [
|
|
]
|
|
num_imputer = ml_components_dict['num_imputer'].set_output(transform="pandas") if (
|
|
'num_cols' in ml_components_dict and 'num_imputer' in ml_components_dict) else None
|
|
cat_imputer = ml_components_dict['cat_imputer'].set_output(transform="pandas") if (
|
|
'cat_cols' in ml_components_dict and 'cat_imputer' in ml_components_dict) else None
|
|
scaler = ml_components_dict['scaler'].set_output(
|
|
transform="pandas") if 'scaler' in ml_components_dict else None
|
|
encoder = ml_components_dict['encoder'] if 'encoder' in ml_components_dict else None
|
|
model = ml_components_dict['model']
|
|
labels = ml_components_dict['labels'] if 'labels' in ml_components_dict else []
|
|
idx_to_labels = {i: l for (i, l) in enumerate(labels)}
|
|
|
|
print(f"\n[Info] ML components loaded: {list(ml_components_dict.keys())}")
|
|
|
|
df_history = setup(tmp_df_fp)
|
|
|
|
|
|
|
|
|
|
|
|
demo_inputs = []
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown('''<img class="center" src="https://www.verdict.co.uk/wp-content/uploads/2018/12/Agri-tech.jpg" width="60%" height="60%">
|
|
<style>
|
|
.center {
|
|
display: block;
|
|
margin-left: auto;
|
|
margin-right: auto;
|
|
width: 50%;
|
|
}
|
|
</style>''')
|
|
gr.Markdown('''<center><h1> π Agri-Tech App π </h1><center>''')
|
|
gr.Markdown('''
|
|
This is a ML API for classification of crop to plant on a land regarding some features
|
|
''')
|
|
|
|
with gr.Row():
|
|
for i in range(0, init_df.shape[1],):
|
|
demo_inputs.append(gr.Number(label=f"Enter {num_cols[i]}"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = gr.Dataframe(df_history)
|
|
|
|
btn_predict = gr.Button("Predict")
|
|
btn_predict.click(fn=make_prediction, inputs=demo_inputs, outputs=output)
|
|
|
|
file_obj = gr.File(label="History File",
|
|
visible=False
|
|
)
|
|
|
|
btn_download = gr.Button("Download")
|
|
btn_download.click(fn=download, inputs=[], outputs=file_obj)
|
|
output.change(fn=hide_download, inputs=[], outputs=file_obj)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch(
|
|
debug=True
|
|
)
|
|
|