Atharv Subhekar commited on
Commit
c846a27
1 Parent(s): 2330284

Application Commit

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
sample_images/train_142.jpg ADDED
sample_images/train_32.jpg ADDED
sample_images/train_59.jpg ADDED
sample_images/train_67.jpg ADDED
sample_images/train_75.jpg ADDED
sample_images/train_92.jpg ADDED
satellite_app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """satellite_app.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
8
+ """
9
+
10
+ !pip install gradio --quiet
11
+
12
+ !pip install -Uq transformers datasets timm accelerate evaluate
13
+
14
+ import gradio as gr
15
+ from safetensors.torch import load_model
16
+ from timm import create_model
17
+ from huggingface_hub import hf_hub_download
18
+ from datasets import load_dataset
19
+ import torch
20
+ import torchvision.transforms as T
21
+ import cv2
22
+ import matplotlib.pyplot as plt
23
+ import numpy as np
24
+ from PIL import Image
25
+
26
+ safe_tensors = hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
27
+
28
+ model_name = 'swin_s3_base_224'
29
+ # intialize the model
30
+ model = create_model(
31
+ model_name,
32
+ num_classes=17
33
+ )
34
+
35
+ load_model(model,safe_tensors)
36
+
37
+ def one_hot_decoding(labels):
38
+ class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze']
39
+ id2label = {idx:c for idx,c in enumerate(class_names)}
40
+
41
+ id_list = []
42
+ for idx,i in enumerate(labels):
43
+ if i == 1:
44
+ id_list.append(idx)
45
+
46
+ true_labels = []
47
+ for i in id_list:
48
+ true_labels.append(id2label[i])
49
+ return true_labels
50
+
51
+ def model_output(image):
52
+ image = cv2.imread(name)
53
+ PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
54
+
55
+ img_size = (224,224)
56
+ test_tfms = T.Compose([
57
+ T.Resize(img_size),
58
+ T.ToTensor(),
59
+ ])
60
+
61
+ img = test_tfms(PIL_image)
62
+
63
+ with torch.no_grad():
64
+ logits = model(img.unsqueeze(0))
65
+
66
+ predictions = logits.sigmoid() > 0.5
67
+ predictions = predictions.float().numpy().flatten()
68
+ pred_labels = one_hot_decoding(predictions)
69
+ output_text = " ".join(pred_labels)
70
+
71
+ return output_text
72
+
73
+ app = gr.Interface(fn=model_output, inputs="image", outputs="text")
74
+ app.launch()
75
+