Spaces:
Sleeping
Sleeping
File size: 3,173 Bytes
290c238 d1c1a86 5cfebb1 d1c1a86 8fa75cc d1c1a86 290c238 216fbaf 290c238 d1c1a86 290c238 d1c1a86 290c238 d1c1a86 290c238 d1c1a86 290c238 5cfebb1 290c238 d1c1a86 290c238 d1c1a86 216fbaf 290c238 d1c1a86 216fbaf d1c1a86 290c238 d1c1a86 290c238 d1c1a86 290c238 d1c1a86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import gradio as gr
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from torchvision import transforms
from templates import openai_imagenet_template
hf_token = os.getenv("HF_TOKEN")
hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo")
model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
preprocess_img = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
@torch.no_grad()
def get_txt_features(classnames, templates):
all_features = []
for classname in classnames:
txts = [template(classname) for template in templates]
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
txt_features /= txt_features.norm()
all_features.append(txt_features)
all_features = torch.stack(all_features, dim=1)
return all_features
@torch.no_grad()
def predict(img, classes: list[str]) -> dict[str, float]:
classes = [cls.strip() for cls in classes if cls.strip()]
txt_features = get_txt_features(classes, openai_imagenet_template)
img = preprocess_img(img).to(device)
img_features = model.encode_image(img.unsqueeze(0))
img_features = F.normalize(img_features, dim=-1)
logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
probs = F.softmax(logits, dim=0).to("cpu").tolist()
return {cls: prob for cls, prob in zip(classes, probs)}
def hierarchical_predict(img) -> list[str]:
"""
Predicts from the top of the tree of life down to the species.
"""
img = preprocess_img(img).to(device)
img_features = model.encode_image(img.unsqueeze(0))
img_features = F.normalize(img_features, dim=-1)
breakpoint()
def run(img, cls_str: str) -> dict[str, float]:
breakpoint()
if cls_str:
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
return predict(img, classes)
else:
return hierarchical_predict(img)
if __name__ == "__main__":
print("Starting.")
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
print("Created model.")
model = torch.compile(model)
print("Compiled model.")
tokenizer = get_tokenizer(tokenizer_str)
demo = gr.Interface(
fn=run,
inputs=[
gr.Image(shape=(224, 224)),
gr.Textbox(
placeholder="dog\ncat\n...",
lines=3,
label="Classes",
show_label=True,
info="If empty, will predict from the entire tree of life.",
),
],
outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
allow_flagging="manual",
flagging_options=["Incorrect", "Other"],
flagging_callback=hf_writer,
)
demo.launch()
|