unixpickle commited on
Commit
fff6fcf
1 Parent(s): 120e140

download model, predict year

Browse files
Files changed (2) hide show
  1. app.py +13 -1
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  import gradio as gr
 
2
  import torch
3
  import torch.nn.functional as F
4
  import torchvision.transforms as transforms
@@ -6,7 +9,14 @@ from PIL import Image
6
 
7
  from constants import MAKES_MODELS, PRICE_BIN_LABELS, YEARS
8
 
9
- model = torch.jit.load("mobilenetv2_432000_calib.pt")
 
 
 
 
 
 
 
10
  model.eval()
11
  transform = transforms.Compose(
12
  [
@@ -18,6 +28,8 @@ transform = transforms.Compose(
18
  ]
19
  )
20
 
 
 
21
 
22
  def classify(img: Image.Image):
23
  in_tensor = transform(img)[None]
 
1
+ import io
2
+
3
  import gradio as gr
4
+ import requests
5
  import torch
6
  import torch.nn.functional as F
7
  import torchvision.transforms as transforms
 
9
 
10
  from constants import MAKES_MODELS, PRICE_BIN_LABELS, YEARS
11
 
12
+ print("downloading checkpoint...")
13
+ data = requests.get(
14
+ "https://data.aqnichol.com/car-data/models/mobilenetv2_432000_calib_torchscript.pt",
15
+ stream=True,
16
+ ).content
17
+
18
+ print("creating model...")
19
+ model = torch.jit.load(io.BytesIO(data))
20
  model.eval()
21
  transform = transforms.Compose(
22
  [
 
28
  ]
29
  )
30
 
31
+ print("done.")
32
+
33
 
34
  def classify(img: Image.Image):
35
  in_tensor = transform(img)[None]
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  torch
2
  torchvision
3
  Pillow
 
 
1
  torch
2
  torchvision
3
  Pillow
4
+ requests