Spaces:
Build error
Build error
import json | |
import random | |
import argparse | |
import numpy as np | |
from tqdm import tqdm | |
from h5py import File | |
import sklearn.metrics | |
random.seed(1) | |
parser = argparse.ArgumentParser() | |
parser.add_argument("train_feature", type=str) | |
parser.add_argument("train_corpus", type=str) | |
parser.add_argument("pred_feature", type=str) | |
parser.add_argument("output_json", type=str) | |
args = parser.parse_args() | |
train_embs = [] | |
train_idx_to_audioid = [] | |
with File(args.train_feature, "r") as store: | |
for audio_id, embedding in tqdm(store.items(), ascii=True): | |
train_embs.append(embedding[()]) | |
train_idx_to_audioid.append(audio_id) | |
train_annotation = json.load(open(args.train_corpus, "r"))["audios"] | |
train_audioid_to_tokens = {} | |
for item in train_annotation: | |
audio_id = item["audio_id"] | |
train_audioid_to_tokens[audio_id] = [cap_item["tokens"] for cap_item in item["captions"]] | |
train_embs = np.stack(train_embs) | |
pred_data = [] | |
pred_embs = [] | |
pred_idx_to_audioids = [] | |
with File(args.pred_feature, "r") as store: | |
for audio_id, embedding in tqdm(store.items(), ascii=True): | |
pred_embs.append(embedding[()]) | |
pred_idx_to_audioids.append(audio_id) | |
pred_embs = np.stack(pred_embs) | |
similarity = sklearn.metrics.pairwise.cosine_similarity(pred_embs, train_embs) | |
for idx, audio_id in enumerate(pred_idx_to_audioids): | |
train_idx = similarity[idx].argmax() | |
pred_data.append({ | |
"filename": audio_id, | |
"tokens": random.choice(train_audioid_to_tokens[train_idx_to_audioid[train_idx]]) | |
}) | |
json.dump({"predictions": pred_data}, open(args.output_json, "w"), ensure_ascii=False, indent=4) | |