Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import numpy as np | |
import time | |
import librosa | |
from pathlib import Path | |
import torch | |
from utilities import (create_folder, get_filename, RegressionPostProcessor, | |
write_events_to_midi) | |
from models import Regress_onset_offset_frame_velocity_CRNN, Note_pedal | |
from pytorch_utils import move_data_to_device, forward | |
import config | |
class PianoTranscription(object): | |
def __init__(self, model_type='Note_pedal', checkpoint_path=None, | |
segment_samples=16000*10, device=torch.device('cuda')): | |
"""Class for transcribing piano solo recording. | |
Args: | |
model_type: str | |
checkpoint_path: str | |
segment_samples: int | |
device: 'cuda' | 'cpu' | |
""" | |
checkpoint_path="CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth" | |
print('Checkpoint path: {}'.format(checkpoint_path)) | |
print('Using {} for inference.'.format(device)) | |
self.segment_samples = segment_samples | |
self.frames_per_second = config.frames_per_second | |
self.classes_num = config.classes_num | |
self.onset_threshold = 0.3 | |
self.offset_threshod = 0.3 | |
self.frame_threshold = 0.1 | |
self.pedal_offset_threshold = 0.2 | |
# Build model | |
Model = eval(model_type) | |
self.model = Model(frames_per_second=self.frames_per_second, | |
classes_num=self.classes_num) | |
# Load model | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
self.model.load_state_dict(checkpoint['model'], strict=False) | |
# Parallel | |
if 'cuda' in str(device): | |
self.model.to(device) | |
print('GPU number: {}'.format(torch.cuda.device_count())) | |
self.model = torch.nn.DataParallel(self.model) | |
else: | |
print('Using CPU.') | |
def transcribe(self, audio, midi_path): | |
"""Transcribe an audio recording. | |
Args: | |
audio: (audio_samples,) | |
midi_path: str, path to write out the transcribed MIDI. | |
Returns: | |
transcribed_dict, dict: {'output_dict':, ..., 'est_note_events': ...} | |
""" | |
audio = audio[None, :] # (1, audio_samples) | |
# Pad audio to be evenly divided by segment_samples | |
audio_len = audio.shape[1] | |
pad_len = int(np.ceil(audio_len / self.segment_samples))\ | |
* self.segment_samples - audio_len | |
audio = np.concatenate((audio, np.zeros((1, pad_len))), axis=1) | |
# Enframe to segments | |
segments = self.enframe(audio, self.segment_samples) | |
"""(N, segment_samples)""" | |
# Forward | |
output_dict = forward(self.model, segments, batch_size=1) | |
"""{'reg_onset_output': (N, segment_frames, classes_num), ...}""" | |
# Deframe to original length | |
for key in output_dict.keys(): | |
output_dict[key] = self.deframe(output_dict[key])[0 : audio_len] | |
"""output_dict: { | |
'reg_onset_output': (N, segment_frames, classes_num), | |
'reg_offset_output': (N, segment_frames, classes_num), | |
'frame_output': (N, segment_frames, classes_num), | |
'velocity_output': (N, segment_frames, classes_num)}""" | |
# Post processor | |
post_processor = RegressionPostProcessor(self.frames_per_second, | |
classes_num=self.classes_num, onset_threshold=self.onset_threshold, | |
offset_threshold=self.offset_threshod, | |
frame_threshold=self.frame_threshold, | |
pedal_offset_threshold=self.pedal_offset_threshold) | |
# Post process output_dict to MIDI events | |
(est_note_events, est_pedal_events) = \ | |
post_processor.output_dict_to_midi_events(output_dict) | |
# Write MIDI events to file | |
if midi_path: | |
write_events_to_midi(start_time=0, note_events=est_note_events, | |
pedal_events=est_pedal_events, midi_path=midi_path) | |
print('Write out to {}'.format(midi_path)) | |
transcribed_dict = { | |
'output_dict': output_dict, | |
'est_note_events': est_note_events, | |
'est_pedal_events': est_pedal_events} | |
return transcribed_dict | |
def enframe(self, x, segment_samples): | |
"""Enframe long sequence to short segments. | |
Args: | |
x: (1, audio_samples) | |
segment_samples: int | |
Returns: | |
batch: (N, segment_samples) | |
""" | |
assert x.shape[1] % segment_samples == 0 | |
batch = [] | |
pointer = 0 | |
while pointer + segment_samples <= x.shape[1]: | |
batch.append(x[:, pointer : pointer + segment_samples]) | |
pointer += segment_samples // 2 | |
batch = np.concatenate(batch, axis=0) | |
return batch | |
def deframe(self, x): | |
"""Deframe predicted segments to original sequence. | |
Args: | |
x: (N, segment_frames, classes_num) | |
Returns: | |
y: (audio_frames, classes_num) | |
""" | |
if x.shape[0] == 1: | |
return x[0] | |
else: | |
x = x[:, 0 : -1, :] | |
"""Remove an extra frame in the end of each segment caused by the | |
'center=True' argument when calculating spectrogram.""" | |
(N, segment_samples, classes_num) = x.shape | |
assert segment_samples % 4 == 0 | |
y = [] | |
y.append(x[0, 0 : int(segment_samples * 0.75)]) | |
for i in range(1, N - 1): | |
y.append(x[i, int(segment_samples * 0.25) : int(segment_samples * 0.75)]) | |
y.append(x[-1, int(segment_samples * 0.25) :]) | |
y = np.concatenate(y, axis=0) | |
return y |