tomaseo2022's picture
Update app.py
6e3f070 verified
import os
os.system("pip install --upgrade httpx")
os.system("pip install --upgrade gradio")
os.system("pip install opencv-python")
os.system("pip install torch")
os.system("pip install --upgrade pillow")
os.system("pip install torchvision")
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
def predict(image):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torchvision.models.resnet50(pretrained=True).to(device)
model.eval()
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = Image.fromarray(image.astype('uint8'), 'RGB')
img = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img)
_, predicted = output.max(1)
return predicted.item()
input_image = gr.component.Image(type="filepath", label="Input")
output_text = gr.outputs.Textbox()
gr.Interface(fn=predict, inputs=input_image, outputs=output_text).launch()