micole66 commited on
Commit
3de953a
โ€ข
1 Parent(s): 70ad7c2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import torch
5
+ import kelip
6
+ import gradio as gr
7
+
8
+ def load_model():
9
+ model, preprocess_img, tokenizer = kelip.build_model('ViT-B/32')
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model = model.to(device)
12
+ model.eval()
13
+
14
+ model_dict = {'model': model,
15
+ 'preprocess_img': preprocess_img,
16
+ 'tokenizer': tokenizer
17
+ }
18
+ return model_dict
19
+
20
+ def classify(img, user_text):
21
+ preprocess_img = model_dict['preprocess_img']
22
+
23
+ input_img = preprocess_img(img).unsqueeze(0)
24
+
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ input_img = input_img.to(device)
27
+
28
+ # extract image features
29
+ with torch.no_grad():
30
+ image_features = model_dict['model'].encode_image(input_img)
31
+
32
+ # extract text features
33
+ user_texts = user_text.split(',')
34
+ if user_text == '' or user_text.isspace():
35
+ user_texts = []
36
+
37
+ input_texts = model_dict['tokenizer'].encode(user_texts)
38
+ if torch.cuda.is_available():
39
+ input_texts = input_texts.cuda()
40
+ text_features = model_dict['model'].encode_text(input_texts)
41
+
42
+ # l2 normalize
43
+ image_features /= image_features.norm(dim=-1, keepdim=True)
44
+ text_features /= text_features.norm(dim=-1, keepdim=True)
45
+ similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
46
+ values, indices = similarity[0].topk(len(user_texts))
47
+ result = {}
48
+ for value, index in zip(values, indices):
49
+ result[user_texts[index]] = value.item()
50
+
51
+ return result
52
+
53
+ if __name__ == '__main__':
54
+ global model_dict
55
+
56
+ model_dict = load_model()
57
+
58
+ inputs = [gr.inputs.Image(type="pil", label="Image"),
59
+ gr.inputs.Textbox(lines=5, label="Caption"),
60
+ ]
61
+
62
+ outputs = ['label']
63
+
64
+ title = "KELIP"
65
+ description = "Zero-shot classification with KELIP -- Korean and English bilingual contrastive Language-Image Pre-training model that is trained with collected 1.1 billion image-text pairs (708 million Korean and 476 million English).<br> <br><a href='https://arxiv.org/abs/2203.14463' target='_blank'>Arxiv</a> | <a href='https://github.com/navervision/KELIP' target='_blank'>Github</a>"
66
+ examples = [
67
+ ["squid_sundae.jpg", "์˜ค์ง•์–ด ์ˆœ๋Œ€,๊น€๋ฐฅ,์ˆœ๋Œ€,๋–ก๋ณถ์ด"],
68
+ ["seokchon_lake.jpg", "ํ‰ํ™”์˜๋ฌธ,์˜ฌ๋ฆผํ”ฝ๊ณต์›,๋กฏ๋ฐ์›”๋“œ,์„์ดŒํ˜ธ์ˆ˜"],
69
+ ["seokchon_lake.jpg", "spring,summer,autumn,winter"],
70
+ ["dog.jpg", "a dog,a cat,a tiger,a rabbit"],
71
+ ]
72
+
73
+ article = ""
74
+
75
+ iface=gr.Interface(
76
+ fn=classify,
77
+ inputs=inputs,
78
+ outputs=outputs,
79
+ examples=examples,
80
+ title=title,
81
+ description=description,
82
+ article=article
83
+ )
84
+ iface.launch()