Spaces:
Sleeping
Sleeping
# Importation des bibliothèques nécessaires | |
import pandas as pd | |
import numpy as np | |
import glob | |
import os | |
import joblib | |
import gradio as gr | |
from sklearn.preprocessing import LabelEncoder, StandardScaler | |
from xgboost import XGBRegressor | |
# 1. Chargement des données | |
print("Chargement des données...") | |
parquet_files = glob.glob('subset_top5_per_station_fuel.parquet') | |
if not parquet_files: | |
raise FileNotFoundError("Aucun fichier Parquet trouvé dans le répertoire spécifié.") | |
df_list = [] | |
for f in parquet_files: | |
print(f"Chargement du fichier {f}") | |
df_list.append(pd.read_parquet(f)) | |
df = pd.concat(df_list, ignore_index=True) | |
del df_list # Libération de la mémoire | |
print(f"Nombre total d'enregistrements: {len(df)}") | |
# 2. Prétraitement des données | |
print("Prétraitement des données...") | |
df['rate_date'] = pd.to_datetime(df['rate_date']) | |
df['brent_date'] = pd.to_datetime(df['brent_date']) | |
df = df.sort_values('rate_date') | |
df = df.dropna() | |
# Exclure les carburants E85 et GPLc | |
df = df[~df['fuel_name'].isin(['E85', 'GPLc'])] | |
# Sélection des colonnes pertinentes | |
cols_to_use = ['station_id', 'commune', 'marque', 'departement', 'regioncode', | |
'coordlatitude', 'coordlongitude', 'fuel_name', 'price', | |
'rate_date', 'brent_rate_eur', 'brent_date'] | |
df = df[cols_to_use] | |
# Encodage des variables catégorielles | |
print("Encodage des variables catégorielles...") | |
label_encoders = {} | |
categorical_cols = ['station_id', 'commune', 'marque', 'departement', | |
'regioncode', 'fuel_name'] | |
for col in categorical_cols: | |
le = LabelEncoder() | |
df[col] = le.fit_transform(df[col].astype(str)) | |
label_encoders[col] = le | |
# Création des mappings pour les communes et les départements | |
commune_mapping = pd.DataFrame({ | |
'commune_encoded': np.arange(len(label_encoders['commune'].classes_)), | |
'commune_decoded': label_encoders['commune'].classes_ | |
}) | |
departement_mapping = pd.DataFrame({ | |
'departement_encoded': np.arange(len(label_encoders['departement'].classes_)), | |
'departement_decoded': label_encoders['departement'].classes_ | |
}) | |
# Obtenir les types de carburant uniques | |
fuel_types = label_encoders['fuel_name'].classes_.tolist() | |
# Obtenir les départements uniques | |
departments = label_encoders['departement'].classes_.tolist() | |
# Fonction pour mettre à jour la liste des stations | |
def update_stations(commune_input, departments): | |
if commune_input: | |
# Recherche insensible à la casse avec correspondance partielle | |
matching_communes = commune_mapping[commune_mapping['commune_decoded'].str.contains(commune_input, case=False, na=False)] | |
if matching_communes.empty: | |
return gr.update(choices=[], value=None) | |
commune_encoded_values = matching_communes['commune_encoded'].values | |
# Filtrer les stations par les communes correspondantes | |
filtered_df = df[df['commune'].isin(commune_encoded_values)] | |
elif departments: | |
# Vérifier si les départements existent | |
valid_departments = [dept for dept in departments if dept in label_encoders['departement'].classes_] | |
if not valid_departments: | |
return gr.update(choices=[], value=None) | |
# Filtrer les stations par départements | |
departments_encoded = label_encoders['departement'].transform(valid_departments) | |
filtered_df = df[df['departement'].isin(departments_encoded)] | |
else: | |
# Si aucun filtre, afficher toutes les stations | |
filtered_df = df.copy() | |
if filtered_df.empty: | |
return gr.update(choices=[], value=None) | |
# Obtenir les informations des stations uniques | |
station_info = filtered_df[['station_id', 'commune', 'marque']].drop_duplicates() | |
# Décoder les valeurs encodées | |
station_info['station_id_decoded'] = label_encoders['station_id'].inverse_transform(station_info['station_id']) | |
station_info['commune_decoded'] = label_encoders['commune'].inverse_transform(station_info['commune']) | |
station_info['marque_decoded'] = label_encoders['marque'].inverse_transform(station_info['marque']) | |
# Construire les chaînes d'affichage | |
station_info['station_display'] = station_info.apply( | |
lambda row: f"{row['commune_decoded']} - {row['marque_decoded']} ({row['station_id_decoded']})", | |
axis=1 | |
) | |
# Construire les choix sous forme de tuples (affichage, valeur) | |
station_choices = list(zip(station_info['station_display'], station_info['station_id_decoded'])) | |
return gr.update(choices=station_choices, value=None) | |
# Fonction pour effectuer les prévisions | |
def forecast_prices(model, last_known_data, scaler, required_columns, brent_price, horizons=[3, 7, 15, 30]): | |
forecasts = {} | |
for horizon in horizons: | |
future_date = last_known_data['rate_date'] + pd.Timedelta(days=horizon) | |
input_data = last_known_data.copy() | |
input_data['rate_date'] = future_date | |
input_data['day_of_week'] = future_date.dayofweek | |
input_data['month'] = future_date.month | |
input_data['year'] = future_date.year | |
# Mise à jour des variables de décalage du Brent | |
for lag in [1, 3, 7, 15, 30]: | |
input_data[f'brent_rate_eur_lag_{lag}'] = brent_price | |
# Préparation des features | |
input_features = input_data.drop(['price', 'rate_date', 'brent_date']) | |
input_features = input_features.to_frame().T | |
# S'assurer que toutes les colonnes sont présentes | |
missing_cols = set(required_columns) - set(input_features.columns) | |
for col in missing_cols: | |
input_features[col] = 0 | |
input_features = input_features[required_columns] | |
# Mise à l'échelle des features | |
input_features_scaled = scaler.transform(input_features) | |
predicted_price = model.predict(input_features_scaled) | |
forecasts[horizon] = predicted_price[0] | |
return forecasts | |
# Fonction principale pour obtenir les prédictions | |
def get_predictions(station_selection, fuel_types_selected, brent_price, commune_input, departments): | |
if not station_selection or not fuel_types_selected: | |
return "Veuillez sélectionner une station et au moins un type de carburant." | |
results = "" | |
# station_selection est l'ID décodé de la station | |
station_id = station_selection | |
if station_id not in label_encoders['station_id'].classes_: | |
return f"Station ID {station_id} non trouvé dans les données." | |
station_id_encoded = label_encoders['station_id'].transform([station_id])[0] | |
for fuel_type in fuel_types_selected: | |
# Charger le modèle et le scaler pour le type de carburant | |
model_filename = f'fuel_price_model_{fuel_type}.pkl' | |
scaler_filename = f'scaler_{fuel_type}.pkl' | |
if not os.path.exists(model_filename) or not os.path.exists(scaler_filename): | |
results += f"\nModèle ou scaler pour le carburant {fuel_type} non trouvé." | |
continue | |
model = joblib.load(model_filename) | |
scaler = joblib.load(scaler_filename) | |
# Obtenir les 5 derniers prix | |
fuel_name_encoded = label_encoders['fuel_name'].transform([fuel_type])[0] | |
df_station_fuel = df[(df['station_id'] == station_id_encoded) & (df['fuel_name'] == fuel_name_encoded)] | |
df_station_fuel = df_station_fuel.sort_values('rate_date', ascending=False) | |
if df_station_fuel.empty: | |
results += f"\nAucune donnée trouvée pour la station {station_id} et le carburant {fuel_type}." | |
continue | |
last_5_prices = df_station_fuel.head(5)[['rate_date', 'price']] | |
last_5_prices['rate_date'] = last_5_prices['rate_date'].dt.strftime('%Y-%m-%d %H:%M:%S') | |
results += f"\n\nType de carburant : {fuel_type}\nLes 5 derniers prix :\n{last_5_prices.to_string(index=False)}" | |
# Préparation des données pour la prédiction | |
last_known_data = df_station_fuel.iloc[0].copy() | |
last_known_data['brent_rate_eur'] = brent_price | |
# Recréer les features utilisées lors de l'entraînement | |
df_fuel = df[df['fuel_name'] == fuel_name_encoded].copy() | |
# Ingénierie des caractéristiques | |
df_fuel['day_of_week'] = df_fuel['rate_date'].dt.dayofweek | |
df_fuel['month'] = df_fuel['rate_date'].dt.month | |
df_fuel['year'] = df_fuel['rate_date'].dt.year | |
for lag in [1, 3, 7, 15, 30]: | |
df_fuel[f'brent_rate_eur_lag_{lag}'] = df_fuel['brent_rate_eur'].shift(lag) | |
df_fuel = df_fuel.dropna() | |
X = df_fuel.drop(['price', 'rate_date', 'brent_date'], axis=1) | |
required_columns = X.columns.tolist() | |
# Prévisions | |
forecasts = forecast_prices(model, last_known_data, scaler, required_columns, brent_price) | |
results += "\nPrévisions :\n" | |
for horizon, price in forecasts.items(): | |
results += f"Dans {horizon} jours : {price:.4f} €\n" | |
return results | |
# CSS personnalisé pour le fond noir et le texte clair | |
custom_css = """ | |
body { | |
background-color: black; | |
color: white; | |
} | |
.gradio-container { | |
background-color: black; | |
color: white; | |
} | |
""" | |
# 7. Construction de l'Interface Gradio avec le fond noir | |
with gr.Blocks(css=custom_css) as demo: | |
gr.Markdown("# Prédiction du Prix des Carburants") | |
with gr.Row(): | |
fuel_type_checkbox = gr.CheckboxGroup( | |
choices=fuel_types, | |
label="Sélectionnez les types de carburant", | |
value=fuel_types # Tous sélectionnés par défaut | |
) | |
with gr.Row(): | |
commune_input = gr.Textbox( | |
label="Entrez la commune", | |
placeholder="Tapez le nom de la commune..." | |
) | |
department_dropdown = gr.Dropdown( | |
choices=departments, | |
label="Sélectionnez le(s) département(s)", | |
multiselect=True | |
) | |
station_dropdown = gr.Dropdown( | |
choices=[], | |
label="Sélectionnez la station" | |
) | |
# Mettre à jour la liste des stations lorsque la commune ou le département change | |
def update_stations_wrapper(commune, departments): | |
return update_stations(commune, departments) | |
commune_input.change( | |
fn=update_stations_wrapper, | |
inputs=[commune_input, department_dropdown], | |
outputs=station_dropdown | |
) | |
department_dropdown.change( | |
fn=update_stations_wrapper, | |
inputs=[commune_input, department_dropdown], | |
outputs=station_dropdown | |
) | |
brent_price_input = gr.Number( | |
label="Entrez le cours du Brent (€)", | |
value=70.0 | |
) | |
predict_button = gr.Button("Prédire") | |
output = gr.Textbox(label="Résultats") | |
def on_predict_click(station_selection, fuel_types_selected, brent_price, commune_input, departments): | |
return get_predictions(station_selection, fuel_types_selected, brent_price, commune_input, departments) | |
predict_button.click( | |
fn=on_predict_click, | |
inputs=[station_dropdown, fuel_type_checkbox, brent_price_input, commune_input, department_dropdown], | |
outputs=output | |
) | |
demo.launch(share=True) |