import logging from pathlib import Path from typing import Union import torch RUN_NAME = "enhancer_stage2" logger = logging.getLogger(__name__) def get_source_url(relpath): return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true" def get_target_path(relpath: Union[str, Path], run_dir: Union[str, Path, None] = None): if run_dir is None: run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME return Path(run_dir) / relpath def download(run_dir: Union[str, Path, None] = None): relpaths = [ "hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt", ] for relpath in relpaths: path = get_target_path(relpath, run_dir=run_dir) if path.exists(): continue url = get_source_url(relpath) path.parent.mkdir(parents=True, exist_ok=True) torch.hub.download_url_to_file(url, str(path)) return get_target_path("", run_dir=run_dir)