Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- utils/__pycache__/utils.cpython-310.pyc +0 -0
- utils/utils.py +77 -0
utils/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.03 kB). View file
|
|
utils/utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
def count_params(model, verbose=False):
|
9 |
+
total_params = sum(p.numel() for p in model.parameters())
|
10 |
+
if verbose:
|
11 |
+
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
12 |
+
return total_params
|
13 |
+
|
14 |
+
|
15 |
+
def check_istarget(name, para_list):
|
16 |
+
"""
|
17 |
+
name: full name of source para
|
18 |
+
para_list: partial name of target para
|
19 |
+
"""
|
20 |
+
istarget=False
|
21 |
+
for para in para_list:
|
22 |
+
if para in name:
|
23 |
+
return True
|
24 |
+
return istarget
|
25 |
+
|
26 |
+
|
27 |
+
def instantiate_from_config(config):
|
28 |
+
if not "target" in config:
|
29 |
+
if config == '__is_first_stage__':
|
30 |
+
return None
|
31 |
+
elif config == "__is_unconditional__":
|
32 |
+
return None
|
33 |
+
raise KeyError("Expected key `target` to instantiate.")
|
34 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
35 |
+
|
36 |
+
|
37 |
+
def get_obj_from_str(string, reload=False):
|
38 |
+
module, cls = string.rsplit(".", 1)
|
39 |
+
if reload:
|
40 |
+
module_imp = importlib.import_module(module)
|
41 |
+
importlib.reload(module_imp)
|
42 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
43 |
+
|
44 |
+
|
45 |
+
def load_npz_from_dir(data_dir):
|
46 |
+
data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)]
|
47 |
+
data = np.concatenate(data, axis=0)
|
48 |
+
return data
|
49 |
+
|
50 |
+
|
51 |
+
def load_npz_from_paths(data_paths):
|
52 |
+
data = [np.load(data_path)['arr_0'] for data_path in data_paths]
|
53 |
+
data = np.concatenate(data, axis=0)
|
54 |
+
return data
|
55 |
+
|
56 |
+
|
57 |
+
def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None):
|
58 |
+
h, w = image.shape[:2]
|
59 |
+
if resize_short_edge is not None:
|
60 |
+
k = resize_short_edge / min(h, w)
|
61 |
+
else:
|
62 |
+
k = max_resolution / (h * w)
|
63 |
+
k = k**0.5
|
64 |
+
h = int(np.round(h * k / 64)) * 64
|
65 |
+
w = int(np.round(w * k / 64)) * 64
|
66 |
+
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
|
67 |
+
return image
|
68 |
+
|
69 |
+
|
70 |
+
def setup_dist(args):
|
71 |
+
if dist.is_initialized():
|
72 |
+
return
|
73 |
+
torch.cuda.set_device(args.local_rank)
|
74 |
+
torch.distributed.init_process_group(
|
75 |
+
'nccl',
|
76 |
+
init_method='env://'
|
77 |
+
)
|