Spaces:
Runtime error
Runtime error
import torch | |
import torchaudio | |
#fixes second prediction bug | |
torch._C._jit_override_can_fuse_on_cpu(False) | |
torch._C._jit_override_can_fuse_on_gpu(False) | |
torch._C._jit_set_texpr_fuser_enabled(False) | |
torch._C._jit_set_nvfuser_enabled(False) | |
loader = torch.jit.load("audio_loader.pt") | |
model = torch.jit.load('QuartzNet_thunderspeech_3.pt') | |
vocab = model.text_transform.vocab.itos | |
vocab[-1] = '' | |
def convert_probs(probs): | |
ids = probs.argmax(1)[0] | |
s = [] | |
if vocab[ids[0]]: s.append(vocab[ids[0]]) | |
for i in range(1,len(ids)): | |
if ids[i-1] != ids[i]: | |
new = vocab[ids[i]] | |
if new: s.append(new) | |
#return '.'.join(s) | |
return s | |
def predict(path): | |
audio = loader(path) | |
probs = model(audio, torch.tensor(audio.shape[0] * [audio.shape[-1]], device=audio.device))[0] | |
return convert_probs(probs) |