Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from constants import MAKES_MODELS, PRICE_BIN_LABELS | |
model = torch.jit.load("mobilenetv2_432000_calib.pt") | |
model.eval() | |
transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
(0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
def classify(img: Image.Image): | |
in_tensor = transform(img)[None] | |
outputs = model(in_tensor) | |
price_bins = dict( | |
zip(PRICE_BIN_LABELS, F.softmax(outputs["price_bin"], dim=-1)[0].tolist()) | |
) | |
make_models = dict( | |
zip( | |
([f"{make} {model}" for make, model in MAKES_MODELS] + ["Unknown"]), | |
F.softmax(outputs["make_model"], dim=-1)[0].tolist(), | |
) | |
) | |
return f"${int(round(outputs['price_median'].item()))}", price_bins, make_models | |
iface = gr.Interface( | |
fn=classify, | |
inputs=gr.Image(shape=(224, 224), type="pil"), | |
outputs=[ | |
gr.Text(label="Price Prediction"), | |
gr.Label(label="Price Bin"), | |
gr.Label(label="Make/Model"), | |
], | |
) | |
iface.launch() | |