Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import yaml | |
from audiosr import download_checkpoint, default_audioldm_config, LatentDiffusion | |
def load_audiosr(ckpt_path=None, config=None, device=None, model_name="basic"): | |
if device is None or device == "auto": | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
print("Loading AudioSR: %s" % model_name) | |
print("Loading model on %s" % device) | |
ckpt_path = download_checkpoint(model_name) | |
if config is not None: | |
assert type(config) is str | |
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) | |
else: | |
config = default_audioldm_config(model_name) | |
# # Use text as condition instead of using waveform during training | |
config["model"]["params"]["device"] = device | |
# config["model"]["params"]["cond_stage_key"] = "text" | |
# No normalization here | |
latent_diffusion = LatentDiffusion(**config["model"]["params"]) | |
resume_from_checkpoint = ckpt_path | |
checkpoint = torch.load(resume_from_checkpoint, map_location="cpu") | |
latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=True) | |
latent_diffusion.eval() | |
latent_diffusion = latent_diffusion.to(device) | |
return latent_diffusion | |