ak0601 commited on
Commit
1e46507
1 Parent(s): 683504d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -18
app.py CHANGED
@@ -1,20 +1,35 @@
1
- import gradio as gr
2
- import tensorflow as tf
3
- from tensorflow.keras.preprocessing import image
4
- import matplotlib.pyplot as plt
5
  import numpy as np
6
- model = tf.keras.models.load_model('dogcat_model_bak.h5')
7
- def image_classifier(img):
8
- img1 = image.load_img(str(img), target_size=(64, 64))
9
- img1 = image.img_to_array(img1)
10
- img1 = img1/255
11
- img1 = np.expand_dims(img1, axis=0)
12
- res = model.predict(img, batch_size=None,steps=1)
13
- if(res[:,:]>0.5):
14
- value ='Dog :%1.2f'%(prediction[0,0])
15
- else:
16
- value ='Cat :%1.2f'%(1.0-prediction[0,0])
17
- return value
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
20
- demo.launch()
 
 
 
1
+ import os
2
+ env_var = os.environ.get('env')
3
+ import torch
4
+ import time
5
  import numpy as np
6
+ import gradio as gr
7
+ from PIL import Image
8
+ import torchvision
9
+ from torchvision import transforms
10
+
11
+ device = 'cpu'
12
+ model = torch.load('model.pkl').to(device).eval()
13
+ transform = transforms.Resize(size=500)
14
+ labels = ['Cat', 'Dog']
15
+
16
+ def predict(image):
17
+ start = time.time()
18
+ with torch.no_grad():
19
+ image = Image.fromarray(np.uint8(image)).convert('RGB')
20
+ image = transform(image)
21
+ image = np.array(image)
22
+ image = torch.from_numpy(image).permute(2,0,1).float()
23
+ image = image.unsqueeze(0)
24
+ prediction = model(image.to(device))
25
+ pred_idx = np.argmax(prediction.to(device))
26
+ pred_label = "Cat" if pred_idx == 0 else "Dog"
27
+ label = [l for l in labels if l!=pred_label]
28
+ confidences = {pred_label: float(prediction[0][pred_idx])/100, label[len(label)-1]: 1-(float(prediction[0][pred_idx]))/100 }
29
+ infer = time.time()-start
30
+ return confidences, infer
31
 
32
+ gr.Interface(fn=predict,
33
+ inputs=gr.inputs.Image(shape=(512, 512)),
34
+ outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Textbox('infer',label='Inference Time')],
35
+ examples='1.jpg 2.jpg'.split(' ')).launch()