File size: 3,641 Bytes
225b9d3 893807d ec033ea f38a7f2 225b9d3 893807d 225b9d3 84ee041 225b9d3 893807d 225b9d3 893807d 225b9d3 893807d fd1a078 225b9d3 893807d 9b35b7c 225b9d3 893807d 225b9d3 ec033ea 225b9d3 61de3b8 ec033ea 893807d ec033ea 225b9d3 ec033ea 893807d ec033ea 225b9d3 ec033ea 7245754 ec033ea 893807d ec033ea fd1a078 ec033ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from diffusers import DiffusionPipeline
import os
import sys
from huggingface_hub import HfApi, hf_hub_download
# from .tools import build_dataset_json_from_list
import torch
class MOSDiffusionPipeline(DiffusionPipeline):
def __init__(self, reload_from_ckpt="./qa_mdt/checkpoint_389999.ckpt", base_folder=None):
"""
Initialize the MOS Diffusion pipeline and download the necessary files/folders.
Args:
config_yaml (str): Path to the YAML configuration file.
list_inference (str): Path to the file containing inference prompts.
reload_from_ckpt (str, optional): Checkpoint path to reload from.
base_folder (str, optional): Base folder to store downloaded files. Defaults to the current working directory.
"""
super().__init__()
self.base_folder = base_folder if base_folder else os.getcwd()
self.repo_id = "jadechoghari/qa-mdt"
self.config_yaml = "./qa_mdt/audioldm_train/config/mos_as_token/qa_mdt.yaml"
self.reload_from_ckpt = reload_from_ckpt
config_yaml_path = os.path.join(self.config_yaml)
self.configs = self.load_yaml(config_yaml_path)
self.configs["reload_from_ckpt"] = self.reload_from_ckpt
self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))
def download_required_folders(self):
"""
Downloads the necessary folders from the Hugging Face Hub if they are not already available locally.
"""
api = HfApi()
files = api.list_repo_files(repo_id=self.repo_id)
required_folders = ["audioldm_train", "checkpoints", "infer", "log", "taming", "test_prompts"]
files_to_download = [f for f in files if any(f.startswith(folder) for folder in required_folders)]
for file in files_to_download:
local_file_path = os.path.join(self.base_folder, file)
if not os.path.exists(local_file_path):
downloaded_file = hf_hub_download(repo_id=self.repo_id, filename=file)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
os.rename(downloaded_file, local_file_path)
sys.path.append(self.base_folder)
def load_yaml(self, yaml_path):
"""
Helper method to load the YAML configuration.
"""
import yaml
with open(yaml_path, "r") as f:
return yaml.safe_load(f)
@torch.no_grad()
def __call__(self, prompt: str):
"""
Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
"""
from .infer.infer_mos5 import infer
dataset_key = self.build_dataset_json_from_prompt(prompt)
# we run inference with the prompt - configs - and other settings
infer(
dataset_key=dataset_key,
configs=self.configs,
config_yaml_path=self.config_yaml,
exp_group_name="qa_mdt",
exp_name="mos_as_token"
)
def build_dataset_json_from_prompt(self, prompt: str):
"""
Build dataset_key dynamically from the provided prompt.
"""
# for simplicity let's just return the prompt as the dataset_key
data = [{"wav": "", "caption": prompt}] # no wav file, just the caption (prompt)
return {"data": data}
# Example of how to use the pipeline
if __name__ == "__main__":
pipe = MOSDiffusionPipeline()
result = pipe("A modern synthesizer creating futuristic soundscapes.")
print(result)
|