Spaces:
Sleeping
Sleeping
File size: 6,302 Bytes
290c238 2cfb891 a33c93d 290c238 a33c93d 290c238 a33c93d 290c238 a33c93d 290c238 2cfb891 290c238 2cfb891 9cbc5ff 2cfb891 a33c93d 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 a33c93d d4005aa a33c93d d4005aa a33c93d d4005aa a33c93d d4005aa a33c93d 6e5adf0 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 a33c93d 290c238 a33c93d 290c238 a33c93d 290c238 a33c93d 6e5adf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
"""
Makes the entire set of text emebeddings for all possible names in the tree of life.
Uses the catalog.csv file from TreeOfLife-10M.
"""
import argparse
import csv
import json
import os
import logging
import numpy as np
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from tqdm import tqdm
import lib
from templates import openai_imagenet_template
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format)
logger = logging.getLogger()
model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
@torch.no_grad()
def write_txt_features(name_lookup):
if os.path.isfile(args.out_path):
all_features = np.load(args.out_path)
else:
all_features = np.zeros((512, len(name_lookup)), dtype=np.float32)
batch_size = args.batch_size // len(openai_imagenet_template)
for batch, (names, indices) in enumerate(
tqdm(
lib.batched(name_lookup.values(), batch_size),
desc="txt feats",
total=len(name_lookup) // batch_size,
)
):
# Skip if any non-zero elements
if all_features[:, indices].any():
logger.info(f"Skipping batch {batch}")
continue
txts = [
template(name) for name in names for template in openai_imagenet_template
]
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = torch.reshape(
txt_features, (len(names), len(openai_imagenet_template), 512)
)
txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
txt_features /= txt_features.norm(dim=1, keepdim=True)
all_features[:, indices] = txt_features.T.cpu().numpy()
if batch % 100 == 0:
np.save(args.out_path, all_features)
np.save(args.out_path, all_features)
def convert_txt_features_to_avgs(name_lookup):
assert os.path.isfile(args.out_path)
# Put that big boy on the GPU. We're going fast.
all_features = torch.from_numpy(np.load(args.out_path)).to(device)
logger.info("Loaded text features from disk to %s.", device)
names_by_rank = [set() for rank in ranks]
for name, index in tqdm(name_lookup.values()):
i = len(name) - 1
names_by_rank[i].add((name, index))
zeroed = 0
for i, rank in reversed(list(enumerate(ranks))):
if rank == "Species":
continue
for name, index in tqdm(names_by_rank[i], desc=rank):
species = tuple(
zip(
*(
(d, i)
for d, i in name_lookup.descendants(prefix=name)
if len(d) >= 6
)
)
)
if not species:
logger.warning("No species for %s.", " ".join(name))
all_features[:, index] = 0.0
zeroed += 1
continue
values, indices = species
mean = all_features[:, indices].mean(dim=1)
all_features[:, index] = F.normalize(mean, dim=0)
out_path, ext = os.path.splitext(args.out_path)
np.save(f"{out_path}_avgs{ext}", all_features.cpu().numpy())
if zeroed:
logger.warning(
"Zeroed out %d nodes because they didn't have any genus or species-level labels.",
zeroed,
)
def convert_txt_features_to_species_only(name_lookup):
assert os.path.isfile(args.out_path)
all_features = np.load(args.out_path)
logger.info("Loaded text features from disk.")
species = [(d, i) for d, i in name_lookup.descendants() if len(d) == 7]
species_features = np.zeros((512, len(species)), dtype=np.float32)
species_names = [""] * len(species)
for new_i, (name, old_i) in enumerate(tqdm(species)):
species_features[:, new_i] = all_features[:, old_i]
species_names[new_i] = name
out_path, ext = os.path.splitext(args.out_path)
np.save(f"{out_path}_species{ext}", species_features)
with open(f"{out_path}_species.json", "w") as fd:
json.dump(species_names, fd, indent=2)
def get_name_lookup(catalog_path, cache_path):
if os.path.isfile(cache_path):
with open(cache_path) as fd:
lookup = lib.TaxonomicTree.from_dict(json.load(fd))
return lookup
lookup = lib.TaxonomicTree()
with open(catalog_path) as fd:
reader = csv.DictReader(fd)
for row in tqdm(reader, desc="catalog"):
name = [
row["kingdom"],
row["phylum"],
row["class"],
row["order"],
row["family"],
row["genus"],
row["species"],
]
if any(not value for value in name):
name = name[: name.index("")]
lookup.add(name)
with open(args.name_cache_path, "w") as fd:
json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder)
return lookup
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--catalog-path",
help="Path to the catalog.csv file from TreeOfLife-10M.",
required=True,
)
parser.add_argument("--out-path", help="Path to the output file.", required=True)
parser.add_argument(
"--name-cache-path",
help="Path to the name cache file.",
default="name_lookup.json",
)
parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int)
args = parser.parse_args()
name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
logger.info("Got name lookup.")
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
logger.info("Created model.")
model = torch.compile(model)
logger.info("Compiled model.")
tokenizer = get_tokenizer(tokenizer_str)
write_txt_features(name_lookup)
convert_txt_features_to_avgs(name_lookup)
convert_txt_features_to_species_only(name_lookup)
|