auto-ml-gradio / app.py
harikrishnad1997's picture
Update app.py
03371ed verified
raw
history blame contribute delete
No virus
2.68 kB
import gradio as gr
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from io import StringIO
# Function to plot histogram
def plot_histogram(file_contents, column, ax=None):
# Read the CSV file
custom_df = pd.read_csv(StringIO(file_contents))
# Plot histogram
sns.histplot(custom_df[column], ax=ax)
ax.set_title(f'Histogram for {column}')
ax.set_xlabel(column)
ax.set_ylabel('Frequency')
# Function to plot scatter plot
def plot_scatter(file_contents, x_axis, y_axis, ax=None):
# Read the CSV file
custom_df = pd.read_csv(StringIO(file_contents))
# Plot scatter plot
sns.scatterplot(x=x_axis, y=y_axis, data=custom_df, ax=ax)
ax.set_title(f'Scatter Plot ({x_axis} vs {y_axis})')
ax.set_xlabel(x_axis)
ax.set_ylabel(y_axis)
def layout_fn(file, text, text_1, text_2):
# Create the figure with subplots
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
# Check if data is uploaded and a column is selected for histogram
if file and text:
plot_histogram(file.getvalue(), text, ax=axes[0])
else:
axes[0].text(0.5, 0.5, "Upload a CSV and select a column", ha='center', va='center')
# Check if data is uploaded and both x and y columns are selected for scatter plot
if file and text_1 and text_2:
plot_scatter(file.getvalue(), text_1, text_2, ax=axes[1])
else:
axes[1].text(0.5, 0.5, "Upload a CSV, select X and Y columns", ha='center', va='center')
# Adjust layout
fig.suptitle("Data Visualization")
plt.tight_layout()
return fig
# Create the Gradio interface
interface = gr.Interface(
fn=layout_fn,
inputs=[
gr.inputs.File(label="Upload CSV file"),
gr.inputs.Dropdown(label="Select Column (Histogram)", choices=[]),
gr.inputs.Dropdown(label="Select X-axis (Scatter)", choices=[]),
gr.inputs.Dropdown(label="Select Y-axis (Scatter)", choices=[]),
],
outputs="plot",
title="Data Visualization Tool",
description="Upload a CSV file, select columns for histogram and scatter plots.",
)
def update_choices(file):
if file:
data = pd.read_csv(StringIO(file.getvalue()))
choices = list(data.columns)
interface.set_config(
inputs=[
gr.inputs.File(label="Upload CSV file"),
gr.inputs.Dropdown(label="Select Column (Histogram)", choices=choices),
gr.inputs.Dropdown(label="Select X-axis (Scatter)", choices=choices),
gr.inputs.Dropdown(label="Select Y-axis (Scatter)", choices=choices),
]
)
interface.run(share=True,fn_change=update_choices)