cat_dog-breed / app.py
shikharyashmaurya's picture
Update app.py
9745587 verified
raw
history blame
707 Bytes
from fastai.vision.all import *
import gradio as gr # Import Gradio directly
import timm
# Load your model
learn = load_learner('model.pkl')
categories = learn.dls.vocab
# Define your prediction function
def classify_image(img):
pred, idx, probs = learn.predict(img)
# Align with Gradio's formatting for outputs
return {category: prob for category, prob in zip(categories, probs)}
# Create the Gradio interface
image = gr.Image() # Use gr.Image directly for input
label = gr.Label(num_top_classes=len(categories)) # Adjust for multi-class output
interface = gr.Interface(
fn=classify_image, # Reference the function directly
inputs=image,
outputs=label
)
interface.launch()