Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,590 Bytes
dcca7d2 4416228 dcca7d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import os
import numpy as np
import time
import torch
from utilities import pad_truncate_sequence
def move_data_to_device(x, device):
if 'float' in str(x.dtype):
x = torch.Tensor(x)
elif 'int' in str(x.dtype):
x = torch.LongTensor(x)
else:
return x
return x.to(device)
def append_to_dict(dict, key, value):
if key in dict.keys():
dict[key].append(value)
else:
dict[key] = [value]
def forward(model, x, batch_size):
"""Forward data to model in mini-batch.
Args:
model: object
x: (N, segment_samples)
batch_size: int
Returns:
output_dict: dict, e.g. {
'frame_output': (segments_num, frames_num, classes_num),
'onset_output': (segments_num, frames_num, classes_num),
...}
"""
output_dict = {}
device = next(model.parameters()).device
pointer = 0
total_segments = int(np.ceil(len(x) / batch_size))
while True:
print('Segment {} / {}'.format(pointer, total_segments))
if pointer >= len(x):
break
batch_waveform = move_data_to_device(x[pointer : pointer + batch_size], device)
pointer += batch_size
with torch.no_grad():
model.eval()
batch_output_dict = model(batch_waveform)
for key in batch_output_dict.keys():
append_to_dict(output_dict, key, batch_output_dict[key].data.cpu().numpy())
for key in output_dict.keys():
output_dict[key] = np.concatenate(output_dict[key], axis=0)
return output_dict
|