import random import gradio as gr import torch import torchvision import torchvision.transforms as transforms from PIL import Image from torch import nn from torchvision.models import mobilenet_v2, resnet18 from torchvision.transforms.functional import InterpolationMode datasets_n_classes = { "Imagenette": 10, "Imagewoof": 10, "Stanford_dogs": 120, } datasets_model_types = { "Imagenette": [ "base_200", "base_200+100", "synthetic_200", "augment_noisy_200", "augment_noisy_200+100", "augment_clean_200", ], "Imagewoof": [ "base_200", "base_200+100", "synthetic_200", "augment_noisy_200", "augment_noisy_200+100", "augment_clean_200", ], "Stanford_dogs": [ "base_200", "base_200+100", "synthetic_200", "augment_noisy_200", "augment_noisy_200+100", ], } model_arch = ["resnet18", "mobilenet_v2"] list_200 = [ "Original", "Synthetic", "Original + Synthetic (Noisy)", "Original + Synthetic (Clean)", ] list_200_100 = ["Base+100", "AugmentNoisy+100"] methods_map = { "200 Epochs": list_200, "200 Epochs on Original + 100": list_200_100, } label_map = dict() label_map["Imagenette (10 classes)"] = "Imagenette" label_map["Imagewoof (10 classes)"] = "Imagewoof" label_map["Stanford Dogs (120 classes)"] = "Stanford_dogs" label_map["ResNet-18"] = "resnet18" label_map["MobileNetV2"] = "mobilenet_v2" label_map["200 Epochs"] = "200" label_map["200 Epochs on Original + 100"] = "200+100" label_map["Original"] = "base" label_map["Synthetic"] = "synthetic" label_map["Original + Synthetic (Noisy)"] = "augment_noisy" label_map["Original + Synthetic (Clean)"] = "augment_clean" label_map["Base+100"] = "base" label_map["AugmentNoisy+100"] = "augment_noisy" dataset_models = dict() for dataset, n_classes in datasets_n_classes.items(): models = dict() for model_type in datasets_model_types[dataset]: for arch in model_arch: if arch == "resnet18": model = resnet18(weights=None, num_classes=n_classes) models[f"{arch}_{model_type}"] = ( model, f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth", ) elif arch == "mobilenet_v2": model = mobilenet_v2(weights=None, num_classes=n_classes) models[f"{arch}_{model_type}"] = ( model, f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth", ) else: raise ValueError(f"Model architecture unavailable: {arch}") dataset_models[dataset] = models def get_random_image(dataset, label_map=label_map) -> Image: dataset_root = f"./data/{label_map[dataset]}/val" dataset_img = torchvision.datasets.ImageFolder( dataset_root, transforms.Compose([transforms.PILToTensor()]), ) random_idx = random.randint(0, len(dataset_img) - 1) image, _ = dataset_img[random_idx] image = transforms.ToPILImage()(image) image = image.resize( (256, 256), ) return image def load_model(model_dict, model_name: str) -> nn.Module: model_name_lower = model_name.lower() if model_name_lower in model_dict: model = model_dict[model_name_lower][0] model_path = model_dict[model_name_lower][1] checkpoint = torch.load(model_path) if "setup" in checkpoint: if checkpoint["setup"]["distributed"]: torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( checkpoint["model"], "module." ) model.load_state_dict(checkpoint["model"]) else: model.load_state_dict(checkpoint) return model else: raise ValueError( f"Model {model_name} is not available for image prediction. Please choose from {[name.capitalize() for name in model_dict.keys()]}." ) def postprocess_default(labels, output) -> dict: probabilities = nn.functional.softmax(output[0], dim=0) top_prob, top_catid = torch.topk(probabilities, 5) confidences = { labels[top_catid.tolist()[i]]: top_prob.tolist()[i] for i in range(top_prob.shape[0]) } return confidences def classify( input_image: Image, dataset_type: str, arch_type: str, methods: str, training_ds: str, dataset_models=dataset_models, label_map=label_map, ) -> dict: for i in [dataset_type, arch_type, methods, training_ds]: if i is None: raise ValueError("Please select all options.") dataset_type = label_map[dataset_type] arch_type = label_map[arch_type] methods = label_map[methods] training_ds = label_map[training_ds] preprocess_input = transforms.Compose( [ transforms.Resize( 256, interpolation=InterpolationMode.BILINEAR, antialias=True, ), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) if input_image is None: raise ValueError("No image was provided.") input_tensor: torch.Tensor = preprocess_input(input_image) input_batch = input_tensor.unsqueeze(0) model = load_model( dataset_models[dataset_type], f"{arch_type}_{training_ds}_{methods}" ) if torch.cuda.is_available(): input_batch = input_batch.to("cuda") model.to("cuda") model.eval() with torch.inference_mode(): output: torch.Tensor = model(input_batch) with open(f"./data/{dataset_type}.txt", "r") as f: labels = {i: line.strip() for i, line in enumerate(f.readlines())} return postprocess_default(labels, output) def update_methods(method, ds_type): if ds_type == "Stanford Dogs (120 classes)" and method == "200 Epochs": methods = list_200[:-1] else: methods = methods_map[method] return gr.update(choices=methods, value=None) def downloadModel( dataset_type, arch_type, methods, training_ds, dataset_models=dataset_models ): for i in [dataset_type, arch_type, methods, training_ds]: if i is None: return gr.update(label="Select Model", value=None) dataset_type = label_map[dataset_type] arch_type = label_map[arch_type] methods = label_map[methods] training_ds = label_map[training_ds] if f"{arch_type}_{training_ds}_{methods}" not in dataset_models[dataset_type]: return gr.update(label="Select Model", value=None) model_path = dataset_models[dataset_type][f"{arch_type}_{training_ds}_{methods}"][1] return gr.update( label=f"Download Model: '{dataset_type}_{arch_type}_{training_ds}_{methods}'", value=model_path, ) if __name__ == "__main__": with gr.Blocks(title="Generative Augmented Image Classifiers") as demo: gr.Markdown( """ # Generative Augmented Image Classifiers This demo showcases the performance of image classifiers trained on various datasets. """ ) with gr.Row(): with gr.Column(): dataset_type = gr.Radio( choices=[ "Imagenette (10 classes)", "Imagewoof (10 classes)", "Stanford Dogs (120 classes)", ], label="Dataset", value="Imagenette (10 classes)", ) arch_type = gr.Radio( choices=["ResNet-18", "MobileNetV2"], label="Model Architecture", value="ResNet-18", interactive=True, ) methods = gr.Radio( label="Methods", choices=["200 Epochs", "200 Epochs on Original + 100"], interactive=True, value="200 Epochs", ) training_ds = gr.Radio( label="Training Dataset", choices=methods_map["200 Epochs"], interactive=True, value="Original", ) dataset_type.change( fn=update_methods, inputs=[methods, dataset_type], outputs=[training_ds], ) methods.change( fn=update_methods, inputs=[methods, dataset_type], outputs=[training_ds], ) generate_button = gr.Button("Sample Random Image") random_image_output = gr.Image( type="pil", label="Random Image from Validation Set" ) classify_button_random = gr.Button("Classify") with gr.Column(): output_label_random = gr.Label(num_top_classes=5) download_model = gr.DownloadButton( label=f"Download Model: '{label_map[dataset_type.value]}_{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}'", value=dataset_models[label_map[dataset_type.value]][ f"{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}" ][1], ) dataset_type.change( fn=downloadModel, inputs=[dataset_type, arch_type, methods, training_ds], outputs=[download_model], ) arch_type.change( fn=downloadModel, inputs=[dataset_type, arch_type, methods, training_ds], outputs=[download_model], ) methods.change( fn=downloadModel, inputs=[dataset_type, arch_type, methods, training_ds], outputs=[download_model], ) training_ds.change( fn=downloadModel, inputs=[dataset_type, arch_type, methods, training_ds], outputs=[download_model], ) generate_button.click( get_random_image, inputs=[dataset_type], outputs=random_image_output, ) classify_button_random.click( classify, inputs=[random_image_output, dataset_type, arch_type, methods, training_ds], outputs=output_label_random, ) demo.launch(show_error=True)