|
import random |
|
|
|
import gradio as gr |
|
import torch |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
from torch import nn |
|
from torchvision.models import mobilenet_v2, resnet18 |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
datasets_n_classes = { |
|
"Imagenette": 10, |
|
"Imagewoof": 10, |
|
"Stanford_dogs": 120, |
|
} |
|
|
|
datasets_model_types = { |
|
"Imagenette": [ |
|
"base_200", |
|
"base_200+100", |
|
"synthetic_200", |
|
"augment_noisy_200", |
|
"augment_noisy_200+100", |
|
"augment_clean_200", |
|
], |
|
"Imagewoof": [ |
|
"base_200", |
|
"base_200+100", |
|
"synthetic_200", |
|
"augment_noisy_200", |
|
"augment_noisy_200+100", |
|
"augment_clean_200", |
|
], |
|
"Stanford_dogs": [ |
|
"base_200", |
|
"base_200+100", |
|
"synthetic_200", |
|
"augment_noisy_200", |
|
"augment_noisy_200+100", |
|
], |
|
} |
|
|
|
model_arch = ["resnet18", "mobilenet_v2"] |
|
|
|
list_200 = [ |
|
"Original", |
|
"Synthetic", |
|
"Original + Synthetic (Noisy)", |
|
"Original + Synthetic (Clean)", |
|
] |
|
|
|
list_200_100 = ["Base+100", "AugmentNoisy+100"] |
|
|
|
methods_map = { |
|
"200 Epochs": list_200, |
|
"200 Epochs on Original + 100": list_200_100, |
|
} |
|
|
|
label_map = dict() |
|
label_map["Imagenette (10 classes)"] = "Imagenette" |
|
label_map["Imagewoof (10 classes)"] = "Imagewoof" |
|
label_map["Stanford Dogs (120 classes)"] = "Stanford_dogs" |
|
label_map["ResNet-18"] = "resnet18" |
|
label_map["MobileNetV2"] = "mobilenet_v2" |
|
label_map["200 Epochs"] = "200" |
|
label_map["200 Epochs on Original + 100"] = "200+100" |
|
label_map["Original"] = "base" |
|
label_map["Synthetic"] = "synthetic" |
|
label_map["Original + Synthetic (Noisy)"] = "augment_noisy" |
|
label_map["Original + Synthetic (Clean)"] = "augment_clean" |
|
label_map["Base+100"] = "base" |
|
label_map["AugmentNoisy+100"] = "augment_noisy" |
|
|
|
dataset_models = dict() |
|
for dataset, n_classes in datasets_n_classes.items(): |
|
models = dict() |
|
for model_type in datasets_model_types[dataset]: |
|
for arch in model_arch: |
|
if arch == "resnet18": |
|
model = resnet18(weights=None, num_classes=n_classes) |
|
models[f"{arch}_{model_type}"] = ( |
|
model, |
|
f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth", |
|
) |
|
elif arch == "mobilenet_v2": |
|
model = mobilenet_v2(weights=None, num_classes=n_classes) |
|
models[f"{arch}_{model_type}"] = ( |
|
model, |
|
f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth", |
|
) |
|
else: |
|
raise ValueError(f"Model architecture unavailable: {arch}") |
|
dataset_models[dataset] = models |
|
|
|
|
|
def get_random_image(dataset, label_map=label_map) -> Image: |
|
dataset_root = f"./data/{label_map[dataset]}/val" |
|
dataset_img = torchvision.datasets.ImageFolder( |
|
dataset_root, |
|
transforms.Compose([transforms.PILToTensor()]), |
|
) |
|
random_idx = random.randint(0, len(dataset_img) - 1) |
|
image, _ = dataset_img[random_idx] |
|
image = transforms.ToPILImage()(image) |
|
image = image.resize( |
|
(256, 256), |
|
) |
|
return image |
|
|
|
|
|
def load_model(model_dict, model_name: str) -> nn.Module: |
|
model_name_lower = model_name.lower() |
|
if model_name_lower in model_dict: |
|
model = model_dict[model_name_lower][0] |
|
model_path = model_dict[model_name_lower][1] |
|
if torch.cuda.is_available(): |
|
checkpoint = torch.load(model_path) |
|
else: |
|
checkpoint = torch.load(model_path, map_location="cpu") |
|
if "setup" in checkpoint: |
|
if checkpoint["setup"]["distributed"]: |
|
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( |
|
checkpoint["model"], "module." |
|
) |
|
model.load_state_dict(checkpoint["model"]) |
|
else: |
|
model.load_state_dict(checkpoint) |
|
return model |
|
else: |
|
raise ValueError( |
|
f"Model {model_name} is not available for image prediction. Please choose from {[name.capitalize() for name in model_dict.keys()]}." |
|
) |
|
|
|
|
|
def postprocess_default(labels, output) -> dict: |
|
probabilities = nn.functional.softmax(output[0], dim=0) |
|
top_prob, top_catid = torch.topk(probabilities, 5) |
|
confidences = { |
|
labels[top_catid.tolist()[i]]: top_prob.tolist()[i] |
|
for i in range(top_prob.shape[0]) |
|
} |
|
return confidences |
|
|
|
|
|
def classify( |
|
input_image: Image, |
|
dataset_type: str, |
|
arch_type: str, |
|
methods: str, |
|
training_ds: str, |
|
dataset_models=dataset_models, |
|
label_map=label_map, |
|
) -> dict: |
|
for i in [dataset_type, arch_type, methods, training_ds]: |
|
if i is None: |
|
raise ValueError("Please select all options.") |
|
dataset_type = label_map[dataset_type] |
|
arch_type = label_map[arch_type] |
|
methods = label_map[methods] |
|
training_ds = label_map[training_ds] |
|
preprocess_input = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
256, |
|
interpolation=InterpolationMode.BILINEAR, |
|
antialias=True, |
|
), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
if input_image is None: |
|
raise ValueError("No image was provided.") |
|
input_tensor: torch.Tensor = preprocess_input(input_image) |
|
input_batch = input_tensor.unsqueeze(0) |
|
model = load_model( |
|
dataset_models[dataset_type], f"{arch_type}_{training_ds}_{methods}" |
|
) |
|
|
|
if torch.cuda.is_available(): |
|
input_batch = input_batch.to("cuda") |
|
model.to("cuda") |
|
|
|
model.eval() |
|
with torch.inference_mode(): |
|
output: torch.Tensor = model(input_batch) |
|
with open(f"./data/{dataset_type}.txt", "r") as f: |
|
labels = {i: line.strip() for i, line in enumerate(f.readlines())} |
|
return postprocess_default(labels, output) |
|
|
|
|
|
def update_methods(method, ds_type): |
|
if ds_type == "Stanford Dogs (120 classes)" and method == "200 Epochs": |
|
methods = list_200[:-1] |
|
else: |
|
methods = methods_map[method] |
|
return gr.update(choices=methods, value=None) |
|
|
|
|
|
def downloadModel( |
|
dataset_type, arch_type, methods, training_ds, dataset_models=dataset_models |
|
): |
|
for i in [dataset_type, arch_type, methods, training_ds]: |
|
if i is None: |
|
return gr.update(label="Select Model", value=None) |
|
dataset_type = label_map[dataset_type] |
|
arch_type = label_map[arch_type] |
|
methods = label_map[methods] |
|
training_ds = label_map[training_ds] |
|
if f"{arch_type}_{training_ds}_{methods}" not in dataset_models[dataset_type]: |
|
return gr.update(label="Select Model", value=None) |
|
model_path = dataset_models[dataset_type][f"{arch_type}_{training_ds}_{methods}"][1] |
|
return gr.update( |
|
label=f"Download Model: '{dataset_type}_{arch_type}_{training_ds}_{methods}'", |
|
value=model_path, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks(title="Generative Augmented Image Classifiers") as demo: |
|
gr.Markdown( |
|
""" |
|
# Generative Augmented Image Classifiers |
|
Main GitHub Repo: [Generative Data Augmentation](https://github.com/zhulinchng/generative-data-augmentation) | Generative Data Augmentation Demo: [Generative Data Augmented](https://huggingface.co/spaces/czl/generative-data-augmentation-demo). |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
dataset_type = gr.Radio( |
|
choices=[ |
|
"Imagenette (10 classes)", |
|
"Imagewoof (10 classes)", |
|
"Stanford Dogs (120 classes)", |
|
], |
|
label="Dataset", |
|
value="Imagenette (10 classes)", |
|
) |
|
arch_type = gr.Radio( |
|
choices=["ResNet-18", "MobileNetV2"], |
|
label="Model Architecture", |
|
value="ResNet-18", |
|
interactive=True, |
|
) |
|
methods = gr.Radio( |
|
label="Methods", |
|
choices=["200 Epochs", "200 Epochs on Original + 100"], |
|
interactive=True, |
|
value="200 Epochs", |
|
) |
|
training_ds = gr.Radio( |
|
label="Training Dataset", |
|
choices=methods_map["200 Epochs"], |
|
interactive=True, |
|
value="Original", |
|
) |
|
dataset_type.change( |
|
fn=update_methods, |
|
inputs=[methods, dataset_type], |
|
outputs=[training_ds], |
|
) |
|
methods.change( |
|
fn=update_methods, |
|
inputs=[methods, dataset_type], |
|
outputs=[training_ds], |
|
) |
|
random_image_output = gr.Image(type="pil", label="Image to Classify") |
|
with gr.Row(): |
|
generate_button = gr.Button("Sample Random Image") |
|
classify_button_random = gr.Button("Classify") |
|
with gr.Column(): |
|
output_label_random = gr.Label(num_top_classes=5) |
|
download_model = gr.DownloadButton( |
|
label=f"Download Model: '{label_map[dataset_type.value]}_{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}'", |
|
value=dataset_models[label_map[dataset_type.value]][ |
|
f"{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}" |
|
][1], |
|
) |
|
dataset_type.change( |
|
fn=downloadModel, |
|
inputs=[dataset_type, arch_type, methods, training_ds], |
|
outputs=[download_model], |
|
) |
|
arch_type.change( |
|
fn=downloadModel, |
|
inputs=[dataset_type, arch_type, methods, training_ds], |
|
outputs=[download_model], |
|
) |
|
methods.change( |
|
fn=downloadModel, |
|
inputs=[dataset_type, arch_type, methods, training_ds], |
|
outputs=[download_model], |
|
) |
|
training_ds.change( |
|
fn=downloadModel, |
|
inputs=[dataset_type, arch_type, methods, training_ds], |
|
outputs=[download_model], |
|
) |
|
gr.Markdown( |
|
""" |
|
This demo showcases the performance of image classifiers trained on various datasets as part of the project 'Improving Fine-Grained Image Classification Using Diffusion-Based Generated Synthetic Images' dissertation. |
|
|
|
View the models and files used in this demo [here](https://huggingface.co/spaces/czl/generative-augmented-classifiers/tree/main). |
|
|
|
Usage Instructions & Documentation [here](https://huggingface.co/spaces/czl/generative-augmented-classifiers/blob/main/README.md). |
|
""" |
|
) |
|
|
|
generate_button.click( |
|
get_random_image, |
|
inputs=[dataset_type], |
|
outputs=random_image_output, |
|
) |
|
classify_button_random.click( |
|
classify, |
|
inputs=[random_image_output, dataset_type, arch_type, methods, training_ds], |
|
outputs=output_label_random, |
|
) |
|
demo.launch(show_error=True) |
|
|