File size: 902 Bytes
757451d
 
 
 
 
 
 
 
1eec766
757451d
 
 
 
 
 
 
 
1eec766
757451d
 
1eec766
757451d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from PIL import Image
import gradio as gr
from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch

model = ViTForImageClassification.from_pretrained('sreeramajay/pollution')
transforms = ViTFeatureExtractor.from_pretrained('sreeramajay/pollution')

def polln_classify(image): 
    labels = {0:"Air Pollution", 1: "Land Pollution" , 2: "Water Pollution"}
    inputs = transforms(image, return_tensors='pt')
    output = model(**inputs)
    probability = output.logits.softmax(1)
    values, indices = torch.topk(probability, k=3)
    return {labels[i.item()]: v.item() for i, v in zip(indices.numpy()[0], values.detach().numpy()[0])}


gr.Interface(polln_classify,
    inputs = gr.inputs.Image(type="pil", label="Chosen Image"),
    outputs = 'label',
    examples = ["air_pollution.jpg","land_pollution.jpg","water_pollution.jpg"],
    theme="seafoam",
).launch(debug=True)