Spaces:
Sleeping
Sleeping
import os | |
import random | |
import numpy as np | |
import librosa | |
import torch | |
import pytorch_lightning as pl | |
import soundfile as sf | |
from torch.nn.utils.rnn import pad_sequence | |
from transformers import T5Config, T5ForConditionalGeneration | |
from midi_tokenizer import MidiTokenizer, extrapolate_beat_times | |
from layer.input import LogMelSpectrogram, ConcatEmbeddingToMel | |
from preprocess.beat_quantizer import extract_rhythm, interpolate_beat_times | |
from utils.dsp import get_stereo | |
DEFAULT_COMPOSERS = {"various composer": 2052} | |
class TransformerWrapper(pl.LightningModule): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.tokenizer = MidiTokenizer(config.tokenizer) | |
self.t5config = T5Config.from_pretrained("t5-small") | |
for k, v in config.t5.items(): | |
self.t5config.__setattr__(k, v) | |
self.transformer = T5ForConditionalGeneration(self.t5config) | |
self.use_mel = self.config.dataset.use_mel | |
self.mel_is_conditioned = self.config.dataset.mel_is_conditioned | |
self.composer_to_feature_token = config.composer_to_feature_token | |
if self.use_mel and not self.mel_is_conditioned: | |
self.composer_to_feature_token = DEFAULT_COMPOSERS | |
if self.use_mel: | |
self.spectrogram = LogMelSpectrogram() | |
if self.mel_is_conditioned: | |
n_dim = 512 | |
composer_n_vocab = len(self.composer_to_feature_token) | |
embedding_offset = min(self.composer_to_feature_token.values()) | |
self.mel_conditioner = ConcatEmbeddingToMel( | |
embedding_offset=embedding_offset, | |
n_vocab=composer_n_vocab, | |
n_dim=n_dim, | |
) | |
else: | |
self.spectrogram = None | |
self.lr = config.training.lr | |
def forward(self, input_ids, labels): | |
""" | |
Deprecated. | |
""" | |
rt = self.transformer(input_ids=input_ids, labels=labels) | |
return rt | |
def single_inference( | |
self, | |
feature_tokens=None, | |
audio=None, | |
beatstep=None, | |
max_length=256, | |
max_batch_size=64, | |
n_bars=None, | |
composer_value=None, | |
): | |
""" | |
generate a long audio sequence | |
feature_tokens or audio : shape (time, ) | |
beatstep : shape (time, ) | |
- input_ids가 해당하는 beatstep 값들 | |
(offset 빠짐, 즉 beatstep[0] == 0) | |
- beatstep[-1] : input_ids가 끝나는 지점의 시간값 | |
(즉 beatstep[-1] == len(y)//sr) | |
""" | |
assert feature_tokens is not None or audio is not None | |
assert beatstep is not None | |
if feature_tokens is not None: | |
assert len(feature_tokens.shape) == 1 | |
if audio is not None: | |
assert len(audio.shape) == 1 | |
config = self.config | |
PAD = self.t5config.pad_token_id | |
n_bars = config.dataset.n_bars if n_bars is None else n_bars | |
if beatstep[0] > 0.01: | |
print( | |
"inference warning : beatstep[0] is not 0 ({beatstep[0]}). all beatstep will be shifted." | |
) | |
beatstep = beatstep - beatstep[0] | |
if self.use_mel: | |
input_ids = None | |
inputs_embeds, ext_beatstep = self.prepare_inference_mel( | |
audio, | |
beatstep, | |
n_bars=n_bars, | |
padding_value=PAD, | |
composer_value=composer_value, | |
) | |
batch_size = inputs_embeds.shape[0] | |
else: | |
raise NotImplementedError | |
# Considering GPU capacity, some sequence would not be generated at once. | |
relative_tokens = list() | |
for i in range(0, batch_size, max_batch_size): | |
start = i | |
end = min(batch_size, i + max_batch_size) | |
if input_ids is None: | |
_input_ids = None | |
_inputs_embeds = inputs_embeds[start:end] | |
else: | |
_input_ids = input_ids[start:end] | |
_inputs_embeds = None | |
_relative_tokens = self.transformer.generate( | |
input_ids=_input_ids, | |
inputs_embeds=_inputs_embeds, | |
max_length=max_length, | |
) | |
_relative_tokens = _relative_tokens.cpu().numpy() | |
relative_tokens.append(_relative_tokens) | |
max_length = max([rt.shape[-1] for rt in relative_tokens]) | |
for i in range(len(relative_tokens)): | |
relative_tokens[i] = np.pad( | |
relative_tokens[i], | |
[(0, 0), (0, max_length - relative_tokens[i].shape[-1])], | |
constant_values=PAD, | |
) | |
relative_tokens = np.concatenate(relative_tokens) | |
pm, notes = self.tokenizer.relative_batch_tokens_to_midi( | |
relative_tokens, | |
beatstep=ext_beatstep, | |
bars_per_batch=n_bars, | |
cutoff_time_idx=(n_bars + 1) * 4, | |
) | |
return relative_tokens, notes, pm | |
def prepare_inference_mel(self, audio, beatstep, n_bars, padding_value, composer_value=None): | |
n_steps = n_bars * 4 | |
n_target_step = len(beatstep) | |
sample_rate = self.config.dataset.sample_rate | |
ext_beatstep = extrapolate_beat_times(beatstep, (n_bars + 1) * 4 + 1) | |
def split_audio(audio): | |
# Split audio corresponding beat intervals. | |
# Each audio's lengths are different. | |
# Because each corresponding beat interval times are different. | |
batch = [] | |
for i in range(0, n_target_step, n_steps): | |
start_idx = i | |
end_idx = min(i + n_steps, n_target_step) | |
start_sample = int(ext_beatstep[start_idx] * sample_rate) | |
end_sample = int(ext_beatstep[end_idx] * sample_rate) | |
feature = audio[start_sample:end_sample] | |
batch.append(feature) | |
return batch | |
def pad_and_stack_batch(batch): | |
batch = pad_sequence(batch, batch_first=True, padding_value=padding_value) | |
return batch | |
batch = split_audio(audio) | |
batch = pad_and_stack_batch(batch) | |
inputs_embeds = self.spectrogram(batch).transpose(-1, -2) | |
if self.mel_is_conditioned: | |
composer_value = torch.tensor(composer_value).to(self.device) | |
composer_value = composer_value.repeat(inputs_embeds.shape[0]) | |
inputs_embeds = self.mel_conditioner(inputs_embeds, composer_value) | |
return inputs_embeds, ext_beatstep | |
def generate( | |
self, | |
audio_path=None, | |
composer=None, | |
model="generated", | |
steps_per_beat=2, | |
stereo_amp=0.5, | |
n_bars=2, | |
ignore_duplicate=True, | |
show_plot=False, | |
save_midi=False, | |
save_mix=False, | |
midi_path=None, | |
mix_path=None, | |
click_amp=0.2, | |
add_click=False, | |
max_batch_size=None, | |
beatsteps=None, | |
mix_sample_rate=None, | |
audio_y=None, | |
audio_sr=None, | |
): | |
config = self.config | |
device = self.device | |
if audio_path is not None: | |
extension = os.path.splitext(audio_path)[1] | |
mix_path = ( | |
audio_path.replace(extension, f".{model}.{composer}.wav") | |
if mix_path is None | |
else mix_path | |
) | |
midi_path = ( | |
audio_path.replace(extension, f".{model}.{composer}.mid") | |
if midi_path is None | |
else midi_path | |
) | |
max_batch_size = 64 // n_bars if max_batch_size is None else max_batch_size | |
composer_to_feature_token = self.composer_to_feature_token | |
if composer is None: | |
composer = random.sample(list(composer_to_feature_token.keys()), 1)[0] | |
composer_value = composer_to_feature_token[composer] | |
mix_sample_rate = config.dataset.sample_rate if mix_sample_rate is None else mix_sample_rate | |
if not ignore_duplicate: | |
if os.path.exists(midi_path): | |
return | |
ESSENTIA_SAMPLERATE = 44100 | |
if beatsteps is None: | |
y, sr = librosa.load(audio_path, sr=ESSENTIA_SAMPLERATE) | |
( | |
bpm, | |
beat_times, | |
confidence, | |
estimates, | |
essentia_beat_intervals, | |
) = extract_rhythm(audio_path, y=y) | |
beat_times = np.array(beat_times) | |
beatsteps = interpolate_beat_times(beat_times, steps_per_beat, extend=True) | |
else: | |
y = None | |
if self.use_mel: | |
if audio_y is None and config.dataset.sample_rate != ESSENTIA_SAMPLERATE: | |
if y is not None: | |
y = librosa.core.resample( | |
y, | |
orig_sr=ESSENTIA_SAMPLERATE, | |
target_sr=config.dataset.sample_rate, | |
) | |
sr = config.dataset.sample_rate | |
else: | |
y, sr = librosa.load(audio_path, sr=config.dataset.sample_rate) | |
elif audio_y is not None: | |
if audio_sr != config.dataset.sample_rate: | |
audio_y = librosa.core.resample( | |
audio_y, orig_sr=audio_sr, target_sr=config.dataset.sample_rate | |
) | |
audio_sr = config.dataset.sample_rate | |
y = audio_y | |
sr = audio_sr | |
start_sample = int(beatsteps[0] * sr) | |
end_sample = int(beatsteps[-1] * sr) | |
_audio = torch.from_numpy(y)[start_sample:end_sample].to(device) | |
fzs = None | |
else: | |
raise NotImplementedError | |
relative_tokens, notes, pm = self.single_inference( | |
feature_tokens=fzs, | |
audio=_audio, | |
beatstep=beatsteps - beatsteps[0], | |
max_length=config.dataset.target_length * max(1, (n_bars // config.dataset.n_bars)), | |
max_batch_size=max_batch_size, | |
n_bars=n_bars, | |
composer_value=composer_value, | |
) | |
for n in pm.instruments[0].notes: | |
n.start += beatsteps[0] | |
n.end += beatsteps[0] | |
if show_plot or save_mix: | |
if mix_sample_rate != sr: | |
y = librosa.core.resample(y, orig_sr=sr, target_sr=mix_sample_rate) | |
sr = mix_sample_rate | |
if add_click: | |
clicks = librosa.clicks(times=beatsteps, sr=sr, length=len(y)) * click_amp | |
y = y + clicks | |
pm_y = pm.fluidsynth(sr) | |
stereo = get_stereo(y, pm_y, pop_scale=stereo_amp) | |
if show_plot: | |
import note_seq | |
note_seq.plot_sequence(note_seq.midi_to_note_sequence(pm)) | |
if save_mix: | |
sf.write( | |
file=mix_path, | |
data=stereo.T, | |
samplerate=sr, | |
format="wav", | |
) | |
if save_midi: | |
pm.write(midi_path) | |
return pm, composer, mix_path, midi_path | |