import gradio as gr import torch import torchvision import pandas as pd import os from PIL import Image from utils.experiment_utils import get_model # File to store the visitor count visitor_count_file = "visitor_count.txt" # Function to update visitor count def update_visitor_count(): if os.path.exists(visitor_count_file): with open(visitor_count_file, "r") as file: count = int(file.read()) else: count = 0 # Start from zero if no file exists # Increment visitor count count += 1 # Save the updated count back to the file with open(visitor_count_file, "w") as file: file.write(str(count)) return count # Custom flagging logic to save flagged data to a CSV file class CustomFlagging(gr.FlaggingCallback): def __init__(self, dir_name="flagged_data"): self.dir = dir_name self.image_dir = os.path.join(self.dir, "uploaded_images") if not os.path.exists(self.dir): os.makedirs(self.dir) if not os.path.exists(self.image_dir): os.makedirs(self.image_dir) # Define setup as a no-op to fulfill abstract class requirement def setup(self, *args, **kwargs): pass def flag(self, flag_data, flag_option=None, flag_index=None, username=None): # Extract data classification_mode, image, sensing_modality, predicted_class, correct_class = flag_data # Save the uploaded image in the "uploaded_images" folder image_filename = os.path.join(self.image_dir, f"flagged_image_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.png") image.save(image_filename) # Save image in PNG format # Columns: Classification, Image Path, Sensing Modality, Predicted Class, Correct Class data = { "Classification Mode": classification_mode, "Image Path": image_filename, # Save path to image in CSV "Sensing Modality": sensing_modality, "Predicted Class": predicted_class, "Correct Class": correct_class, } df = pd.DataFrame([data]) csv_file = os.path.join(self.dir, "flagged_data.csv") # Append to CSV, or create if it doesn't exist if os.path.exists(csv_file): df.to_csv(csv_file, mode='a', header=False, index=False) else: df.to_csv(csv_file, mode='w', header=True, index=False) # Function to load the appropriate model based on the user's selection def load_model(modality, mode): # For Few-Shot classification, always use the DINOv2 model if mode == "Few-Shot": class Args: model = 'DINOv2' pretrained = 'pretrained' frozen = 'unfrozen' args = Args() model = get_model(args) # Load DINOv2 model for Few-Shot classification else: # For Fully-Supervised classification, choose model based on the sensing modality if modality == "Texture": class Args: model = 'DINOv2' pretrained = 'pretrained' frozen = 'unfrozen' args = Args() model = get_model(args) # Load DINOv2 model for Texture modality elif modality == "Heightmap": class Args: model = 'ResNet152' pretrained = 'pretrained' frozen = 'unfrozen' args = Args() model = get_model(args) # Load ResNet152 model for Heightmap modality else: raise ValueError("Invalid modality selected!") model.eval() # Set the model to evaluation mode return model # Prediction function that processes the image and returns the prediction results def predict(image, modality, mode): # Load the appropriate model based on the user's selections model = load_model(modality, mode) # Print the selected mode and modality for debugging purposes print(f"User selected Mode: {mode}, Modality: {modality}") # Preprocess the image transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image_tensor = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): output = model(image_tensor) # Get model predictions probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist() # Class names for the predictions class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"] # Pair class names with their corresponding probabilities predicted_class = class_names[probabilities.index(max(probabilities))] # Get the predicted class results = {class_names[i]: probabilities[i] for i in range(len(class_names))} return predicted_class, results # Return predicted class and probabilities # Create the Gradio interface using gr.Blocks def create_interface(): with gr.Blocks() as interface: # Title at the top of the interface (centered and larger) gr.Markdown("