Spaces:
Runtime error
Runtime error
Samuel Stevens
commited on
Commit
•
2cfb891
1
Parent(s):
290c238
v0.1
Browse files- README.md +2 -2
- app.py +169 -31
- lib.py +11 -7
- make_txt_embedding.py +46 -16
- txt_emb.npy +3 -0
README.md
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
---
|
2 |
title: Bioclip Demo
|
3 |
-
emoji:
|
4 |
colorFrom: indigo
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.7.1
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
---
|
|
|
1 |
---
|
2 |
title: Bioclip Demo
|
3 |
+
emoji: 🐘
|
4 |
colorFrom: indigo
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.7.1
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
---
|
app.py
CHANGED
@@ -1,24 +1,29 @@
|
|
|
|
1 |
import os
|
2 |
|
3 |
import gradio as gr
|
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from open_clip import create_model, get_tokenizer
|
7 |
from torchvision import transforms
|
8 |
|
|
|
9 |
from templates import openai_imagenet_template
|
10 |
|
11 |
hf_token = os.getenv("HF_TOKEN")
|
12 |
-
hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo")
|
13 |
|
14 |
model_str = "hf-hub:imageomics/bioclip"
|
15 |
tokenizer_str = "ViT-B-16"
|
|
|
|
|
16 |
|
17 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
18 |
|
19 |
preprocess_img = transforms.Compose(
|
20 |
[
|
21 |
transforms.ToTensor(),
|
|
|
22 |
transforms.Normalize(
|
23 |
mean=(0.48145466, 0.4578275, 0.40821073),
|
24 |
std=(0.26862954, 0.26130258, 0.27577711),
|
@@ -26,6 +31,28 @@ preprocess_img = transforms.Compose(
|
|
26 |
]
|
27 |
)
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
@torch.no_grad()
|
31 |
def get_txt_features(classnames, templates):
|
@@ -42,8 +69,8 @@ def get_txt_features(classnames, templates):
|
|
42 |
|
43 |
|
44 |
@torch.no_grad()
|
45 |
-
def
|
46 |
-
classes = [cls.strip() for cls in
|
47 |
txt_features = get_txt_features(classes, openai_imagenet_template)
|
48 |
|
49 |
img = preprocess_img(img).to(device)
|
@@ -55,7 +82,8 @@ def predict(img, classes: list[str]) -> dict[str, float]:
|
|
55 |
return {cls: prob for cls, prob in zip(classes, probs)}
|
56 |
|
57 |
|
58 |
-
|
|
|
59 |
"""
|
60 |
Predicts from the top of the tree of life down to the species.
|
61 |
"""
|
@@ -63,16 +91,44 @@ def hierarchical_predict(img) -> list[str]:
|
|
63 |
img_features = model.encode_image(img.unsqueeze(0))
|
64 |
img_features = F.normalize(img_features, dim=-1)
|
65 |
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
-
def
|
70 |
-
|
71 |
-
|
72 |
-
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
|
73 |
-
return predict(img, classes)
|
74 |
-
else:
|
75 |
-
return hierarchical_predict(img)
|
76 |
|
77 |
|
78 |
if __name__ == "__main__":
|
@@ -86,22 +142,104 @@ if __name__ == "__main__":
|
|
86 |
|
87 |
tokenizer = get_tokenizer(tokenizer_str)
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
|
4 |
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
from open_clip import create_model, get_tokenizer
|
9 |
from torchvision import transforms
|
10 |
|
11 |
+
import lib
|
12 |
from templates import openai_imagenet_template
|
13 |
|
14 |
hf_token = os.getenv("HF_TOKEN")
|
|
|
15 |
|
16 |
model_str = "hf-hub:imageomics/bioclip"
|
17 |
tokenizer_str = "ViT-B-16"
|
18 |
+
name_lookup_json = "name_lookup.json"
|
19 |
+
txt_emb_npy = "txt_emb.npy"
|
20 |
|
21 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
22 |
|
23 |
preprocess_img = transforms.Compose(
|
24 |
[
|
25 |
transforms.ToTensor(),
|
26 |
+
transforms.Resize((224, 224), antialias=True),
|
27 |
transforms.Normalize(
|
28 |
mean=(0.48145466, 0.4578275, 0.40821073),
|
29 |
std=(0.26862954, 0.26130258, 0.27577711),
|
|
|
31 |
]
|
32 |
)
|
33 |
|
34 |
+
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
|
35 |
+
|
36 |
+
open_domain_examples = [
|
37 |
+
["examples/Ursus-arctos.jpeg", "Species"],
|
38 |
+
["examples/Phoca-vitulina.png", "Species"],
|
39 |
+
["examples/Felis-catus.jpeg", "Genus"],
|
40 |
+
]
|
41 |
+
zero_shot_examples = [
|
42 |
+
[
|
43 |
+
"examples/Carnegiea-gigantea.png",
|
44 |
+
"Carnegiea gigantea\nSchlumbergera opuntioides\nMammillaria albicoma",
|
45 |
+
],
|
46 |
+
[
|
47 |
+
"examples/Amanita-muscaria.jpeg",
|
48 |
+
"Amanita fulva\nAmanita vaginata (grisette)\nAmanita calyptrata (coccoli)\nAmanita crocea\nAmanita rubescens (blusher)\nAmanita caesarea (Caesar's mushroom)\nAmanita jacksonii (American Caesar's mushroom)\nAmanita muscaria (fly agaric)\nAmanita pantherina (panther cap)",
|
49 |
+
],
|
50 |
+
[
|
51 |
+
"examples/Actinostola-abyssorum.png",
|
52 |
+
"Animalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola abyssorum\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola bulbosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola callosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola capensis\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola carlgreni",
|
53 |
+
],
|
54 |
+
]
|
55 |
+
|
56 |
|
57 |
@torch.no_grad()
|
58 |
def get_txt_features(classnames, templates):
|
|
|
69 |
|
70 |
|
71 |
@torch.no_grad()
|
72 |
+
def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
|
73 |
+
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
|
74 |
txt_features = get_txt_features(classes, openai_imagenet_template)
|
75 |
|
76 |
img = preprocess_img(img).to(device)
|
|
|
82 |
return {cls: prob for cls, prob in zip(classes, probs)}
|
83 |
|
84 |
|
85 |
+
@torch.no_grad()
|
86 |
+
def open_domain_classification(img, rank: int) -> list[dict[str, float]]:
|
87 |
"""
|
88 |
Predicts from the top of the tree of life down to the species.
|
89 |
"""
|
|
|
91 |
img_features = model.encode_image(img.unsqueeze(0))
|
92 |
img_features = F.normalize(img_features, dim=-1)
|
93 |
|
94 |
+
outputs = []
|
95 |
+
|
96 |
+
name = []
|
97 |
+
for _ in range(rank + 1):
|
98 |
+
children = tuple(zip(*name_lookup.children(name)))
|
99 |
+
if not children:
|
100 |
+
break
|
101 |
+
values, indices = children
|
102 |
+
txt_features = txt_emb[:, indices].to(device)
|
103 |
+
logits = (model.logit_scale.exp() * img_features @ txt_features).view(-1)
|
104 |
+
|
105 |
+
probs = F.softmax(logits, dim=0).to("cpu").tolist()
|
106 |
+
parent = " ".join(name)
|
107 |
+
outputs.append(
|
108 |
+
{f"{parent} {value}": prob for value, prob in zip(values, probs)}
|
109 |
+
)
|
110 |
+
|
111 |
+
top = values[logits.argmax()]
|
112 |
+
name.append(top)
|
113 |
+
|
114 |
+
while len(outputs) < 7:
|
115 |
+
outputs.append({})
|
116 |
+
|
117 |
+
return list(reversed(outputs))
|
118 |
+
|
119 |
+
|
120 |
+
def change_output(choice):
|
121 |
+
return [
|
122 |
+
gr.Label(
|
123 |
+
num_top_classes=5, label=rank, show_label=True, visible=(6 - i <= choice)
|
124 |
+
)
|
125 |
+
for i, rank in enumerate(reversed(ranks))
|
126 |
+
]
|
127 |
|
128 |
|
129 |
+
def get_name_lookup(path):
|
130 |
+
with open(path) as fd:
|
131 |
+
return lib.TaxonomicTree.from_dict(json.load(fd))
|
|
|
|
|
|
|
|
|
132 |
|
133 |
|
134 |
if __name__ == "__main__":
|
|
|
142 |
|
143 |
tokenizer = get_tokenizer(tokenizer_str)
|
144 |
|
145 |
+
name_lookup = get_name_lookup(name_lookup_json)
|
146 |
+
txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r"))
|
147 |
+
|
148 |
+
done = txt_emb.any(axis=0).sum().item()
|
149 |
+
total = txt_emb.shape[1]
|
150 |
+
status_msg = ""
|
151 |
+
if done != total:
|
152 |
+
status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
|
153 |
+
|
154 |
+
with gr.Blocks() as app:
|
155 |
+
img_input = gr.Image()
|
156 |
+
|
157 |
+
with gr.Tab("Open-Ended"):
|
158 |
+
with gr.Row():
|
159 |
+
with gr.Column():
|
160 |
+
rank_dropdown = gr.Dropdown(
|
161 |
+
label="Taxonomic Rank",
|
162 |
+
info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
|
163 |
+
choices=ranks,
|
164 |
+
value="Species",
|
165 |
+
type="index",
|
166 |
+
)
|
167 |
+
open_domain_btn = gr.Button("Submit", variant="primary")
|
168 |
+
gr.Examples(
|
169 |
+
examples=open_domain_examples,
|
170 |
+
inputs=[img_input, rank_dropdown],
|
171 |
+
)
|
172 |
+
|
173 |
+
with gr.Column():
|
174 |
+
open_domain_outputs = [
|
175 |
+
gr.Label(num_top_classes=5, label=rank, show_label=True)
|
176 |
+
for rank in reversed(ranks)
|
177 |
+
]
|
178 |
+
open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
|
179 |
+
|
180 |
+
open_domain_callback = gr.HuggingFaceDatasetSaver(
|
181 |
+
hf_token, "imageomics/bioclip-demo-open-domain-mistakes", private=True
|
182 |
+
)
|
183 |
+
open_domain_callback.setup(
|
184 |
+
[img_input, *open_domain_outputs], flagging_dir="logs/flagged"
|
185 |
+
)
|
186 |
+
open_domain_flag_btn.click(
|
187 |
+
lambda *args: open_domain_callback.flag(args),
|
188 |
+
[img_input, *open_domain_outputs],
|
189 |
+
None,
|
190 |
+
preprocess=False,
|
191 |
+
)
|
192 |
+
|
193 |
+
with gr.Tab("Zero-Shot"):
|
194 |
+
with gr.Row():
|
195 |
+
with gr.Column():
|
196 |
+
classes_txt = gr.Textbox(
|
197 |
+
placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
|
198 |
+
lines=3,
|
199 |
+
label="Classes",
|
200 |
+
show_label=True,
|
201 |
+
info="Use taxonomic names where possible; include common names if possible.",
|
202 |
+
)
|
203 |
+
zero_shot_btn = gr.Button("Submit", variant="primary")
|
204 |
+
gr.Examples(
|
205 |
+
examples=zero_shot_examples,
|
206 |
+
inputs=[img_input, classes_txt],
|
207 |
+
)
|
208 |
+
|
209 |
+
with gr.Column():
|
210 |
+
zero_shot_output = gr.Label(
|
211 |
+
num_top_classes=5, label="Prediction", show_label=True
|
212 |
+
)
|
213 |
+
zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
|
214 |
+
|
215 |
+
zero_shot_callback = gr.HuggingFaceDatasetSaver(
|
216 |
+
hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
|
217 |
+
)
|
218 |
+
zero_shot_callback.setup(
|
219 |
+
[img_input, zero_shot_output], flagging_dir="logs/flagged"
|
220 |
+
)
|
221 |
+
zero_shot_flag_btn.click(
|
222 |
+
lambda *args: zero_shot_callback.flag(args),
|
223 |
+
[img_input, zero_shot_output],
|
224 |
+
None,
|
225 |
+
preprocess=False,
|
226 |
+
)
|
227 |
+
|
228 |
+
rank_dropdown.change(
|
229 |
+
fn=change_output, inputs=rank_dropdown, outputs=open_domain_outputs
|
230 |
+
)
|
231 |
+
|
232 |
+
open_domain_btn.click(
|
233 |
+
fn=open_domain_classification,
|
234 |
+
inputs=[img_input, rank_dropdown],
|
235 |
+
outputs=open_domain_outputs,
|
236 |
+
)
|
237 |
+
|
238 |
+
zero_shot_btn.click(
|
239 |
+
fn=zero_shot_classification,
|
240 |
+
inputs=[img_input, classes_txt],
|
241 |
+
outputs=zero_shot_output,
|
242 |
+
)
|
243 |
+
|
244 |
+
app.queue(max_size=20)
|
245 |
+
app.launch()
|
lib.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
import json
|
2 |
import itertools
|
|
|
3 |
|
4 |
|
5 |
class TaxonomicNode:
|
@@ -43,11 +43,12 @@ class TaxonomicNode:
|
|
43 |
@classmethod
|
44 |
def from_dict(cls, dct, root):
|
45 |
node = cls(dct["name"], dct["index"], root)
|
46 |
-
node._children = {
|
|
|
|
|
47 |
return node
|
48 |
|
49 |
|
50 |
-
|
51 |
class TaxonomicTree:
|
52 |
"""
|
53 |
Efficient structure for finding taxonomic names and their descendants.
|
@@ -85,11 +86,15 @@ class TaxonomicTree:
|
|
85 |
for kingdom in self.kingdoms.values():
|
86 |
yield from kingdom
|
87 |
|
|
|
|
|
|
|
88 |
@classmethod
|
89 |
def from_dict(cls, dct):
|
90 |
tree = cls()
|
91 |
tree.kingdoms = {
|
92 |
-
kingdom["name"]: TaxonomicNode.from_dict(kingdom, tree)
|
|
|
93 |
}
|
94 |
tree.size = dct["size"]
|
95 |
return tree
|
@@ -112,11 +117,10 @@ class TaxonomicJsonEncoder(json.JSONEncoder):
|
|
112 |
super().default(self, obj)
|
113 |
|
114 |
|
115 |
-
|
116 |
def batched(iterable, n):
|
117 |
# batched('ABCDEFG', 3) --> ABC DEF G
|
118 |
if n < 1:
|
119 |
-
raise ValueError(
|
120 |
it = iter(iterable)
|
121 |
while batch := tuple(itertools.islice(it, n)):
|
122 |
-
yield zip(*batch)
|
|
|
|
|
1 |
import itertools
|
2 |
+
import json
|
3 |
|
4 |
|
5 |
class TaxonomicNode:
|
|
|
43 |
@classmethod
|
44 |
def from_dict(cls, dct, root):
|
45 |
node = cls(dct["name"], dct["index"], root)
|
46 |
+
node._children = {
|
47 |
+
child["name"]: cls.from_dict(child, root) for child in dct["children"]
|
48 |
+
}
|
49 |
return node
|
50 |
|
51 |
|
|
|
52 |
class TaxonomicTree:
|
53 |
"""
|
54 |
Efficient structure for finding taxonomic names and their descendants.
|
|
|
86 |
for kingdom in self.kingdoms.values():
|
87 |
yield from kingdom
|
88 |
|
89 |
+
def __len__(self):
|
90 |
+
return self.size
|
91 |
+
|
92 |
@classmethod
|
93 |
def from_dict(cls, dct):
|
94 |
tree = cls()
|
95 |
tree.kingdoms = {
|
96 |
+
kingdom["name"]: TaxonomicNode.from_dict(kingdom, tree)
|
97 |
+
for kingdom in dct["kingdoms"]
|
98 |
}
|
99 |
tree.size = dct["size"]
|
100 |
return tree
|
|
|
117 |
super().default(self, obj)
|
118 |
|
119 |
|
|
|
120 |
def batched(iterable, n):
|
121 |
# batched('ABCDEFG', 3) --> ABC DEF G
|
122 |
if n < 1:
|
123 |
+
raise ValueError("n must be at least one")
|
124 |
it = iter(iterable)
|
125 |
while batch := tuple(itertools.islice(it, n)):
|
126 |
+
yield zip(*batch)
|
make_txt_embedding.py
CHANGED
@@ -5,6 +5,7 @@ Uses the catalog.csv file from TreeOfLife-10M.
|
|
5 |
import argparse
|
6 |
import csv
|
7 |
import json
|
|
|
8 |
|
9 |
import numpy as np
|
10 |
import torch
|
@@ -22,29 +23,53 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
22 |
|
23 |
@torch.no_grad()
|
24 |
def write_txt_features(name_lookup):
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
28 |
|
29 |
batch_size = args.batch_size // len(openai_imagenet_template)
|
30 |
-
for names, indices in
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
txts = tokenizer(txts).to(device)
|
33 |
txt_features = model.encode_text(txts)
|
34 |
-
txt_features = torch.reshape(
|
|
|
|
|
35 |
txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
|
36 |
txt_features /= txt_features.norm(dim=1, keepdim=True)
|
37 |
-
all_features[:, indices] = txt_features.cpu().numpy()
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
|
41 |
|
42 |
-
def get_name_lookup(catalog_path):
|
|
|
|
|
|
|
|
|
|
|
43 |
lookup = lib.TaxonomicTree()
|
44 |
|
45 |
with open(catalog_path) as fd:
|
46 |
reader = csv.DictReader(fd)
|
47 |
-
for row in tqdm(reader):
|
48 |
name = [
|
49 |
row["kingdom"],
|
50 |
row["phylum"],
|
@@ -58,6 +83,9 @@ def get_name_lookup(catalog_path):
|
|
58 |
name = name[: name.index("")]
|
59 |
lookup.add(name)
|
60 |
|
|
|
|
|
|
|
61 |
return lookup
|
62 |
|
63 |
|
@@ -69,15 +97,17 @@ if __name__ == "__main__":
|
|
69 |
required=True,
|
70 |
)
|
71 |
parser.add_argument("--out-path", help="Path to the output file.", required=True)
|
72 |
-
parser.add_argument(
|
73 |
-
|
|
|
|
|
|
|
|
|
74 |
args = parser.parse_args()
|
75 |
|
76 |
-
name_lookup = get_name_lookup(args.catalog_path)
|
77 |
-
|
78 |
-
json.dump(name_lookup, fd, cls=lib.TaxonomicJsonEncoder)
|
79 |
|
80 |
-
print("Starting.")
|
81 |
model = create_model(model_str, output_dict=True, require_pretrained=True)
|
82 |
model = model.to(device)
|
83 |
print("Created model.")
|
|
|
5 |
import argparse
|
6 |
import csv
|
7 |
import json
|
8 |
+
import os
|
9 |
|
10 |
import numpy as np
|
11 |
import torch
|
|
|
23 |
|
24 |
@torch.no_grad()
|
25 |
def write_txt_features(name_lookup):
|
26 |
+
if os.path.isfile(args.out_path):
|
27 |
+
all_features = np.load(args.out_path)
|
28 |
+
else:
|
29 |
+
all_features = np.zeros((512, len(name_lookup)), dtype=np.float32)
|
30 |
|
31 |
batch_size = args.batch_size // len(openai_imagenet_template)
|
32 |
+
for batch, (names, indices) in enumerate(
|
33 |
+
tqdm(
|
34 |
+
lib.batched(name_lookup, batch_size),
|
35 |
+
desc="txt feats",
|
36 |
+
total=len(name_lookup) // batch_size,
|
37 |
+
)
|
38 |
+
):
|
39 |
+
# Skip if any non-zero elements
|
40 |
+
if all_features[:, indices].any():
|
41 |
+
print(f"Skipping batch {batch}")
|
42 |
+
continue
|
43 |
+
|
44 |
+
txts = [
|
45 |
+
template(name) for name in names for template in openai_imagenet_template
|
46 |
+
]
|
47 |
txts = tokenizer(txts).to(device)
|
48 |
txt_features = model.encode_text(txts)
|
49 |
+
txt_features = torch.reshape(
|
50 |
+
txt_features, (len(names), len(openai_imagenet_template), 512)
|
51 |
+
)
|
52 |
txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
|
53 |
txt_features /= txt_features.norm(dim=1, keepdim=True)
|
54 |
+
all_features[:, indices] = txt_features.T.cpu().numpy()
|
55 |
+
|
56 |
+
if batch % 100 == 0:
|
57 |
+
np.save(args.out_path, all_features)
|
58 |
|
59 |
+
np.save(args.out_path, all_features)
|
60 |
|
61 |
|
62 |
+
def get_name_lookup(catalog_path, cache_path):
|
63 |
+
if os.path.isfile(cache_path):
|
64 |
+
with open(cache_path) as fd:
|
65 |
+
lookup = lib.TaxonomicTree.from_dict(json.load(fd))
|
66 |
+
return lookup
|
67 |
+
|
68 |
lookup = lib.TaxonomicTree()
|
69 |
|
70 |
with open(catalog_path) as fd:
|
71 |
reader = csv.DictReader(fd)
|
72 |
+
for row in tqdm(reader, desc="catalog"):
|
73 |
name = [
|
74 |
row["kingdom"],
|
75 |
row["phylum"],
|
|
|
83 |
name = name[: name.index("")]
|
84 |
lookup.add(name)
|
85 |
|
86 |
+
with open(args.name_cache_path, "w") as fd:
|
87 |
+
json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder)
|
88 |
+
|
89 |
return lookup
|
90 |
|
91 |
|
|
|
97 |
required=True,
|
98 |
)
|
99 |
parser.add_argument("--out-path", help="Path to the output file.", required=True)
|
100 |
+
parser.add_argument(
|
101 |
+
"--name-cache-path",
|
102 |
+
help="Path to the name cache file.",
|
103 |
+
default="name_lookup.json",
|
104 |
+
)
|
105 |
+
parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int)
|
106 |
args = parser.parse_args()
|
107 |
|
108 |
+
name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
|
109 |
+
print("Got name lookup.")
|
|
|
110 |
|
|
|
111 |
model = create_model(model_str, output_dict=True, require_pretrained=True)
|
112 |
model = model.to(device)
|
113 |
print("Created model.")
|
txt_emb.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4a3c3412c3dae49cf92cc760aba5ee84227362adf1eb08f04dd50ee2a756e43
|
3 |
+
size 969818240
|