Spaces:
Running
Running
# This code is modified from https://github.com/ZFTurbo/ | |
import pdb | |
import librosa | |
from tqdm import tqdm | |
import os | |
import torch | |
import numpy as np | |
import soundfile as sf | |
import torch.nn as nn | |
import warnings | |
warnings.filterwarnings("ignore") | |
from bs_roformer.bs_roformer import BSRoformer | |
class BsRoformer_Loader: | |
def get_model_from_config(self): | |
config = { | |
"attn_dropout": 0.1, | |
"depth": 12, | |
"dim": 512, | |
"dim_freqs_in": 1025, | |
"dim_head": 64, | |
"ff_dropout": 0.1, | |
"flash_attn": True, | |
"freq_transformer_depth": 1, | |
"freqs_per_bands":(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129), | |
"heads": 8, | |
"linear_transformer_depth": 0, | |
"mask_estimator_depth": 2, | |
"multi_stft_hop_size": 147, | |
"multi_stft_normalized": False, | |
"multi_stft_resolution_loss_weight": 1.0, | |
"multi_stft_resolutions_window_sizes":(4096, 2048, 1024, 512, 256), | |
"num_stems": 1, | |
"stereo": True, | |
"stft_hop_length": 441, | |
"stft_n_fft": 2048, | |
"stft_normalized": False, | |
"stft_win_length": 2048, | |
"time_transformer_depth": 1, | |
} | |
model = BSRoformer( | |
**dict(config) | |
) | |
return model | |
def demix_track(self, model, mix, device): | |
C = 352800 | |
# num_overlap | |
N = 1 | |
fade_size = C // 10 | |
step = int(C // N) | |
border = C - step | |
batch_size = 4 | |
length_init = mix.shape[-1] | |
progress_bar = tqdm(total=length_init // step + 1) | |
progress_bar.set_description("Processing") | |
# Do pad from the beginning and end to account floating window results better | |
if length_init > 2 * border and (border > 0): | |
mix = nn.functional.pad(mix, (border, border), mode='reflect') | |
# Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment | |
window_size = C | |
fadein = torch.linspace(0, 1, fade_size) | |
fadeout = torch.linspace(1, 0, fade_size) | |
window_start = torch.ones(window_size) | |
window_middle = torch.ones(window_size) | |
window_finish = torch.ones(window_size) | |
window_start[-fade_size:] *= fadeout # First audio chunk, no fadein | |
window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout | |
window_middle[-fade_size:] *= fadeout | |
window_middle[:fade_size] *= fadein | |
with torch.amp.autocast('cuda'): | |
with torch.inference_mode(): | |
req_shape = (1, ) + tuple(mix.shape) | |
result = torch.zeros(req_shape, dtype=torch.float32) | |
counter = torch.zeros(req_shape, dtype=torch.float32) | |
i = 0 | |
batch_data = [] | |
batch_locations = [] | |
while i < mix.shape[1]: | |
part = mix[:, i:i + C].to(device) | |
length = part.shape[-1] | |
if length < C: | |
if length > C // 2 + 1: | |
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') | |
else: | |
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) | |
if(self.is_half==True): | |
part=part.half() | |
batch_data.append(part) | |
batch_locations.append((i, length)) | |
i += step | |
progress_bar.update(1) | |
if len(batch_data) >= batch_size or (i >= mix.shape[1]): | |
arr = torch.stack(batch_data, dim=0) | |
# print(23333333,arr.dtype) | |
x = model(arr) | |
window = window_middle | |
if i - step == 0: # First audio chunk, no fadein | |
window = window_start | |
elif i >= mix.shape[1]: # Last audio chunk, no fadeout | |
window = window_finish | |
for j in range(len(batch_locations)): | |
start, l = batch_locations[j] | |
result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l] | |
counter[..., start:start+l] += window[..., :l] | |
batch_data = [] | |
batch_locations = [] | |
estimated_sources = result / counter | |
estimated_sources = estimated_sources.cpu().numpy() | |
np.nan_to_num(estimated_sources, copy=False, nan=0.0) | |
if length_init > 2 * border and (border > 0): | |
# Remove pad | |
estimated_sources = estimated_sources[..., border:-border] | |
progress_bar.close() | |
return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)} | |
def run_folder(self,input, vocal_root, others_root, format): | |
# start_time = time.time() | |
self.model.eval() | |
path = input | |
if not os.path.isdir(vocal_root): | |
os.mkdir(vocal_root) | |
if not os.path.isdir(others_root): | |
os.mkdir(others_root) | |
try: | |
mix, sr = librosa.load(path, sr=44100, mono=False) | |
except Exception as e: | |
print('Can read track: {}'.format(path)) | |
print('Error message: {}'.format(str(e))) | |
return | |
# Convert mono to stereo if needed | |
if len(mix.shape) == 1: | |
mix = np.stack([mix, mix], axis=0) | |
mix_orig = mix.copy() | |
mixture = torch.tensor(mix, dtype=torch.float32) | |
res = self.demix_track(self.model, mixture, self.device) | |
estimates = res['vocals'].T | |
if format in ["wav", "flac"]: | |
sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr) | |
sf.write("{}/{}_{}.{}".format(others_root, os.path.basename(path)[:-4], 'instrumental', format), mix_orig.T - estimates, sr) | |
else: | |
path_vocal = "%s/%s_vocals.wav" % (vocal_root, os.path.basename(path)[:-4]) | |
path_other = "%s/%s_instrumental.wav" % (others_root, os.path.basename(path)[:-4]) | |
sf.write(path_vocal, estimates, sr) | |
sf.write(path_other, mix_orig.T - estimates, sr) | |
opt_path_vocal = path_vocal[:-4] + ".%s" % format | |
opt_path_other = path_other[:-4] + ".%s" % format | |
if os.path.exists(path_vocal): | |
os.system( | |
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal) | |
) | |
if os.path.exists(opt_path_vocal): | |
try: | |
os.remove(path_vocal) | |
except: | |
pass | |
if os.path.exists(path_other): | |
os.system( | |
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other) | |
) | |
if os.path.exists(opt_path_other): | |
try: | |
os.remove(path_other) | |
except: | |
pass | |
# print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) | |
def __init__(self, model_path, device,is_half): | |
self.device = device | |
self.extract_instrumental=True | |
model = self.get_model_from_config() | |
state_dict = torch.load(model_path,map_location="cpu") | |
model.load_state_dict(state_dict) | |
self.is_half=is_half | |
if(is_half==False): | |
self.model = model.to(device) | |
else: | |
self.model = model.half().to(device) | |
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False): | |
self.run_folder(input, vocal_root, others_root, format) | |