File size: 1,218 Bytes
481fb54
ddd296c
 
 
 
481fb54
ddd296c
481fb54
ddd296c
 
 
 
 
 
 
 
 
 
 
481fb54
 
ddd296c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481fb54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image

from constants import MAKES_MODELS, PRICE_BIN_LABELS

model = torch.jit.load("mobilenetv2_432000_calib.pt")
model.eval()
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            (0.48145466, 0.4578275, 0.40821073),
            (0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)


def classify(img: Image.Image):
    in_tensor = transform(img)[None]
    outputs = model(in_tensor)

    price_bins = dict(
        zip(PRICE_BIN_LABELS, F.softmax(outputs["price_bin"], dim=-1)[0].tolist())
    )
    make_models = dict(
        zip(
            ([f"{make} {model}" for make, model in MAKES_MODELS] + ["Unknown"]),
            F.softmax(outputs["make_model"], dim=-1)[0].tolist(),
        )
    )
    return f"${int(round(outputs['price_median'].item()))}", price_bins, make_models


iface = gr.Interface(
    fn=classify,
    inputs=gr.Image(shape=(224, 224), type="pil"),
    outputs=[
        gr.Text(label="Price Prediction"),
        gr.Label(label="Price Bin"),
        gr.Label(label="Make/Model"),
    ],
)
iface.launch()