LUWA / run_gradio.py
DanielXu0208's picture
Update run_gradio.py
645dad2 verified
import gradio as gr
import torch
import torchvision
import pandas as pd
import os
from PIL import Image
from utils.experiment_utils import get_model
# File to store the visitor count
visitor_count_file = "visitor_count.txt"
# Function to update visitor count
def update_visitor_count():
if os.path.exists(visitor_count_file):
with open(visitor_count_file, "r") as file:
count = int(file.read())
else:
count = 0 # Start from zero if no file exists
# Increment visitor count
count += 1
# Save the updated count back to the file
with open(visitor_count_file, "w") as file:
file.write(str(count))
return count
# Custom flagging logic to save flagged data to a CSV file
class CustomFlagging(gr.FlaggingCallback):
def __init__(self, dir_name="flagged_data"):
self.dir = dir_name
self.image_dir = os.path.join(self.dir, "uploaded_images")
if not os.path.exists(self.dir):
os.makedirs(self.dir)
if not os.path.exists(self.image_dir):
os.makedirs(self.image_dir)
# Define setup as a no-op to fulfill abstract class requirement
def setup(self, *args, **kwargs):
pass
def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
# Extract data
classification_mode, image, sensing_modality, predicted_class, correct_class = flag_data
# Save the uploaded image in the "uploaded_images" folder
image_filename = os.path.join(self.image_dir,
f"flagged_image_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.png")
image.save(image_filename) # Save image in PNG format
# Columns: Classification, Image Path, Sensing Modality, Predicted Class, Correct Class
data = {
"Classification Mode": classification_mode,
"Image Path": image_filename, # Save path to image in CSV
"Sensing Modality": sensing_modality,
"Predicted Class": predicted_class,
"Correct Class": correct_class,
}
df = pd.DataFrame([data])
csv_file = os.path.join(self.dir, "flagged_data.csv")
# Append to CSV, or create if it doesn't exist
if os.path.exists(csv_file):
df.to_csv(csv_file, mode='a', header=False, index=False)
else:
df.to_csv(csv_file, mode='w', header=True, index=False)
# Function to load the appropriate model based on the user's selection
def load_model(modality, mode):
# For Few-Shot classification, always use the DINOv2 model
if mode == "Few-Shot":
class Args:
model = 'DINOv2'
pretrained = 'pretrained'
frozen = 'unfrozen'
args = Args()
model = get_model(args) # Load DINOv2 model for Few-Shot classification
else:
# For Fully-Supervised classification, choose model based on the sensing modality
if modality == "Texture":
class Args:
model = 'DINOv2'
pretrained = 'pretrained'
frozen = 'unfrozen'
args = Args()
model = get_model(args) # Load DINOv2 model for Texture modality
elif modality == "Heightmap":
class Args:
model = 'ResNet152'
pretrained = 'pretrained'
frozen = 'unfrozen'
args = Args()
model = get_model(args) # Load ResNet152 model for Heightmap modality
else:
raise ValueError("Invalid modality selected!")
model.eval() # Set the model to evaluation mode
return model
# Prediction function that processes the image and returns the prediction results
def predict(image, modality, mode):
# Load the appropriate model based on the user's selections
model = load_model(modality, mode)
# Print the selected mode and modality for debugging purposes
print(f"User selected Mode: {mode}, Modality: {modality}")
# Preprocess the image
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(image_tensor) # Get model predictions
probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()
# Class names for the predictions
class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]
# Pair class names with their corresponding probabilities
predicted_class = class_names[probabilities.index(max(probabilities))] # Get the predicted class
results = {class_names[i]: probabilities[i] for i in range(len(class_names))}
return predicted_class, results # Return predicted class and probabilities
# Create the Gradio interface using gr.Blocks
def create_interface():
with gr.Blocks() as interface:
# Title at the top of the interface (centered and larger)
gr.Markdown("<h1 style='text-align: center; font-size: 36px;'>LUWA Dataset Image Classification</h1>")
# Add description for the interface
description = """
### Image Classification Options
- **Fully-Supervised Classification**: Choose this for common or well-known materials with plenty of data (e.g., bone, wood).
- **Few-Shot Classification**: Choose this for rare or newly discovered materials where only a few examples exist.
### **Don't forget to choose the Sensing Modality based on your uploaded images.**
### **Please help us to flag the correct class for your uploaded image if you know it, it will help us to further develop our dataset. If you cannot find the correct class in the option, please click on the option 'Other' and type the correct class for us!**
"""
gr.Markdown(description)
# Top-level selector for Fully-Supervised vs. Few-Shot classification
mode_selector = gr.Radio(choices=["Fully Supervised", "Few-Shot"], label="Classification Mode",
value="Fully Supervised")
# Sensing modality selector
modality_selector = gr.Radio(choices=["Texture", "Heightmap"], label="Sensing Modality", value="Texture")
# Image upload input
image_input = gr.Image(type="pil", label="Image")
# Predicted classification output and class probabilities
with gr.Row():
predicted_output = gr.Label(num_top_classes=1, label="Predicted Classification")
probabilities_output = gr.Label(label="Prediction Probabilities")
# Add the "Run Prediction" button under the Prediction Probabilities
predict_button = gr.Button("Run Prediction")
# Dropdown for user to select the correct class if the model prediction is wrong
correct_class_selector = gr.Radio(
choices=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD", "Other"],
label="Select Correct Class"
)
# Text box for user to type the correct class if "Other" is selected
other_class_input = gr.Textbox(label="If Other, enter the correct class", visible=False)
# Logic to dynamically update visibility of the "Other" class text box
def update_visibility(selected_class):
return gr.update(visible=selected_class == "Other")
correct_class_selector.change(fn=update_visibility, inputs=correct_class_selector, outputs=other_class_input)
# Create a flagging instance
flagging_instance = CustomFlagging(dir_name="flagged_data")
# Define function for the confirmation pop-up
def confirm_flag_selection(correct_class, other_class):
# Generate confirmation message
if correct_class == "Other":
message = f"Are you sure the class you selected is '{other_class}' for this picture?"
else:
message = f"Are you sure the class you selected is '{correct_class}' for this picture?"
return message, gr.update(visible=True), gr.update(visible=True)
# Final flag submission function
def flag_data_save(correct_class, other_class, mode, image, modality, predicted_class, confirmed):
if confirmed == "Yes":
# Save the flagged data
correct_class_final = correct_class if correct_class != "Other" else other_class
flagging_instance.flag([mode, image, modality, predicted_class, correct_class_final])
return "Flagged successfully!"
else:
return "No flag submitted, please select again."
# Flagging button
flag_button = gr.Button("Flag")
# Confirmation box for user input and confirmation flag
confirmation_text = gr.Textbox(visible=False)
yes_no_choice = gr.Radio(choices=["Yes", "No"], label="Are you sure?", visible=False)
confirmation_button = gr.Button("Confirm Flag", visible=False)
# Prediction action
predict_button.click(
fn=predict,
inputs=[image_input, modality_selector, mode_selector],
outputs=[predicted_output, probabilities_output]
)
# Flagging action with confirmation
flag_button.click(
fn=confirm_flag_selection,
inputs=[correct_class_selector, other_class_input],
outputs=[confirmation_text, yes_no_choice, confirmation_button]
)
# Final flag submission after confirmation
confirmation_button.click(
fn=flag_data_save,
inputs=[correct_class_selector, other_class_input, mode_selector, image_input, modality_selector,
predicted_output, yes_no_choice],
outputs=gr.Textbox(label="Flagging Status")
)
# Visitor count displayed at the bottom
visitor_count = update_visitor_count() # Update the visitor count
gr.Markdown(f"### The Number of Visitors since October 2024: {visitor_count}") # Display visitor count
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=True)