Spaces:
Runtime error
Runtime error
import io | |
import gradio as gr | |
import requests | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from constants import MAKES_MODELS, PRICE_BIN_LABELS, YEARS | |
print("downloading checkpoint...") | |
data = requests.get( | |
"https://data.aqnichol.com/car-data/models/mobilenetv2_432000_calib_torchscript.pt", | |
stream=True, | |
).content | |
print("creating model...") | |
model = torch.jit.load(io.BytesIO(data)) | |
model.eval() | |
transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
(0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
print("done.") | |
def classify(img: Image.Image): | |
in_tensor = transform(img)[None] | |
outputs = model(in_tensor) | |
price_bins = dict( | |
zip(PRICE_BIN_LABELS, F.softmax(outputs["price_bin"], dim=-1)[0].tolist()) | |
) | |
years = dict( | |
zip( | |
[str(year) for year in YEARS] + ["Unknown"], | |
F.softmax(outputs["year"], dim=-1)[0].tolist(), | |
) | |
) | |
make_models = dict( | |
zip( | |
([f"{make} {model}" for make, model in MAKES_MODELS] + ["Unknown"]), | |
F.softmax(outputs["make_model"], dim=-1)[0].tolist(), | |
) | |
) | |
return ( | |
f"${int(round(outputs['price_median'].item()))}", | |
price_bins, | |
years, | |
make_models, | |
img, | |
) | |
iface = gr.Interface( | |
fn=classify, | |
inputs=gr.Image(shape=(224, 224), type="pil"), | |
outputs=[ | |
gr.Text(label="Price Prediction"), | |
gr.Label(label="Price Bin", num_top_classes=5), | |
gr.Label(label="Year", num_top_classes=5), | |
gr.Label(label="Make/Model", num_top_classes=10), | |
gr.Image(label="Cropped Input"), | |
], | |
) | |
iface.queue(concurrency_count=2).launch() | |