siriuszeina commited on
Commit
2e614a5
1 Parent(s): 1216549

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -34
app.py CHANGED
@@ -39,42 +39,12 @@ def load_labels() -> list[str]:
39
  labels = [line.strip() for line in f.readlines()]
40
  return labels
41
 
42
- def get_image(url) -> PIL.Image:
43
- response = requests.get(url)
44
- image = PIL.Image.open(BytesIO(response.content))
45
- return image
46
-
47
 
48
 
49
 
50
  model = load_model()
51
  labels = load_labels()
52
 
53
- def predictx(url: str, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
54
- _, height, width, _ = model.input_shape
55
- response = requests.get(url)
56
- image = PIL.Image.open(BytesIO(response.content))
57
-
58
- image = np.asarray(image)
59
- image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True)
60
- image = image.numpy()
61
- image = dd.image.transform_and_pad_image(image, width, height)
62
- image = image / 255.0
63
- probs = model.predict(image[None, ...])[0]
64
- probs = probs.astype(float)
65
-
66
- indices = np.argsort(probs)[::-1]
67
- result_all = dict()
68
- result_threshold = dict()
69
- for index in indices:
70
- label = labels[index]
71
- prob = probs[index]
72
- result_all[label] = prob
73
- if prob < score_threshold:
74
- break
75
- result_threshold[label] = prob
76
- result_text = ", ".join(result_all.keys())
77
- return result_threshold, result_all, result_text
78
 
79
  def predict(image: PIL.Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
80
  _, height, width, _ = model.input_shape
@@ -107,7 +77,6 @@ with gr.Blocks(css="style.css") as demo:
107
  gr.Markdown(DESCRIPTION)
108
  with gr.Row():
109
  with gr.Column():
110
- url = gr.Textbox(value="https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png")
111
  image = gr.Image(label="Input", type="pil")
112
  score_threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, step=0.05, value=0.5)
113
  run_button = gr.Button("Run")
@@ -128,10 +97,10 @@ with gr.Blocks(css="style.css") as demo:
128
  )
129
 
130
  run_button.click(
131
- fn=predictx,
132
- inputs=[url, score_threshold],
133
  outputs=[result, result_json, result_text],
134
- api_name="predictx",
135
  )
136
 
137
  if __name__ == "__main__":
 
39
  labels = [line.strip() for line in f.readlines()]
40
  return labels
41
 
 
 
 
 
 
42
 
43
 
44
 
45
  model = load_model()
46
  labels = load_labels()
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def predict(image: PIL.Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
50
  _, height, width, _ = model.input_shape
 
77
  gr.Markdown(DESCRIPTION)
78
  with gr.Row():
79
  with gr.Column():
 
80
  image = gr.Image(label="Input", type="pil")
81
  score_threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, step=0.05, value=0.5)
82
  run_button = gr.Button("Run")
 
97
  )
98
 
99
  run_button.click(
100
+ fn=predict,
101
+ inputs=[image, score_threshold],
102
  outputs=[result, result_json, result_text],
103
+ api_name="predict",
104
  )
105
 
106
  if __name__ == "__main__":