Spaces:
Running
on
Zero
Running
on
Zero
SunderAli17
commited on
Commit
•
e615809
1
Parent(s):
abb3181
Create ToonMage/utils.py
Browse files- ToonMage/utils.py +76 -0
ToonMage/utils.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from transformers import PretrainedConfig
|
10 |
+
|
11 |
+
|
12 |
+
def seed_everything(seed):
|
13 |
+
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
14 |
+
random.seed(seed)
|
15 |
+
np.random.seed(seed)
|
16 |
+
torch.manual_seed(seed)
|
17 |
+
torch.cuda.manual_seed_all(seed)
|
18 |
+
|
19 |
+
|
20 |
+
def is_torch2_available():
|
21 |
+
return hasattr(F, "scaled_dot_product_attention")
|
22 |
+
|
23 |
+
|
24 |
+
def instantiate_from_config(config):
|
25 |
+
if "target" not in config:
|
26 |
+
if config == '__is_first_stage__' or config == "__is_unconditional__":
|
27 |
+
return None
|
28 |
+
raise KeyError("Expected key `target` to instantiate.")
|
29 |
+
return get_obj_from_str(config["target"])(**config.get("params", {}))
|
30 |
+
|
31 |
+
|
32 |
+
def get_obj_from_str(string, reload=False):
|
33 |
+
module, cls = string.rsplit(".", 1)
|
34 |
+
if reload:
|
35 |
+
module_imp = importlib.import_module(module)
|
36 |
+
importlib.reload(module_imp)
|
37 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
38 |
+
|
39 |
+
|
40 |
+
def drop_seq_token(seq, drop_rate=0.5):
|
41 |
+
idx = torch.randperm(seq.size(1))
|
42 |
+
num_keep_tokens = int(len(idx) * (1 - drop_rate))
|
43 |
+
idx = idx[:num_keep_tokens]
|
44 |
+
seq = seq[:, idx]
|
45 |
+
return seq
|
46 |
+
|
47 |
+
|
48 |
+
def import_model_class_from_model_name_or_path(
|
49 |
+
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
50 |
+
):
|
51 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
52 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
53 |
+
)
|
54 |
+
model_class = text_encoder_config.architectures[0]
|
55 |
+
|
56 |
+
if model_class == "CLIPTextModel":
|
57 |
+
from transformers import CLIPTextModel
|
58 |
+
|
59 |
+
return CLIPTextModel
|
60 |
+
elif model_class == "CLIPTextModelWithProjection": # noqa RET505
|
61 |
+
from transformers import CLIPTextModelWithProjection
|
62 |
+
|
63 |
+
return CLIPTextModelWithProjection
|
64 |
+
else:
|
65 |
+
raise ValueError(f"{model_class} is not supported.")
|
66 |
+
|
67 |
+
|
68 |
+
def resize_numpy_image_long(image, resize_long_edge=768):
|
69 |
+
h, w = image.shape[:2]
|
70 |
+
if max(h, w) <= resize_long_edge:
|
71 |
+
return image
|
72 |
+
k = resize_long_edge / max(h, w)
|
73 |
+
h = int(h * k)
|
74 |
+
w = int(w * k)
|
75 |
+
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
|
76 |
+
return image
|