|
import random
|
|
|
|
import gradio as gr
|
|
import torch
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
from torch import nn
|
|
from torchvision.models import mobilenet_v2, resnet18
|
|
from torchvision.transforms.functional import InterpolationMode
|
|
|
|
datasets_n_classes = {
|
|
"Imagenette": 10,
|
|
"Imagewoof": 10,
|
|
"Stanford_dogs": 120,
|
|
}
|
|
|
|
datasets_model_types = {
|
|
"Imagenette": [
|
|
"base_200",
|
|
"base_200+100",
|
|
"synthetic_200",
|
|
"augment_noisy_200",
|
|
"augment_noisy_200+100",
|
|
"augment_clean_200",
|
|
],
|
|
"Imagewoof": [
|
|
"base_200",
|
|
"base_200+100",
|
|
"synthetic_200",
|
|
"augment_noisy_200",
|
|
"augment_noisy_200+100",
|
|
"augment_clean_200",
|
|
],
|
|
"Stanford_dogs": [
|
|
"base_200",
|
|
"base_200+100",
|
|
"synthetic_200",
|
|
"augment_noisy_200",
|
|
"augment_noisy_200+100",
|
|
],
|
|
}
|
|
|
|
model_arch = ["resnet18", "mobilenet_v2"]
|
|
|
|
list_200 = [
|
|
"Original",
|
|
"Synthetic",
|
|
"Original + Synthetic (Noisy)",
|
|
"Original + Synthetic (Clean)",
|
|
]
|
|
|
|
list_200_100 = ["Base+100", "AugmentNoisy+100"]
|
|
|
|
methods_map = {
|
|
"200 Epochs": list_200,
|
|
"200 Epochs on Original + 100": list_200_100,
|
|
}
|
|
|
|
label_map = dict()
|
|
label_map["Imagenette (10 classes)"] = "Imagenette"
|
|
label_map["Imagewoof (10 classes)"] = "Imagewoof"
|
|
label_map["Stanford Dogs (120 classes)"] = "Stanford_dogs"
|
|
label_map["ResNet-18"] = "resnet18"
|
|
label_map["MobileNetV2"] = "mobilenet_v2"
|
|
label_map["200 Epochs"] = "200"
|
|
label_map["200 Epochs on Original + 100"] = "200+100"
|
|
label_map["Original"] = "base"
|
|
label_map["Synthetic"] = "synthetic"
|
|
label_map["Original + Synthetic (Noisy)"] = "augment_noisy"
|
|
label_map["Original + Synthetic (Clean)"] = "augment_clean"
|
|
label_map["Base+100"] = "base"
|
|
label_map["AugmentNoisy+100"] = "augment_noisy"
|
|
|
|
dataset_models = dict()
|
|
for dataset, n_classes in datasets_n_classes.items():
|
|
models = dict()
|
|
for model_type in datasets_model_types[dataset]:
|
|
for arch in model_arch:
|
|
if arch == "resnet18":
|
|
model = resnet18(weights=None, num_classes=n_classes)
|
|
models[f"{arch}_{model_type}"] = (
|
|
model,
|
|
f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth",
|
|
)
|
|
elif arch == "mobilenet_v2":
|
|
model = mobilenet_v2(weights=None, num_classes=n_classes)
|
|
models[f"{arch}_{model_type}"] = (
|
|
model,
|
|
f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth",
|
|
)
|
|
else:
|
|
raise ValueError(f"Model architecture unavailable: {arch}")
|
|
dataset_models[dataset] = models
|
|
|
|
|
|
def get_random_image(dataset, label_map=label_map) -> Image:
|
|
dataset_root = f"./data/{label_map[dataset]}/val"
|
|
dataset_img = torchvision.datasets.ImageFolder(
|
|
dataset_root,
|
|
transforms.Compose([transforms.PILToTensor()]),
|
|
)
|
|
random_idx = random.randint(0, len(dataset_img) - 1)
|
|
image, _ = dataset_img[random_idx]
|
|
image = transforms.ToPILImage()(image)
|
|
image = image.resize(
|
|
(256, 256),
|
|
)
|
|
return image
|
|
|
|
|
|
def load_model(model_dict, model_name: str) -> nn.Module:
|
|
model_name_lower = model_name.lower()
|
|
if model_name_lower in model_dict:
|
|
model = model_dict[model_name_lower][0]
|
|
model_path = model_dict[model_name_lower][1]
|
|
checkpoint = torch.load(model_path)
|
|
if "setup" in checkpoint:
|
|
if checkpoint["setup"]["distributed"]:
|
|
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
|
|
checkpoint["model"], "module."
|
|
)
|
|
model.load_state_dict(checkpoint["model"])
|
|
else:
|
|
model.load_state_dict(checkpoint)
|
|
return model
|
|
else:
|
|
raise ValueError(
|
|
f"Model {model_name} is not available for image prediction. Please choose from {[name.capitalize() for name in model_dict.keys()]}."
|
|
)
|
|
|
|
|
|
def postprocess_default(labels, output) -> dict:
|
|
probabilities = nn.functional.softmax(output[0], dim=0)
|
|
top_prob, top_catid = torch.topk(probabilities, 5)
|
|
confidences = {
|
|
labels[top_catid.tolist()[i]]: top_prob.tolist()[i]
|
|
for i in range(top_prob.shape[0])
|
|
}
|
|
return confidences
|
|
|
|
|
|
def classify(
|
|
input_image: Image,
|
|
dataset_type: str,
|
|
arch_type: str,
|
|
methods: str,
|
|
training_ds: str,
|
|
dataset_models=dataset_models,
|
|
label_map=label_map,
|
|
) -> dict:
|
|
for i in [dataset_type, arch_type, methods, training_ds]:
|
|
if i is None:
|
|
raise ValueError("Please select all options.")
|
|
dataset_type = label_map[dataset_type]
|
|
arch_type = label_map[arch_type]
|
|
methods = label_map[methods]
|
|
training_ds = label_map[training_ds]
|
|
preprocess_input = transforms.Compose(
|
|
[
|
|
transforms.Resize(
|
|
256,
|
|
interpolation=InterpolationMode.BILINEAR,
|
|
antialias=True,
|
|
),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
]
|
|
)
|
|
if input_image is None:
|
|
raise ValueError("No image was provided.")
|
|
input_tensor: torch.Tensor = preprocess_input(input_image)
|
|
input_batch = input_tensor.unsqueeze(0)
|
|
model = load_model(
|
|
dataset_models[dataset_type], f"{arch_type}_{training_ds}_{methods}"
|
|
)
|
|
|
|
if torch.cuda.is_available():
|
|
input_batch = input_batch.to("cuda")
|
|
model.to("cuda")
|
|
|
|
model.eval()
|
|
with torch.inference_mode():
|
|
output: torch.Tensor = model(input_batch)
|
|
with open(f"./data/{dataset_type}.txt", "r") as f:
|
|
labels = {i: line.strip() for i, line in enumerate(f.readlines())}
|
|
return postprocess_default(labels, output)
|
|
|
|
|
|
def update_methods(method, ds_type):
|
|
if ds_type == "Stanford Dogs (120 classes)" and method == "200 Epochs":
|
|
methods = list_200[:-1]
|
|
else:
|
|
methods = methods_map[method]
|
|
return gr.update(choices=methods, value=None)
|
|
|
|
|
|
def downloadModel(
|
|
dataset_type, arch_type, methods, training_ds, dataset_models=dataset_models
|
|
):
|
|
for i in [dataset_type, arch_type, methods, training_ds]:
|
|
if i is None:
|
|
return gr.update(label="Select Model", value=None)
|
|
dataset_type = label_map[dataset_type]
|
|
arch_type = label_map[arch_type]
|
|
methods = label_map[methods]
|
|
training_ds = label_map[training_ds]
|
|
if f"{arch_type}_{training_ds}_{methods}" not in dataset_models[dataset_type]:
|
|
return gr.update(label="Select Model", value=None)
|
|
model_path = dataset_models[dataset_type][f"{arch_type}_{training_ds}_{methods}"][1]
|
|
return gr.update(
|
|
label=f"Download Model: '{dataset_type}_{arch_type}_{training_ds}_{methods}'",
|
|
value=model_path,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
with gr.Blocks(title="Generative Augmented Image Classifiers") as demo:
|
|
gr.Markdown(
|
|
"""
|
|
# Generative Augmented Image Classifiers
|
|
This demo showcases the performance of image classifiers trained on various datasets.
|
|
"""
|
|
)
|
|
with gr.Row():
|
|
with gr.Column():
|
|
dataset_type = gr.Radio(
|
|
choices=[
|
|
"Imagenette (10 classes)",
|
|
"Imagewoof (10 classes)",
|
|
"Stanford Dogs (120 classes)",
|
|
],
|
|
label="Dataset",
|
|
value="Imagenette (10 classes)",
|
|
)
|
|
arch_type = gr.Radio(
|
|
choices=["ResNet-18", "MobileNetV2"],
|
|
label="Model Architecture",
|
|
value="ResNet-18",
|
|
interactive=True,
|
|
)
|
|
methods = gr.Radio(
|
|
label="Methods",
|
|
choices=["200 Epochs", "200 Epochs on Original + 100"],
|
|
interactive=True,
|
|
value="200 Epochs",
|
|
)
|
|
training_ds = gr.Radio(
|
|
label="Training Dataset",
|
|
choices=methods_map["200 Epochs"],
|
|
interactive=True,
|
|
value="Original",
|
|
)
|
|
dataset_type.change(
|
|
fn=update_methods,
|
|
inputs=[methods, dataset_type],
|
|
outputs=[training_ds],
|
|
)
|
|
methods.change(
|
|
fn=update_methods,
|
|
inputs=[methods, dataset_type],
|
|
outputs=[training_ds],
|
|
)
|
|
generate_button = gr.Button("Sample Random Image")
|
|
random_image_output = gr.Image(
|
|
type="pil", label="Random Image from Validation Set"
|
|
)
|
|
classify_button_random = gr.Button("Classify")
|
|
with gr.Column():
|
|
output_label_random = gr.Label(num_top_classes=5)
|
|
download_model = gr.DownloadButton(
|
|
label=f"Download Model: '{label_map[dataset_type.value]}_{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}'",
|
|
value=dataset_models[label_map[dataset_type.value]][
|
|
f"{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}"
|
|
][1],
|
|
)
|
|
dataset_type.change(
|
|
fn=downloadModel,
|
|
inputs=[dataset_type, arch_type, methods, training_ds],
|
|
outputs=[download_model],
|
|
)
|
|
arch_type.change(
|
|
fn=downloadModel,
|
|
inputs=[dataset_type, arch_type, methods, training_ds],
|
|
outputs=[download_model],
|
|
)
|
|
methods.change(
|
|
fn=downloadModel,
|
|
inputs=[dataset_type, arch_type, methods, training_ds],
|
|
outputs=[download_model],
|
|
)
|
|
training_ds.change(
|
|
fn=downloadModel,
|
|
inputs=[dataset_type, arch_type, methods, training_ds],
|
|
outputs=[download_model],
|
|
)
|
|
|
|
generate_button.click(
|
|
get_random_image,
|
|
inputs=[dataset_type],
|
|
outputs=random_image_output,
|
|
)
|
|
classify_button_random.click(
|
|
classify,
|
|
inputs=[random_image_output, dataset_type, arch_type, methods, training_ds],
|
|
outputs=output_label_random,
|
|
)
|
|
demo.launch(show_error=True)
|
|
|