|
import contextlib |
|
import importlib |
|
from huggingface_hub import hf_hub_download |
|
import numpy as np |
|
import torch |
|
|
|
from inspect import isfunction |
|
import os |
|
import subprocess |
|
import tempfile |
|
import json |
|
import soundfile as sf |
|
import time |
|
import wave |
|
import torchaudio |
|
import progressbar |
|
from librosa.filters import mel as librosa_mel_fn |
|
from audiosr.lowpass import lowpass |
|
|
|
hann_window = {} |
|
mel_basis = {} |
|
|
|
|
|
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): |
|
return torch.log(torch.clamp(x, min=clip_val) * C) |
|
|
|
|
|
def dynamic_range_decompression_torch(x, C=1): |
|
return torch.exp(x) / C |
|
|
|
|
|
def spectral_normalize_torch(magnitudes): |
|
output = dynamic_range_compression_torch(magnitudes) |
|
return output |
|
|
|
|
|
def spectral_de_normalize_torch(magnitudes): |
|
output = dynamic_range_decompression_torch(magnitudes) |
|
return output |
|
|
|
|
|
def _locate_cutoff_freq(stft, percentile=0.97): |
|
def _find_cutoff(x, percentile=0.95): |
|
percentile = x[-1] * percentile |
|
for i in range(1, x.shape[0]): |
|
if x[-i] < percentile: |
|
return x.shape[0] - i |
|
return 0 |
|
|
|
magnitude = torch.abs(stft) |
|
energy = torch.cumsum(torch.sum(magnitude, dim=0), dim=0) |
|
return _find_cutoff(energy, percentile) |
|
|
|
|
|
def pad_wav(waveform, target_length): |
|
waveform_length = waveform.shape[-1] |
|
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length |
|
|
|
if waveform_length == target_length: |
|
return waveform |
|
|
|
|
|
temp_wav = np.zeros((1, target_length), dtype=np.float32) |
|
rand_start = 0 |
|
|
|
temp_wav[:, rand_start : rand_start + waveform_length] = waveform |
|
return temp_wav |
|
|
|
|
|
def lowpass_filtering_prepare_inference(dl_output): |
|
waveform = dl_output["waveform"] |
|
sampling_rate = dl_output["sampling_rate"] |
|
|
|
cutoff_freq = ( |
|
_locate_cutoff_freq(dl_output["stft"], percentile=0.985) / 1024 |
|
) * 24000 |
|
|
|
|
|
if(cutoff_freq < 1000): |
|
cutoff_freq = 24000 |
|
|
|
order = 8 |
|
ftype = np.random.choice(["butter", "cheby1", "ellip", "bessel"]) |
|
filtered_audio = lowpass( |
|
waveform.numpy().squeeze(), |
|
highcut=cutoff_freq, |
|
fs=sampling_rate, |
|
order=order, |
|
_type=ftype, |
|
) |
|
|
|
filtered_audio = torch.FloatTensor(filtered_audio.copy()).unsqueeze(0) |
|
|
|
if waveform.size(-1) <= filtered_audio.size(-1): |
|
filtered_audio = filtered_audio[..., : waveform.size(-1)] |
|
else: |
|
filtered_audio = torch.functional.pad( |
|
filtered_audio, (0, waveform.size(-1) - filtered_audio.size(-1)) |
|
) |
|
|
|
return {"waveform_lowpass": filtered_audio} |
|
|
|
|
|
def mel_spectrogram_train(y): |
|
global mel_basis, hann_window |
|
|
|
sampling_rate = 48000 |
|
filter_length = 2048 |
|
hop_length = 480 |
|
win_length = 2048 |
|
n_mel = 256 |
|
mel_fmin = 20 |
|
mel_fmax = 24000 |
|
|
|
if 24000 not in mel_basis: |
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=filter_length, n_mels=n_mel, fmin=mel_fmin, fmax=mel_fmax) |
|
mel_basis[str(mel_fmax) + "_" + str(y.device)] = ( |
|
torch.from_numpy(mel).float().to(y.device) |
|
) |
|
hann_window[str(y.device)] = torch.hann_window(win_length).to(y.device) |
|
|
|
y = torch.nn.functional.pad( |
|
y.unsqueeze(1), |
|
(int((filter_length - hop_length) / 2), int((filter_length - hop_length) / 2)), |
|
mode="reflect", |
|
) |
|
|
|
y = y.squeeze(1) |
|
|
|
stft_spec = torch.stft( |
|
y, |
|
filter_length, |
|
hop_length=hop_length, |
|
win_length=win_length, |
|
window=hann_window[str(y.device)], |
|
center=False, |
|
pad_mode="reflect", |
|
normalized=False, |
|
onesided=True, |
|
return_complex=True, |
|
) |
|
|
|
stft_spec = torch.abs(stft_spec) |
|
|
|
mel = spectral_normalize_torch( |
|
torch.matmul(mel_basis[str(mel_fmax) + "_" + str(y.device)], stft_spec) |
|
) |
|
|
|
return mel[0], stft_spec[0] |
|
|
|
|
|
def pad_spec(log_mel_spec, target_frame): |
|
n_frames = log_mel_spec.shape[0] |
|
p = target_frame - n_frames |
|
|
|
if p > 0: |
|
m = torch.nn.ZeroPad2d((0, 0, 0, p)) |
|
log_mel_spec = m(log_mel_spec) |
|
elif p < 0: |
|
log_mel_spec = log_mel_spec[0:target_frame, :] |
|
|
|
if log_mel_spec.size(-1) % 2 != 0: |
|
log_mel_spec = log_mel_spec[..., :-1] |
|
|
|
return log_mel_spec |
|
|
|
|
|
def wav_feature_extraction(waveform, target_frame): |
|
waveform = waveform[0, ...] |
|
waveform = torch.FloatTensor(waveform) |
|
|
|
log_mel_spec, stft = mel_spectrogram_train(waveform.unsqueeze(0)) |
|
|
|
log_mel_spec = torch.FloatTensor(log_mel_spec.T) |
|
stft = torch.FloatTensor(stft.T) |
|
|
|
log_mel_spec, stft = pad_spec(log_mel_spec, target_frame), pad_spec( |
|
stft, target_frame |
|
) |
|
return log_mel_spec, stft |
|
|
|
|
|
def normalize_wav(waveform): |
|
waveform = waveform - np.mean(waveform) |
|
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) |
|
return waveform * 0.5 |
|
|
|
def read_wav_file(filename): |
|
waveform, sr = torchaudio.load(filename) |
|
duration = waveform.size(-1) / sr |
|
|
|
if(duration > 10.24): |
|
print("\033[93m {}\033[00m" .format("Warning: audio is longer than 10.24 seconds, may degrade the model performance. It's recommand to truncate your audio to 5.12 seconds before input to AudioSR to get the best performance.")) |
|
|
|
if(duration % 5.12 != 0): |
|
pad_duration = duration + (5.12 - duration % 5.12) |
|
else: |
|
pad_duration = duration |
|
|
|
target_frame = int(pad_duration * 100) |
|
|
|
waveform = torchaudio.functional.resample(waveform, sr, 48000) |
|
|
|
waveform = waveform.numpy()[0, ...] |
|
|
|
waveform = normalize_wav( |
|
waveform |
|
) |
|
|
|
waveform = waveform[None, ...] |
|
waveform = pad_wav(waveform, target_length=int(48000 * pad_duration)) |
|
return waveform, target_frame, pad_duration |
|
|
|
def read_audio_file(filename): |
|
waveform, target_frame, duration = read_wav_file(filename) |
|
log_mel_spec, stft = wav_feature_extraction(waveform, target_frame) |
|
return log_mel_spec, stft, waveform, duration, target_frame |
|
|
|
|
|
def read_list(fname): |
|
result = [] |
|
with open(fname, "r", encoding="utf-8") as f: |
|
for each in f.readlines(): |
|
each = each.strip("\n") |
|
result.append(each) |
|
return result |
|
|
|
|
|
def get_duration(fname): |
|
with contextlib.closing(wave.open(fname, "r")) as f: |
|
frames = f.getnframes() |
|
rate = f.getframerate() |
|
return frames / float(rate) |
|
|
|
|
|
def get_bit_depth(fname): |
|
with contextlib.closing(wave.open(fname, "r")) as f: |
|
bit_depth = f.getsampwidth() * 8 |
|
return bit_depth |
|
|
|
|
|
def get_time(): |
|
t = time.localtime() |
|
return time.strftime("%d_%m_%Y_%H_%M_%S", t) |
|
|
|
|
|
def seed_everything(seed): |
|
import random, os |
|
import numpy as np |
|
import torch |
|
|
|
random.seed(seed) |
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
|
|
def strip_silence(orignal_path, input_path, output_path): |
|
get_dur = subprocess.run([ |
|
'ffprobe', |
|
'-v', 'error', |
|
'-select_streams', 'a:0', |
|
'-show_entries', 'format=duration', |
|
'-sexagesimal', |
|
'-of', 'json', |
|
orignal_path |
|
], stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
|
|
duration = json.loads(get_dur.stdout)['format']['duration'] |
|
|
|
subprocess.run([ |
|
'ffmpeg', |
|
'-y', |
|
'-ss', '00:00:00', |
|
'-i', input_path, |
|
'-t', duration, |
|
'-c', 'copy', |
|
output_path |
|
]) |
|
os.remove(input_path) |
|
|
|
|
|
|
|
def save_wave(waveform, inputpath, savepath, name="outwav", samplerate=16000): |
|
if type(name) is not list: |
|
name = [name] * waveform.shape[0] |
|
|
|
for i in range(waveform.shape[0]): |
|
if waveform.shape[0] > 1: |
|
fname = "%s_%s.wav" % ( |
|
os.path.basename(name[i]) |
|
if (not ".wav" in name[i]) |
|
else os.path.basename(name[i]).split(".")[0], |
|
i, |
|
) |
|
else: |
|
fname = ( |
|
"%s.wav" % os.path.basename(name[i]) |
|
if (not ".wav" in name[i]) |
|
else os.path.basename(name[i]).split(".")[0] |
|
) |
|
|
|
if len(fname) > 255: |
|
fname = f"{hex(hash(fname))}.wav" |
|
|
|
save_path = os.path.join(savepath, fname) |
|
temp_path = os.path.join(tempfile.gettempdir(), fname) |
|
print("\033[98m {}\033[00m" .format("Don't forget to try different seeds by setting --seed <int> so that AudioSR can have optimal performance on your hardware.")) |
|
print("Save audio to %s." % save_path) |
|
sf.write(temp_path, waveform[i, 0], samplerate=samplerate) |
|
strip_silence(inputpath, temp_path, save_path) |
|
|
|
|
|
def exists(x): |
|
return x is not None |
|
|
|
|
|
def default(val, d): |
|
if exists(val): |
|
return val |
|
return d() if isfunction(d) else d |
|
|
|
|
|
def count_params(model, verbose=False): |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
if verbose: |
|
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") |
|
return total_params |
|
|
|
|
|
def get_obj_from_str(string, reload=False): |
|
module, cls = string.rsplit(".", 1) |
|
if reload: |
|
module_imp = importlib.import_module(module) |
|
importlib.reload(module_imp) |
|
return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
|
def instantiate_from_config(config): |
|
if not "target" in config: |
|
if config == "__is_first_stage__": |
|
return None |
|
elif config == "__is_unconditional__": |
|
return None |
|
raise KeyError("Expected key `target` to instantiate.") |
|
try: |
|
return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|
except: |
|
import ipdb |
|
|
|
ipdb.set_trace() |
|
|
|
|
|
def default_audioldm_config(model_name="basic"): |
|
basic_config = get_basic_config() |
|
return basic_config |
|
|
|
|
|
class MyProgressBar: |
|
def __init__(self): |
|
self.pbar = None |
|
|
|
def __call__(self, block_num, block_size, total_size): |
|
if not self.pbar: |
|
self.pbar = progressbar.ProgressBar(maxval=total_size) |
|
self.pbar.start() |
|
|
|
downloaded = block_num * block_size |
|
if downloaded < total_size: |
|
self.pbar.update(downloaded) |
|
else: |
|
self.pbar.finish() |
|
|
|
|
|
def download_checkpoint(checkpoint_name="basic"): |
|
if checkpoint_name == "basic": |
|
model_id = "haoheliu/audiosr_basic" |
|
|
|
checkpoint_path = hf_hub_download( |
|
repo_id=model_id, filename="pytorch_model.bin" |
|
) |
|
elif checkpoint_name == "speech": |
|
model_id = "haoheliu/audiosr_speech" |
|
|
|
checkpoint_path = hf_hub_download( |
|
repo_id=model_id, filename="pytorch_model.bin" |
|
) |
|
else: |
|
raise ValueError("Invalid Model Name %s" % checkpoint_name) |
|
return checkpoint_path |
|
|
|
|
|
def get_basic_config(): |
|
return { |
|
"preprocessing": { |
|
"audio": { |
|
"sampling_rate": 48000, |
|
"max_wav_value": 32768, |
|
"duration": 10.24, |
|
}, |
|
"stft": {"filter_length": 2048, "hop_length": 480, "win_length": 2048}, |
|
"mel": {"n_mel_channels": 256, "mel_fmin": 20, "mel_fmax": 24000}, |
|
}, |
|
"augmentation": {"mixup": 0.5}, |
|
"model": { |
|
"target": "audiosr.latent_diffusion.models.ddpm.LatentDiffusion", |
|
"params": { |
|
"first_stage_config": { |
|
"base_learning_rate": 0.000008, |
|
"target": "audiosr.latent_encoder.autoencoder.AutoencoderKL", |
|
"params": { |
|
"reload_from_ckpt": "/mnt/bn/lqhaoheliu/project/audio_generation_diffusion/log/vae/vae_48k_256/ds_8_kl_1/checkpoints/ckpt-checkpoint-484999.ckpt", |
|
"sampling_rate": 48000, |
|
"batchsize": 4, |
|
"monitor": "val/rec_loss", |
|
"image_key": "fbank", |
|
"subband": 1, |
|
"embed_dim": 16, |
|
"time_shuffle": 1, |
|
"ddconfig": { |
|
"double_z": True, |
|
"mel_bins": 256, |
|
"z_channels": 16, |
|
"resolution": 256, |
|
"downsample_time": False, |
|
"in_channels": 1, |
|
"out_ch": 1, |
|
"ch": 128, |
|
"ch_mult": [1, 2, 4, 8], |
|
"num_res_blocks": 2, |
|
"attn_resolutions": [], |
|
"dropout": 0.1, |
|
}, |
|
}, |
|
}, |
|
"base_learning_rate": 0.0001, |
|
"warmup_steps": 5000, |
|
"optimize_ddpm_parameter": True, |
|
"sampling_rate": 48000, |
|
"batchsize": 16, |
|
"beta_schedule": "cosine", |
|
"linear_start": 0.0015, |
|
"linear_end": 0.0195, |
|
"num_timesteps_cond": 1, |
|
"log_every_t": 200, |
|
"timesteps": 1000, |
|
"unconditional_prob_cfg": 0.1, |
|
"parameterization": "v", |
|
"first_stage_key": "fbank", |
|
"latent_t_size": 128, |
|
"latent_f_size": 32, |
|
"channels": 16, |
|
"monitor": "val/loss_simple_ema", |
|
"scale_by_std": True, |
|
"unet_config": { |
|
"target": "audiosr.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel", |
|
"params": { |
|
"image_size": 64, |
|
"in_channels": 32, |
|
"out_channels": 16, |
|
"model_channels": 128, |
|
"attention_resolutions": [8, 4, 2], |
|
"num_res_blocks": 2, |
|
"channel_mult": [1, 2, 3, 5], |
|
"num_head_channels": 32, |
|
"extra_sa_layer": True, |
|
"use_spatial_transformer": True, |
|
"transformer_depth": 1, |
|
}, |
|
}, |
|
"evaluation_params": { |
|
"unconditional_guidance_scale": 3.5, |
|
"ddim_sampling_steps": 200, |
|
"n_candidates_per_samples": 1, |
|
}, |
|
"cond_stage_config": { |
|
"concat_lowpass_cond": { |
|
"cond_stage_key": "lowpass_mel", |
|
"conditioning_key": "concat", |
|
"target": "audiosr.latent_diffusion.modules.encoders.modules.VAEFeatureExtract", |
|
"params": { |
|
"first_stage_config": { |
|
"base_learning_rate": 0.000008, |
|
"target": "audiosr.latent_encoder.autoencoder.AutoencoderKL", |
|
"params": { |
|
"sampling_rate": 48000, |
|
"batchsize": 4, |
|
"monitor": "val/rec_loss", |
|
"image_key": "fbank", |
|
"subband": 1, |
|
"embed_dim": 16, |
|
"time_shuffle": 1, |
|
"ddconfig": { |
|
"double_z": True, |
|
"mel_bins": 256, |
|
"z_channels": 16, |
|
"resolution": 256, |
|
"downsample_time": False, |
|
"in_channels": 1, |
|
"out_ch": 1, |
|
"ch": 128, |
|
"ch_mult": [1, 2, 4, 8], |
|
"num_res_blocks": 2, |
|
"attn_resolutions": [], |
|
"dropout": 0.1, |
|
}, |
|
}, |
|
} |
|
}, |
|
} |
|
}, |
|
}, |
|
}, |
|
} |