Spaces:
Runtime error
Runtime error
from huggingface_hub import from_pretrained_keras | |
import tensorflow as tf | |
import gradio as gr | |
# download the model in the global context | |
vis_model = from_pretrained_keras("ariG23498/involution") | |
def infer(test_image): | |
# convert the image to a tensorflow tensor and resize the image | |
# to a constant 32x32 | |
image = tf.constant(test_image) | |
image = tf.image.resize(image, (32, 32)) | |
# Use the model and get the activation maps | |
(inv1_out, inv2_out, inv3_out) = vis_model.predict(image[None, ...]) | |
_, inv1_kernel = inv1_out | |
_, inv2_kernel = inv2_out | |
_, inv3_kernel = inv3_out | |
inv1_kernel = tf.reduce_sum(inv1_kernel, axis=[-1, -2, -3]) | |
inv2_kernel = tf.reduce_sum(inv2_kernel, axis=[-1, -2, -3]) | |
inv3_kernel = tf.reduce_sum(inv3_kernel, axis=[-1, -2, -3]) | |
return ( | |
tf.keras.utils.array_to_img(inv1_kernel[0, ..., None]), | |
tf.keras.utils.array_to_img(inv2_kernel[0, ..., None]), | |
tf.keras.utils.array_to_img(inv3_kernel[0, ..., None]), | |
) | |
# define the article | |
article = """<center> | |
Authors: <a href='https://twitter.com/ariG23498' target='_blank'>Aritra Roy Gosthipaty</a> | | |
<a href='https://twitter.com/ritwik_raha' target='_blank'>Ritwik Raha</a> | |
<br> | |
<a href='https://arxiv.org/abs/2103.06255' target='_blank'>Involution: Inverting the Inherence of Convolution for Visual Recognition</a> | |
<br> | |
Convolution Kernel | |
<img src='https://i.imgur.com/Y7xVrwb.png' alt='Convolution'> | |
<br> | |
Involution Kernel | |
<img src='https://i.imgur.com/jHIW26g.png' alt='Involution'> | |
</center>""" | |
# define the description | |
description="""<center> | |
Visualize the activation maps from the Involution Kernel. | |
</center> | |
""" | |
iface = gr.Interface( | |
fn=infer, | |
title="Involutional Neural Networks", | |
article=article, | |
description=description, | |
inputs=gr.inputs.Image(label="Input Image"), | |
outputs=[ | |
gr.outputs.Image(label="Activation from Kernel 1"), | |
gr.outputs.Image(label="Activation from Kernel 2"), | |
gr.outputs.Image(label="Activation from Kernel 3"), | |
], | |
examples=[["examples/lama.jpeg"], ["examples/dalai_lama.jpeg"]], | |
layout="horizontal", | |
).launch(share=True) |