Spaces:
Build error
Build error
import sys | |
import os | |
import librosa | |
import numpy as np | |
import torch | |
import audio_to_text.captioning.models | |
import audio_to_text.captioning.models.encoder | |
import audio_to_text.captioning.models.decoder | |
import audio_to_text.captioning.utils.train_util as train_util | |
def load_model(config, checkpoint): | |
ckpt = torch.load(checkpoint, "cpu") | |
encoder_cfg = config["model"]["encoder"] | |
encoder = train_util.init_obj( | |
audio_to_text.captioning.models.encoder, | |
encoder_cfg | |
) | |
if "pretrained" in encoder_cfg: | |
pretrained = encoder_cfg["pretrained"] | |
train_util.load_pretrained_model(encoder, | |
pretrained, | |
sys.stdout.write) | |
decoder_cfg = config["model"]["decoder"] | |
if "vocab_size" not in decoder_cfg["args"]: | |
decoder_cfg["args"]["vocab_size"] = len(ckpt["vocabulary"]) | |
decoder = train_util.init_obj( | |
audio_to_text.captioning.models.decoder, | |
decoder_cfg | |
) | |
if "word_embedding" in decoder_cfg: | |
decoder.load_word_embedding(**decoder_cfg["word_embedding"]) | |
if "pretrained" in decoder_cfg: | |
pretrained = decoder_cfg["pretrained"] | |
train_util.load_pretrained_model(decoder, | |
pretrained, | |
sys.stdout.write) | |
model = train_util.init_obj(audio_to_text.captioning.models, config["model"], | |
encoder=encoder, decoder=decoder) | |
train_util.load_pretrained_model(model, ckpt) | |
model.eval() | |
return { | |
"model": model, | |
"vocabulary": ckpt["vocabulary"] | |
} | |
def decode_caption(word_ids, vocabulary): | |
candidate = [] | |
for word_id in word_ids: | |
word = vocabulary[word_id] | |
if word == "<end>": | |
break | |
elif word == "<start>": | |
continue | |
candidate.append(word) | |
candidate = " ".join(candidate) | |
return candidate | |
class AudioCapModel(object): | |
def __init__(self,weight_dir,device='cpu'): | |
config = os.path.join(weight_dir,'config.yaml') | |
self.config = train_util.parse_config_or_kwargs(config) | |
checkpoint = os.path.join(weight_dir,'swa.pth') | |
resumed = load_model(self.config, checkpoint) | |
model = resumed["model"] | |
self.vocabulary = resumed["vocabulary"] | |
self.model = model.to(device) | |
self.device = device | |
def caption(self,audio_list): | |
if isinstance(audio_list,np.ndarray): | |
audio_list = [audio_list] | |
elif isinstance(audio_list,str): | |
audio_list = [librosa.load(audio_list,sr=32000)[0]] | |
captions = [] | |
for wav in audio_list: | |
inputwav = torch.as_tensor(wav).float().unsqueeze(0).to(self.device) | |
wav_len = torch.LongTensor([len(wav)]) | |
input_dict = { | |
"mode": "inference", | |
"wav": inputwav, | |
"wav_len": wav_len, | |
"specaug": False, | |
"sample_method": "beam", | |
} | |
print(input_dict) | |
out_dict = self.model(input_dict) | |
caption_batch = [decode_caption(seq, self.vocabulary) for seq in \ | |
out_dict["seq"].cpu().numpy()] | |
captions.extend(caption_batch) | |
return captions | |
def __call__(self, audio_list): | |
return self.caption(audio_list) | |