|
import logging |
|
from pathlib import Path |
|
from typing import Union |
|
|
|
import torch |
|
|
|
RUN_NAME = "enhancer_stage2" |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_source_url(relpath): |
|
return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true" |
|
|
|
|
|
def get_target_path(relpath: Union[str, Path], run_dir: Union[str, Path, None] = None): |
|
if run_dir is None: |
|
run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME |
|
return Path(run_dir) / relpath |
|
|
|
|
|
def download(run_dir: Union[str, Path, None] = None): |
|
relpaths = [ |
|
"hparams.yaml", |
|
"ds/G/latest", |
|
"ds/G/default/mp_rank_00_model_states.pt", |
|
] |
|
for relpath in relpaths: |
|
path = get_target_path(relpath, run_dir=run_dir) |
|
if path.exists(): |
|
continue |
|
url = get_source_url(relpath) |
|
path.parent.mkdir(parents=True, exist_ok=True) |
|
torch.hub.download_url_to_file(url, str(path)) |
|
return get_target_path("", run_dir=run_dir) |
|
|