import gradio as gr import torch import torchvision import pandas as pd import os from PIL import Image from utils.experiment_utils import get_model # File to store the visitor count visitor_count_file = "visitor_count.txt" # Function to update visitor count def update_visitor_count(): if os.path.exists(visitor_count_file): with open(visitor_count_file, "r") as file: count = int(file.read()) else: count = 0 # Start from zero if no file exists # Increment visitor count count += 1 # Save the updated count back to the file with open(visitor_count_file, "w") as file: file.write(str(count)) return count # Custom flagging logic to save flagged data to a CSV file class CustomFlagging(gr.FlaggingCallback): def __init__(self, dir_name="flagged_data"): self.dir = dir_name self.image_dir = os.path.join(self.dir, "uploaded_images") if not os.path.exists(self.dir): os.makedirs(self.dir) if not os.path.exists(self.image_dir): os.makedirs(self.image_dir) # Define setup as a no-op to fulfill abstract class requirement def setup(self, *args, **kwargs): pass def flag(self, flag_data, flag_option=None, flag_index=None, username=None): # Extract data classification_mode, image, sensing_modality, predicted_class, correct_class = flag_data # Save the uploaded image in the "uploaded_images" folder image_filename = os.path.join(self.image_dir, f"flagged_image_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.png") image.save(image_filename) # Save image in PNG format # Columns: Classification, Image Path, Sensing Modality, Predicted Class, Correct Class data = { "Classification Mode": classification_mode, "Image Path": image_filename, # Save path to image in CSV "Sensing Modality": sensing_modality, "Predicted Class": predicted_class, "Correct Class": correct_class, } df = pd.DataFrame([data]) csv_file = os.path.join(self.dir, "flagged_data.csv") # Append to CSV, or create if it doesn't exist if os.path.exists(csv_file): df.to_csv(csv_file, mode='a', header=False, index=False) else: df.to_csv(csv_file, mode='w', header=True, index=False) # Function to load the appropriate model based on the user's selection def load_model(modality, mode): # For Few-Shot classification, always use the DINOv2 model if mode == "Few-Shot": class Args: model = 'DINOv2' pretrained = 'pretrained' frozen = 'unfrozen' args = Args() model = get_model(args) # Load DINOv2 model for Few-Shot classification else: # For Fully-Supervised classification, choose model based on the sensing modality if modality == "Texture": class Args: model = 'DINOv2' pretrained = 'pretrained' frozen = 'unfrozen' args = Args() model = get_model(args) # Load DINOv2 model for Texture modality elif modality == "Heightmap": class Args: model = 'ResNet152' pretrained = 'pretrained' frozen = 'unfrozen' args = Args() model = get_model(args) # Load ResNet152 model for Heightmap modality else: raise ValueError("Invalid modality selected!") model.eval() # Set the model to evaluation mode return model # Prediction function that processes the image and returns the prediction results def predict(image, modality, mode): # Load the appropriate model based on the user's selections model = load_model(modality, mode) # Print the selected mode and modality for debugging purposes print(f"User selected Mode: {mode}, Modality: {modality}") # Preprocess the image transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image_tensor = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): output = model(image_tensor) # Get model predictions probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist() # Class names for the predictions class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"] # Pair class names with their corresponding probabilities predicted_class = class_names[probabilities.index(max(probabilities))] # Get the predicted class results = {class_names[i]: probabilities[i] for i in range(len(class_names))} return predicted_class, results # Return predicted class and probabilities # Create the Gradio interface using gr.Blocks def create_interface(): with gr.Blocks() as interface: # Title at the top of the interface (centered and larger) gr.Markdown("

LUWA Dataset Image Classification

") # Add description for the interface description = """ ### Image Classification Options - **Fully-Supervised Classification**: Choose this for common or well-known materials with plenty of data (e.g., bone, wood). - **Few-Shot Classification**: Choose this for rare or newly discovered materials where only a few examples exist. ### **Don't forget to choose the Sensing Modality based on your uploaded images.** ### **Please help us to flag the correct class for your uploaded image if you know it, it will help us to further develop our dataset. If you cannot find the correct class in the option, please click on the option 'Other' and type the correct class for us!** """ gr.Markdown(description) # Top-level selector for Fully-Supervised vs. Few-Shot classification mode_selector = gr.Radio(choices=["Fully Supervised", "Few-Shot"], label="Classification Mode", value="Fully Supervised") # Sensing modality selector modality_selector = gr.Radio(choices=["Texture", "Heightmap"], label="Sensing Modality", value="Texture") # Image upload input image_input = gr.Image(type="pil", label="Image") # Predicted classification output and class probabilities with gr.Row(): predicted_output = gr.Label(num_top_classes=1, label="Predicted Classification") probabilities_output = gr.Label(label="Prediction Probabilities") # Add the "Run Prediction" button under the Prediction Probabilities predict_button = gr.Button("Run Prediction") # Dropdown for user to select the correct class if the model prediction is wrong correct_class_selector = gr.Radio( choices=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD", "Other"], label="Select Correct Class" ) # Text box for user to type the correct class if "Other" is selected other_class_input = gr.Textbox(label="If Other, enter the correct class", visible=False) # Logic to dynamically update visibility of the "Other" class text box def update_visibility(selected_class): return gr.update(visible=selected_class == "Other") correct_class_selector.change(fn=update_visibility, inputs=correct_class_selector, outputs=other_class_input) # Create a flagging instance flagging_instance = CustomFlagging(dir_name="flagged_data") # Define function for the confirmation pop-up def confirm_flag_selection(correct_class, other_class): # Generate confirmation message if correct_class == "Other": message = f"Are you sure the class you selected is '{other_class}' for this picture?" else: message = f"Are you sure the class you selected is '{correct_class}' for this picture?" return message, gr.update(visible=True), gr.update(visible=True) # Final flag submission function def flag_data_save(correct_class, other_class, mode, image, modality, predicted_class, confirmed): if confirmed == "Yes": # Save the flagged data correct_class_final = correct_class if correct_class != "Other" else other_class flagging_instance.flag([mode, image, modality, predicted_class, correct_class_final]) return "Flagged successfully!" else: return "No flag submitted, please select again." # Flagging button flag_button = gr.Button("Flag") # Confirmation box for user input and confirmation flag confirmation_text = gr.Textbox(visible=False) yes_no_choice = gr.Radio(choices=["Yes", "No"], label="Are you sure?", visible=False) confirmation_button = gr.Button("Confirm Flag", visible=False) # Prediction action predict_button.click( fn=predict, inputs=[image_input, modality_selector, mode_selector], outputs=[predicted_output, probabilities_output] ) # Flagging action with confirmation flag_button.click( fn=confirm_flag_selection, inputs=[correct_class_selector, other_class_input], outputs=[confirmation_text, yes_no_choice, confirmation_button] ) # Final flag submission after confirmation confirmation_button.click( fn=flag_data_save, inputs=[correct_class_selector, other_class_input, mode_selector, image_input, modality_selector, predicted_output, yes_no_choice], outputs=gr.Textbox(label="Flagging Status") ) # Visitor count displayed at the bottom visitor_count = update_visitor_count() # Update the visitor count gr.Markdown(f"### The Number of Visitors since October 2024: {visitor_count}") # Display visitor count return interface if __name__ == "__main__": interface = create_interface() interface.launch(share=True)