Spaces:
Running
on
A10G
Running
on
A10G
import importlib | |
import requests | |
from pathlib import Path | |
from os.path import dirname | |
from omegaconf import OmegaConf | |
from tqdm import tqdm | |
PROJECT_DIR = dirname(dirname(dirname(__file__))) | |
CONFIG_FOLDER = f'{PROJECT_DIR}/assets/config' | |
MODEL_FOLDER = f'{PROJECT_DIR}/assets/models' | |
def download_file(url, save_path, chunk_size=1024): | |
try: | |
save_path = Path(save_path) | |
if save_path.exists(): | |
print(f'{save_path.name} exists') | |
return | |
save_path.parent.mkdir(exist_ok=True, parents=True) | |
resp = requests.get(url, stream=True) | |
total = int(resp.headers.get('content-length', 0)) | |
with open(save_path, 'wb') as file, tqdm( | |
desc=save_path.name, | |
total=total, | |
unit='iB', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as bar: | |
for data in resp.iter_content(chunk_size=chunk_size): | |
size = file.write(data) | |
bar.update(size) | |
print(f'{save_path.name} download finished') | |
except Exception as e: | |
raise Exception(f"Download failed: {e}") | |
def get_obj_from_str(string): | |
module, cls = string.rsplit(".", 1) | |
try: | |
return getattr(importlib.import_module(module, package=None), cls) | |
except: | |
return getattr(importlib.import_module('lib.' + module, package=None), cls) | |
def load_obj(path): | |
objyaml = OmegaConf.load(path) | |
return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {})) | |