cfe-gen / app.py
anindya-hf-2002's picture
updated app.py
a9c06f8 verified
import gradio as gr
from src.inference import load_classifier, load_model, generate_images, convert_into_image, classify_image
from src.models import ResUNetGenerator
from src.explainer import GradCAM, preprocess_image
# Loading Models
classifier_path = './models/efficientnet_b1-epoch16-val_loss0.46_ft.ckpt'
g_NP_checkpoint = './models/g_NP_best.ckpt'
g_PN_checkpoint = './models/g_PN_best.ckpt'
g_NP = load_model(g_NP_checkpoint, ResUNetGenerator(gf=32, channels=1))
g_PN = load_model(g_PN_checkpoint, ResUNetGenerator(gf=32, channels=1))
classifier = load_classifier(classifier_path)
target_layer = classifier.model.features[-1]
grad_cam = GradCAM(classifier, target_layer)
def counterfactual_generation(input_image):
translated_images, recon_images = generate_images(input_image, classifier, g_PN, g_NP)
translated_images = convert_into_image(translated_images)
recon_images = convert_into_image(recon_images)
return translated_images, recon_images
def image_classification(input_image):
result, target_class = classify_image(input_image, classifier=classifier)
input_tensor = preprocess_image(input_image)
cam = grad_cam.generate_cam(input_tensor, target_class)
cam_image = grad_cam.visualize_cam(cam, input_tensor)
return result, cam_image
# Defining the components
inputs1 = gr.Image(type="pil", format="png")
inputs2 = gr.Image(type="pil", format="png")
outputs1 = [gr.Image(type="pil", label="Translated Images", format="png"),
gr.Image(type="pil", label="Reconstructed Images", format="png")]
outputs2 = [gr.Label(label="Classification Result"), gr.Image(label="Grad-CAM", format="png")]
with gr.Blocks() as demo:
with gr.Tab("Counterfactual Generation"):
app1 = gr.Interface(fn=counterfactual_generation, inputs=inputs1, outputs=outputs1,
title="Counterfactual Image Generation", allow_flagging="never",
description="Generate counterfactual images to explain the classifier's decisions.")
with gr.Tab("Classification"):
app2 = gr.Interface(fn=image_classification, inputs=inputs2, outputs=outputs2,
title="Image Classification", allow_flagging="never",
description="Classify the input medical image and visualize Grad-CAM.")
# Launch the app
demo.launch(share=True)