|
from diffusers import DiffusionPipeline |
|
import os |
|
import sys |
|
from huggingface_hub import HfApi, hf_hub_download |
|
|
|
import torch |
|
|
|
class MOSDiffusionPipeline(DiffusionPipeline): |
|
|
|
def __init__(self, config_yaml, list_inference, reload_from_ckpt=None, base_folder=None): |
|
""" |
|
Initialize the MOS Diffusion pipeline and download the necessary files/folders. |
|
|
|
Args: |
|
config_yaml (str): Path to the YAML configuration file. |
|
list_inference (str): Path to the file containing inference prompts. |
|
reload_from_ckpt (str, optional): Checkpoint path to reload from. |
|
base_folder (str, optional): Base folder to store downloaded files. Defaults to the current working directory. |
|
""" |
|
super().__init__() |
|
|
|
|
|
self.base_folder = base_folder if base_folder else os.getcwd() |
|
self.repo_id = "jadechoghari/qa-mdt" |
|
self.config_yaml = config_yaml |
|
self.list_inference = list_inference |
|
self.reload_from_ckpt = reload_from_ckpt |
|
config_yaml_path = os.path.join(self.config_yaml) |
|
self.configs = self.load_yaml(config_yaml_path) |
|
if self.reload_from_ckpt is not None: |
|
self.configs["reload_from_ckpt"] = self.reload_from_ckpt |
|
|
|
self.dataset_key = build_dataset_json_from_list(self.list_inference) |
|
self.exp_name = os.path.basename(self.config_yaml.split(".")[0]) |
|
self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml)) |
|
|
|
def download_required_folders(self): |
|
""" |
|
Downloads the necessary folders from the Hugging Face Hub if they are not already available locally. |
|
""" |
|
api = HfApi() |
|
|
|
files = api.list_repo_files(repo_id=self.repo_id) |
|
|
|
required_folders = ["audioldm_train", "checkpoints", "infer", "log", "taming", "test_prompts"] |
|
|
|
files_to_download = [f for f in files if any(f.startswith(folder) for folder in required_folders)] |
|
|
|
for file in files_to_download: |
|
local_file_path = os.path.join(self.base_folder, file) |
|
if not os.path.exists(local_file_path): |
|
downloaded_file = hf_hub_download(repo_id=self.repo_id, filename=file) |
|
|
|
os.makedirs(os.path.dirname(local_file_path), exist_ok=True) |
|
|
|
os.rename(downloaded_file, local_file_path) |
|
|
|
sys.path.append(self.base_folder) |
|
|
|
def load_yaml(self, yaml_path): |
|
""" |
|
Helper method to load the YAML configuration. |
|
""" |
|
import yaml |
|
with open(yaml_path, "r") as f: |
|
return yaml.safe_load(f) |
|
|
|
|
|
@torch.no_grad() |
|
def __call__(self, prompt: str): |
|
""" |
|
Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py. |
|
""" |
|
from .infer.infer_mos5 import infer |
|
dataset_key = self.build_dataset_json_from_prompt(prompt) |
|
|
|
|
|
infer( |
|
dataset_key=dataset_key, |
|
configs=self.configs, |
|
config_yaml_path=self.config_yaml, |
|
exp_group_name="qa_mdt", |
|
exp_name="mos_as_token" |
|
) |
|
|
|
def build_dataset_json_from_prompt(self, prompt: str): |
|
""" |
|
Build dataset_key dynamically from the provided prompt. |
|
""" |
|
|
|
return {"prompt": prompt} |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
pipe = MOSDiffusionPipeline() |
|
result = pipe("Generate a description of a sunny day.") |
|
print(result) |
|
|