openmusic / pipeline.py
jadechoghari's picture
Update pipeline.py
ec033ea verified
raw
history blame
3.67 kB
from diffusers import DiffusionPipeline
import os
import sys
from huggingface_hub import HfApi, hf_hub_download
# from .tools import build_dataset_json_from_list
import torch
class MOSDiffusionPipeline(DiffusionPipeline):
def __init__(self, config_yaml, list_inference, reload_from_ckpt=None, base_folder=None):
"""
Initialize the MOS Diffusion pipeline and download the necessary files/folders.
Args:
config_yaml (str): Path to the YAML configuration file.
list_inference (str): Path to the file containing inference prompts.
reload_from_ckpt (str, optional): Checkpoint path to reload from.
base_folder (str, optional): Base folder to store downloaded files. Defaults to the current working directory.
"""
super().__init__()
self.base_folder = base_folder if base_folder else os.getcwd()
self.repo_id = "jadechoghari/qa-mdt"
self.config_yaml = config_yaml
self.list_inference = list_inference
self.reload_from_ckpt = reload_from_ckpt
config_yaml_path = os.path.join(self.config_yaml)
self.configs = self.load_yaml(config_yaml_path)
if self.reload_from_ckpt is not None:
self.configs["reload_from_ckpt"] = self.reload_from_ckpt
self.dataset_key = build_dataset_json_from_list(self.list_inference)
self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))
def download_required_folders(self):
"""
Downloads the necessary folders from the Hugging Face Hub if they are not already available locally.
"""
api = HfApi()
files = api.list_repo_files(repo_id=self.repo_id)
required_folders = ["audioldm_train", "checkpoints", "infer", "log", "taming", "test_prompts"]
files_to_download = [f for f in files if any(f.startswith(folder) for folder in required_folders)]
for file in files_to_download:
local_file_path = os.path.join(self.base_folder, file)
if not os.path.exists(local_file_path):
downloaded_file = hf_hub_download(repo_id=self.repo_id, filename=file)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
os.rename(downloaded_file, local_file_path)
sys.path.append(self.base_folder)
def load_yaml(self, yaml_path):
"""
Helper method to load the YAML configuration.
"""
import yaml
with open(yaml_path, "r") as f:
return yaml.safe_load(f)
@torch.no_grad()
def __call__(self, prompt: str):
"""
Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
"""
from .infer.infer_mos5 import infer
dataset_key = self.build_dataset_json_from_prompt(prompt)
# we run inference with the prompt - configs - and other settings
infer(
dataset_key=dataset_key,
configs=self.configs,
config_yaml_path=self.config_yaml,
exp_group_name="qa_mdt",
exp_name="mos_as_token"
)
def build_dataset_json_from_prompt(self, prompt: str):
"""
Build dataset_key dynamically from the provided prompt.
"""
# for simplicity let's just return the prompt as the dataset_key
return {"prompt": prompt}
# Example of how to use the pipeline
if __name__ == "__main__":
pipe = MOSDiffusionPipeline()
result = pipe("Generate a description of a sunny day.")
print(result)