Spaces:
Sleeping
Sleeping
Samuel Stevens
commited on
Commit
•
d1c1a86
1
Parent(s):
d86aa61
add app.py
Browse files- app.py +73 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
from open_clip import create_model, get_tokenizer
|
7 |
+
from open_clip.training.imagenet_zeroshot_data import openai_imagenet_template
|
8 |
+
|
9 |
+
model_str = "ViT-B-16"
|
10 |
+
pretrained = "/fs/ess/PAS2136/foundation_model/model/10m/2023_09_22-21_14_04-model_ViT-B-16-lr_0.0001-b_4096-j_8-p_amp/checkpoints/epoch_99.pt"
|
11 |
+
|
12 |
+
preprocess_img = transforms.Compose(
|
13 |
+
[
|
14 |
+
transforms.ToTensor(),
|
15 |
+
transforms.Normalize(
|
16 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
17 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
18 |
+
),
|
19 |
+
]
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
@torch.no_grad()
|
24 |
+
def get_txt_features(classnames, templates):
|
25 |
+
all_features = []
|
26 |
+
for classname in classnames:
|
27 |
+
txts = [template(classname) for template in templates]
|
28 |
+
txts = tokenizer(txts)
|
29 |
+
txt_features = model.encode_text(txts)
|
30 |
+
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
|
31 |
+
txt_features /= txt_features.norm()
|
32 |
+
all_features.append(txt_features)
|
33 |
+
all_features = torch.stack(all_features, dim=1)
|
34 |
+
return all_features
|
35 |
+
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def predict(img, cls_str: str) -> dict[str, float]:
|
39 |
+
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
|
40 |
+
txt_features = get_txt_features(classes, openai_imagenet_template)
|
41 |
+
|
42 |
+
img = preprocess_img(img)
|
43 |
+
|
44 |
+
img_features = model.encode_image(img.unsqueeze(0))
|
45 |
+
img_features = F.normalize(img_features, dim=-1)
|
46 |
+
logits = (img_features @ txt_features).squeeze()
|
47 |
+
probs = F.softmax(logits, dim=0).tolist()
|
48 |
+
return {cls: prob for cls, prob in zip(classes, probs)}
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
print("Starting.")
|
53 |
+
model = create_model(model_str, pretrained, output_dict=True)
|
54 |
+
print("Created model.")
|
55 |
+
|
56 |
+
model = torch.compile(model)
|
57 |
+
print("Compiled model.")
|
58 |
+
|
59 |
+
tokenizer = get_tokenizer(model_str)
|
60 |
+
|
61 |
+
demo = gr.Interface(
|
62 |
+
fn=predict,
|
63 |
+
inputs=[
|
64 |
+
gr.Image(shape=(224, 224)),
|
65 |
+
gr.Textbox(
|
66 |
+
placeholder="dog\ncat\n...", lines=3, label="Classes", show_label=True
|
67 |
+
),
|
68 |
+
],
|
69 |
+
outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
|
70 |
+
)
|
71 |
+
|
72 |
+
demo.launch()
|
73 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
open_clip_torch
|
2 |
+
torchvision
|
3 |
+
torch
|
4 |
+
gradio
|